1use super::{DataflowGraph, DependencyAnalysis, OperatorId, ScheduleStep, ScheduledDfg};
24use std::cmp::Ordering;
25use std::collections::{HashMap, VecDeque};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum ThermalConstraint {
30 Cool,
32 Nominal,
34 Elevated,
36 Hot,
38 Critical,
40}
41
42impl ThermalConstraint {
43 pub fn parallelism_factor(&self) -> f64 {
45 match self {
46 ThermalConstraint::Cool => 1.0,
47 ThermalConstraint::Nominal => 1.0,
48 ThermalConstraint::Elevated => 0.75,
49 ThermalConstraint::Hot => 0.5,
50 ThermalConstraint::Critical => 0.0, }
52 }
53
54 pub fn energy_multiplier(&self) -> f64 {
56 match self {
57 ThermalConstraint::Cool => 0.9,
58 ThermalConstraint::Nominal => 1.0,
59 ThermalConstraint::Elevated => 1.3,
60 ThermalConstraint::Hot => 1.6,
61 ThermalConstraint::Critical => 2.0,
62 }
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct SchedulingConfig {
69 pub max_parallelism: usize,
71 pub energy_budget_j: Option<f64>,
73 pub prefer_efficiency: bool,
75 pub min_parallelism: usize,
77 pub base_energy_per_op: f64,
79}
80
81impl Default for SchedulingConfig {
82 fn default() -> Self {
83 Self {
84 max_parallelism: 8, energy_budget_j: None,
86 prefer_efficiency: false,
87 min_parallelism: 1,
88 base_energy_per_op: 1e-6, }
90 }
91}
92
93impl SchedulingConfig {
94 pub fn high_performance(cores: usize) -> Self {
96 Self {
97 max_parallelism: cores,
98 energy_budget_j: None,
99 prefer_efficiency: false,
100 min_parallelism: 2,
101 base_energy_per_op: 1e-6,
102 }
103 }
104
105 pub fn energy_efficient(budget_j: f64) -> Self {
107 Self {
108 max_parallelism: 4,
109 energy_budget_j: Some(budget_j),
110 prefer_efficiency: true,
111 min_parallelism: 1,
112 base_energy_per_op: 1e-6,
113 }
114 }
115
116 pub fn thermal_constrained() -> Self {
118 Self {
119 max_parallelism: 2,
120 energy_budget_j: None,
121 prefer_efficiency: true,
122 min_parallelism: 1,
123 base_energy_per_op: 1e-6,
124 }
125 }
126}
127
128pub struct EnergyAwareScheduler {
130 config: SchedulingConfig,
131}
132
133impl EnergyAwareScheduler {
134 pub fn new(config: SchedulingConfig) -> Self {
135 Self { config }
136 }
137
138 pub fn schedule(&self, dfg: &DataflowGraph, thermal: ThermalConstraint) -> ScheduledDfg {
140 let analysis = DependencyAnalysis::analyze(dfg);
142
143 let effective_parallelism = self.calculate_effective_parallelism(thermal);
145
146 let schedule = self.build_schedule(dfg, &analysis, effective_parallelism);
148
149 let estimated_energy = self.estimate_total_energy(dfg, &schedule, thermal);
151
152 let estimated_duration = self.estimate_duration(&schedule);
154
155 ScheduledDfg {
156 dfg: dfg.clone(),
157 schedule,
158 estimated_energy_j: estimated_energy,
159 estimated_duration_s: estimated_duration,
160 }
161 }
162
163 fn calculate_effective_parallelism(&self, thermal: ThermalConstraint) -> usize {
165 let base = self.config.max_parallelism as f64;
166 let thermal_factor = thermal.parallelism_factor();
167
168 let effective = (base * thermal_factor).round() as usize;
169
170 let energy_constrained = if let Some(budget) = self.config.energy_budget_j {
172 let ops_allowed = (budget / self.config.base_energy_per_op) as usize;
174 effective.min(ops_allowed.max(1))
175 } else {
176 effective
177 };
178
179 energy_constrained.max(self.config.min_parallelism)
180 }
181
182 fn build_schedule(
184 &self,
185 dfg: &DataflowGraph,
186 analysis: &DependencyAnalysis,
187 max_parallelism: usize,
188 ) -> Vec<ScheduleStep> {
189 if dfg.operators.is_empty() {
190 return Vec::new();
191 }
192
193 let levels = self.assign_levels(dfg, analysis);
195
196 let max_level = levels.values().max().copied().unwrap_or(0);
198
199 let mut level_groups: Vec<Vec<OperatorId>> = vec![Vec::new(); max_level + 1];
201 for (&op_id, &level) in &levels {
202 level_groups[level].push(op_id);
203 }
204
205 let mut schedule = Vec::new();
207
208 for level_ops in level_groups {
209 if level_ops.is_empty() {
210 continue;
211 }
212
213 let mut entries: Vec<ScheduleEntry> = level_ops
216 .iter()
217 .map(|&op_id| {
218 let energy = dfg
219 .operators
220 .get(op_id.0 as usize)
221 .map(|op| op.energy_cost())
222 .unwrap_or(0);
223 ScheduleEntry {
224 op_id,
225 priority: if self.config.prefer_efficiency {
226 i32::MAX - energy as i32
228 } else {
229 0 },
231 energy_cost: energy,
232 }
233 })
234 .collect();
235
236 entries.sort_by(|a, b| a.cmp(b).reverse());
238
239 let sorted_ops: Vec<OperatorId> = entries.iter().map(|e| e.op_id).collect();
240
241 for chunk in sorted_ops.chunks(max_parallelism) {
243 let step_energy: f64 = chunk
244 .iter()
245 .map(|op_id| {
246 dfg.operators
247 .get(op_id.0 as usize)
248 .map(|op| op.energy_cost() as f64 * self.config.base_energy_per_op)
249 .unwrap_or(0.0)
250 })
251 .sum();
252
253 schedule.push(ScheduleStep {
254 operators: chunk.to_vec(),
255 max_parallelism: chunk.len(),
256 estimated_energy: step_energy,
257 });
258 }
259 }
260
261 schedule
262 }
263
264 fn assign_levels(
266 &self,
267 dfg: &DataflowGraph,
268 analysis: &DependencyAnalysis,
269 ) -> HashMap<OperatorId, usize> {
270 let mut levels: HashMap<OperatorId, usize> = HashMap::new();
271 let mut queue: VecDeque<OperatorId> = VecDeque::new();
272
273 for &src in &analysis.sources {
275 levels.insert(src, 0);
276 queue.push_back(src);
277 }
278
279 while let Some(op) = queue.pop_front() {
281 let op_level = *levels.get(&op).unwrap_or(&0);
282
283 if let Some(succs) = analysis.successors.get(&op) {
284 for &succ in succs {
285 let new_level = op_level + 1;
286 let current = levels.entry(succ).or_insert(0);
287 if new_level > *current {
288 *current = new_level;
289 }
290
291 let all_preds_done = analysis
293 .predecessors
294 .get(&succ)
295 .map(|preds| preds.iter().all(|p| levels.contains_key(p)))
296 .unwrap_or(true);
297
298 if all_preds_done && !queue.contains(&succ) {
299 queue.push_back(succ);
300 }
301 }
302 }
303 }
304
305 for i in 0..dfg.operators.len() {
307 let op_id = OperatorId::new(i as u32);
308 levels.entry(op_id).or_insert(0);
309 }
310
311 levels
312 }
313
314 fn estimate_total_energy(
316 &self,
317 _dfg: &DataflowGraph,
318 schedule: &[ScheduleStep],
319 thermal: ThermalConstraint,
320 ) -> f64 {
321 let base_energy: f64 = schedule.iter().map(|s| s.estimated_energy).sum();
322
323 base_energy * thermal.energy_multiplier()
325 }
326
327 fn estimate_duration(&self, schedule: &[ScheduleStep]) -> f64 {
329 let base_time_per_op = 1e-6; schedule
334 .iter()
335 .map(|step| {
336 let ops = step.operators.len().max(1);
337 let parallelism = step.max_parallelism.max(1);
338 (ops as f64 / parallelism as f64) * base_time_per_op
339 })
340 .sum()
341 }
342
343 pub fn check_energy_budget(&self, scheduled: &ScheduledDfg) -> Option<f64> {
345 if let Some(budget) = self.config.energy_budget_j {
346 if scheduled.estimated_energy_j > budget {
347 return Some(scheduled.estimated_energy_j - budget);
348 }
349 }
350 None
351 }
352
353 pub fn reschedule_for_energy(&self, dfg: &DataflowGraph, target_energy_j: f64) -> ScheduledDfg {
355 let thermal = ThermalConstraint::Critical;
357 let mut scheduled = self.schedule(dfg, thermal);
358
359 if scheduled.estimated_energy_j > target_energy_j {
361 return scheduled;
363 }
364
365 for state in &[
367 ThermalConstraint::Hot,
368 ThermalConstraint::Elevated,
369 ThermalConstraint::Nominal,
370 ThermalConstraint::Cool,
371 ] {
372 let candidate = self.schedule(dfg, *state);
373 if candidate.estimated_energy_j <= target_energy_j {
374 scheduled = candidate;
375 } else {
376 break;
377 }
378 }
379
380 scheduled
381 }
382}
383
384impl Default for EnergyAwareScheduler {
385 fn default() -> Self {
386 Self::new(SchedulingConfig::default())
387 }
388}
389
390#[derive(Debug, Clone, Eq, PartialEq)]
392struct ScheduleEntry {
393 op_id: OperatorId,
394 priority: i32, energy_cost: u32,
396}
397
398impl Ord for ScheduleEntry {
399 fn cmp(&self, other: &Self) -> Ordering {
400 self.priority
402 .cmp(&other.priority)
403 .then_with(|| other.energy_cost.cmp(&self.energy_cost))
404 }
405}
406
407impl PartialOrd for ScheduleEntry {
408 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
409 Some(self.cmp(other))
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use crate::dataflow::{ComputeOp, DfOperator, TokenType};
417
418 #[test]
419 #[allow(clippy::float_cmp)]
420 fn test_thermal_constraint_factors() {
421 assert_eq!(ThermalConstraint::Cool.parallelism_factor(), 1.0);
422 assert_eq!(ThermalConstraint::Hot.parallelism_factor(), 0.5);
423 assert_eq!(ThermalConstraint::Critical.parallelism_factor(), 0.0);
424
425 assert!(ThermalConstraint::Hot.energy_multiplier() > 1.0);
426 assert!(ThermalConstraint::Cool.energy_multiplier() < 1.0);
427 }
428
429 #[test]
430 fn test_scheduling_config() {
431 let config = SchedulingConfig::default();
432 assert_eq!(config.max_parallelism, 8);
433 assert!(config.energy_budget_j.is_none());
434
435 let efficient = SchedulingConfig::energy_efficient(1.0);
436 assert_eq!(efficient.energy_budget_j, Some(1.0));
437 assert!(efficient.prefer_efficiency);
438
439 let perf = SchedulingConfig::high_performance(16);
440 assert_eq!(perf.max_parallelism, 16);
441 }
442
443 #[test]
444 fn test_scheduler_empty_graph() {
445 let dfg = DataflowGraph::new();
446 let scheduler = EnergyAwareScheduler::default();
447 let scheduled = scheduler.schedule(&dfg, ThermalConstraint::Nominal);
448
449 assert!(scheduled.schedule.is_empty());
450 }
451
452 #[test]
453 fn test_scheduler_linear_graph() {
454 let mut dfg = DataflowGraph::new();
455
456 let ch1 = dfg.add_channel(TokenType::i32());
458 let ch2 = dfg.add_channel(TokenType::i32());
459 let _ch3 = dfg.add_channel(TokenType::i32());
460
461 dfg.add_operator(DfOperator::Source {
462 external_id: 0,
463 output: ch1,
464 });
465 dfg.add_operator(DfOperator::Compute {
466 op: ComputeOp::Add,
467 inputs: vec![ch1],
468 output: ch2,
469 });
470 dfg.add_operator(DfOperator::Sink {
471 input: ch2,
472 external_id: 0,
473 });
474
475 let scheduler = EnergyAwareScheduler::default();
476 let scheduled = scheduler.schedule(&dfg, ThermalConstraint::Nominal);
477
478 assert_eq!(scheduled.schedule.len(), 3);
480 }
481
482 #[test]
483 fn test_scheduler_parallel_graph() {
484 let mut dfg = DataflowGraph::new();
485
486 let ch_in = dfg.add_channel(TokenType::i32());
488 let ch_a = dfg.add_channel(TokenType::i32());
489 let ch_b = dfg.add_channel(TokenType::i32());
490 let ch_mul_out = dfg.add_channel(TokenType::i32());
491 let ch_add_out = dfg.add_channel(TokenType::i32());
492
493 dfg.add_operator(DfOperator::Source {
494 external_id: 0,
495 output: ch_in,
496 });
497 dfg.add_operator(DfOperator::Split {
498 input: ch_in,
499 outputs: vec![ch_a, ch_b],
500 });
501 dfg.add_operator(DfOperator::Compute {
502 op: ComputeOp::Mul,
503 inputs: vec![ch_a],
504 output: ch_mul_out,
505 });
506 dfg.add_operator(DfOperator::Compute {
507 op: ComputeOp::Add,
508 inputs: vec![ch_b],
509 output: ch_add_out,
510 });
511
512 let scheduler = EnergyAwareScheduler::new(SchedulingConfig {
513 max_parallelism: 4,
514 ..Default::default()
515 });
516 let scheduled = scheduler.schedule(&dfg, ThermalConstraint::Nominal);
517
518 assert!(scheduled.schedule.len() <= dfg.operators.len());
520 }
521
522 #[test]
523 fn test_thermal_reduces_parallelism() {
524 let mut dfg = DataflowGraph::new();
525
526 for i in 0..10 {
528 let ch = dfg.add_channel(TokenType::i32());
529 dfg.add_operator(DfOperator::Source {
530 external_id: i,
531 output: ch,
532 });
533 }
534
535 let scheduler = EnergyAwareScheduler::new(SchedulingConfig {
536 max_parallelism: 8,
537 ..Default::default()
538 });
539
540 let cool = scheduler.schedule(&dfg, ThermalConstraint::Cool);
541 let hot = scheduler.schedule(&dfg, ThermalConstraint::Hot);
542 let critical = scheduler.schedule(&dfg, ThermalConstraint::Critical);
543
544 assert!(hot.schedule.len() >= cool.schedule.len());
546 assert!(critical.schedule.len() >= hot.schedule.len());
547 }
548
549 #[test]
550 fn test_energy_budget_check() {
551 let mut dfg = DataflowGraph::new();
552 let ch = dfg.add_channel(TokenType::i32());
553 dfg.add_operator(DfOperator::Source {
554 external_id: 0,
555 output: ch,
556 });
557
558 let scheduler = EnergyAwareScheduler::new(SchedulingConfig {
559 energy_budget_j: Some(1e-9), ..Default::default()
561 });
562
563 let scheduled = scheduler.schedule(&dfg, ThermalConstraint::Nominal);
564
565 let _exceeded = scheduler.check_energy_budget(&scheduled);
567 }
569}