joule_mir/dataflow/
mod.rs

1//! Dataflow Graph Representation for Energy-Aware Execution
2//!
3//! This module provides a dataflow graph (DFG) representation that enables:
4//! - Implicit parallelism from data dependencies
5//! - Energy-aware scheduling decisions
6//! - RipTide-style ordered dataflow execution
7//!
8//! # Key Concepts
9//!
10//! - **Operators**: Computation nodes that fire when inputs are ready
11//! - **Channels**: Bounded FIFO queues connecting operators
12//! - **Steer**: Conditional routing (RipTide's control-flow primitive)
13//! - **Stream**: Fused loop induction variable (27-52% operator reduction)
14//! - **Carry**: Loop-carried dependencies
15//!
16//! # Energy-Aware Scheduling
17//!
18//! Unlike traditional dataflow, Joule's scheduler considers:
19//! - Current thermal state
20//! - Power/energy budgets
21//! - Available parallelism
22//!
23//! This allows the runtime to dynamically throttle parallelism when
24//! thermal or energy constraints are active.
25
26pub mod extract;
27pub mod optimize;
28pub mod schedule;
29
30pub use extract::DfgExtractor;
31pub use optimize::{DfgOptimizer, StreamFusion};
32pub use schedule::{EnergyAwareScheduler, SchedulingConfig, ThermalConstraint};
33
34use std::collections::{HashMap, HashSet, VecDeque};
35use std::fmt;
36
37// ============================================================================
38// Identifiers
39// ============================================================================
40
41/// Unique identifier for a dataflow operator
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
43pub struct OperatorId(pub u32);
44
45impl OperatorId {
46    pub fn new(id: u32) -> Self {
47        Self(id)
48    }
49}
50
51impl fmt::Display for OperatorId {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        write!(f, "op{}", self.0)
54    }
55}
56
57/// Unique identifier for a channel between operators
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
59pub struct ChannelId(pub u32);
60
61impl ChannelId {
62    pub fn new(id: u32) -> Self {
63        Self(id)
64    }
65}
66
67impl fmt::Display for ChannelId {
68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69        write!(f, "ch{}", self.0)
70    }
71}
72
73// ============================================================================
74// Token Types
75// ============================================================================
76
77/// Value carried by a token through a channel
78#[derive(Debug, Clone)]
79pub enum TokenValue {
80    /// Unit/void value
81    Unit,
82    /// Boolean
83    Bool(bool),
84    /// Signed integer (up to 64 bits)
85    Int(i64),
86    /// Unsigned integer (up to 64 bits)
87    Uint(u64),
88    /// 64-bit float
89    Float(f64),
90    /// Pointer/address
91    Ptr(u64),
92    /// Array of values
93    Array(Vec<TokenValue>),
94    /// Tuple of values
95    Tuple(Vec<TokenValue>),
96}
97
98/// Error when converting a token value to an incompatible type
99#[derive(Debug, Clone)]
100pub struct TokenConversionError {
101    pub value: String,
102    pub target_type: &'static str,
103}
104
105impl fmt::Display for TokenConversionError {
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        write!(f, "cannot convert {} to {}", self.value, self.target_type)
108    }
109}
110
111impl std::error::Error for TokenConversionError {}
112
113impl TokenValue {
114    pub fn as_bool(&self) -> Result<bool, TokenConversionError> {
115        match self {
116            TokenValue::Bool(b) => Ok(*b),
117            TokenValue::Int(i) => Ok(*i != 0),
118            TokenValue::Uint(u) => Ok(*u != 0),
119            TokenValue::Unit => Ok(false),
120            TokenValue::Float(f) => Ok(*f != 0.0),
121            TokenValue::Ptr(p) => Ok(*p != 0),
122            TokenValue::Array(_) | TokenValue::Tuple(_) => Err(TokenConversionError {
123                value: format!("{:?}", self),
124                target_type: "bool",
125            }),
126        }
127    }
128
129    pub fn as_i64(&self) -> Result<i64, TokenConversionError> {
130        match self {
131            TokenValue::Int(i) => Ok(*i),
132            TokenValue::Uint(u) => Ok(*u as i64),
133            TokenValue::Bool(b) => Ok(if *b { 1 } else { 0 }),
134            TokenValue::Float(f) => Ok(*f as i64),
135            TokenValue::Ptr(p) => Ok(*p as i64),
136            TokenValue::Unit => Ok(0),
137            TokenValue::Array(_) | TokenValue::Tuple(_) => Err(TokenConversionError {
138                value: format!("{:?}", self),
139                target_type: "i64",
140            }),
141        }
142    }
143
144    pub fn as_u64(&self) -> Result<u64, TokenConversionError> {
145        match self {
146            TokenValue::Uint(u) => Ok(*u),
147            TokenValue::Int(i) => Ok(*i as u64),
148            TokenValue::Bool(b) => Ok(if *b { 1 } else { 0 }),
149            TokenValue::Float(f) => Ok(*f as u64),
150            TokenValue::Ptr(p) => Ok(*p),
151            TokenValue::Unit => Ok(0),
152            TokenValue::Array(_) | TokenValue::Tuple(_) => Err(TokenConversionError {
153                value: format!("{:?}", self),
154                target_type: "u64",
155            }),
156        }
157    }
158
159    pub fn as_f64(&self) -> Result<f64, TokenConversionError> {
160        match self {
161            TokenValue::Float(f) => Ok(*f),
162            TokenValue::Int(i) => Ok(*i as f64),
163            TokenValue::Uint(u) => Ok(*u as f64),
164            TokenValue::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
165            TokenValue::Ptr(p) => Ok(*p as f64),
166            TokenValue::Unit => Ok(0.0),
167            TokenValue::Array(_) | TokenValue::Tuple(_) => Err(TokenConversionError {
168                value: format!("{:?}", self),
169                target_type: "f64",
170            }),
171        }
172    }
173}
174
175/// Type of token values
176#[derive(Debug, Clone, PartialEq, Eq, Hash)]
177pub enum TokenType {
178    Unit,
179    Bool,
180    Int {
181        bits: u8,
182        signed: bool,
183    },
184    Float {
185        bits: u8,
186    },
187    Ptr,
188    Array {
189        element: Box<TokenType>,
190        size: usize,
191    },
192    Tuple(Vec<TokenType>),
193}
194
195impl TokenType {
196    pub fn i32() -> Self {
197        TokenType::Int {
198            bits: 32,
199            signed: true,
200        }
201    }
202
203    pub fn i64() -> Self {
204        TokenType::Int {
205            bits: 64,
206            signed: true,
207        }
208    }
209
210    pub fn u32() -> Self {
211        TokenType::Int {
212            bits: 32,
213            signed: false,
214        }
215    }
216
217    pub fn u64() -> Self {
218        TokenType::Int {
219            bits: 64,
220            signed: false,
221        }
222    }
223
224    pub fn f32() -> Self {
225        TokenType::Float { bits: 32 }
226    }
227
228    pub fn f64() -> Self {
229        TokenType::Float { bits: 64 }
230    }
231}
232
233// ============================================================================
234// Compute Operations
235// ============================================================================
236
237/// Arithmetic/logic operations for Compute operators
238#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
239pub enum ComputeOp {
240    // Arithmetic
241    Add,
242    Sub,
243    Mul,
244    Div,
245    Rem,
246    Neg,
247
248    // Bitwise
249    BitAnd,
250    BitOr,
251    BitXor,
252    BitNot,
253    Shl,
254    Shr,
255
256    // Comparison
257    Eq,
258    Ne,
259    Lt,
260    Le,
261    Gt,
262    Ge,
263
264    // Logical
265    And,
266    Or,
267    Not,
268
269    // Type conversion
270    IntToFloat,
271    FloatToInt,
272    SignExtend,
273    ZeroExtend,
274    Truncate,
275
276    // Special
277    Min,
278    Max,
279    Abs,
280    Sqrt,
281    Fma, // Fused multiply-add: a * b + c
282}
283
284impl ComputeOp {
285    /// Number of inputs required for this operation
286    pub fn input_count(&self) -> usize {
287        match self {
288            // Unary
289            ComputeOp::Neg
290            | ComputeOp::BitNot
291            | ComputeOp::Not
292            | ComputeOp::IntToFloat
293            | ComputeOp::FloatToInt
294            | ComputeOp::SignExtend
295            | ComputeOp::ZeroExtend
296            | ComputeOp::Truncate
297            | ComputeOp::Abs
298            | ComputeOp::Sqrt => 1,
299
300            // Binary
301            ComputeOp::Add
302            | ComputeOp::Sub
303            | ComputeOp::Mul
304            | ComputeOp::Div
305            | ComputeOp::Rem
306            | ComputeOp::BitAnd
307            | ComputeOp::BitOr
308            | ComputeOp::BitXor
309            | ComputeOp::Shl
310            | ComputeOp::Shr
311            | ComputeOp::Eq
312            | ComputeOp::Ne
313            | ComputeOp::Lt
314            | ComputeOp::Le
315            | ComputeOp::Gt
316            | ComputeOp::Ge
317            | ComputeOp::And
318            | ComputeOp::Or
319            | ComputeOp::Min
320            | ComputeOp::Max => 2,
321
322            // Ternary
323            ComputeOp::Fma => 3,
324        }
325    }
326
327    /// Estimated energy cost (relative units, higher = more energy)
328    pub fn energy_cost(&self) -> u32 {
329        match self {
330            // Low cost
331            ComputeOp::Add
332            | ComputeOp::Sub
333            | ComputeOp::Neg
334            | ComputeOp::BitAnd
335            | ComputeOp::BitOr
336            | ComputeOp::BitXor
337            | ComputeOp::BitNot
338            | ComputeOp::Shl
339            | ComputeOp::Shr
340            | ComputeOp::Not
341            | ComputeOp::And
342            | ComputeOp::Or
343            | ComputeOp::Eq
344            | ComputeOp::Ne
345            | ComputeOp::Lt
346            | ComputeOp::Le
347            | ComputeOp::Gt
348            | ComputeOp::Ge => 1,
349
350            // Medium cost
351            ComputeOp::Mul
352            | ComputeOp::Min
353            | ComputeOp::Max
354            | ComputeOp::Abs
355            | ComputeOp::SignExtend
356            | ComputeOp::ZeroExtend
357            | ComputeOp::Truncate => 3,
358
359            // High cost
360            ComputeOp::Div
361            | ComputeOp::Rem
362            | ComputeOp::IntToFloat
363            | ComputeOp::FloatToInt
364            | ComputeOp::Sqrt => 10,
365
366            // Very high cost (but good throughput)
367            ComputeOp::Fma => 4,
368        }
369    }
370}
371
372/// Memory operations
373#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
374pub enum MemoryOp {
375    Load,
376    Store,
377}
378
379// ============================================================================
380// Dataflow Operators (RipTide-inspired)
381// ============================================================================
382
383/// A dataflow operator that fires when all inputs are ready
384#[derive(Debug, Clone)]
385pub enum DfOperator {
386    /// Pure computation: inputs → output
387    Compute {
388        op: ComputeOp,
389        inputs: Vec<ChannelId>,
390        output: ChannelId,
391    },
392
393    /// Memory access with ordering constraints
394    Memory {
395        op: MemoryOp,
396        /// Address input
397        address: ChannelId,
398        /// Data input (for stores) or output (for loads)
399        data: ChannelId,
400        /// Ordering dependencies (must complete before this fires)
401        ordering: Vec<ChannelId>,
402    },
403
404    /// Steer: conditional routing (RipTide's key control-flow primitive)
405    ///
406    /// Routes data to one of two outputs based on a boolean decider.
407    /// This enables control flow without traditional branches.
408    Steer {
409        /// Boolean condition channel
410        decider: ChannelId,
411        /// Value to route
412        data: ChannelId,
413        /// Output when decider is true
414        true_out: ChannelId,
415        /// Output when decider is false
416        false_out: ChannelId,
417    },
418
419    /// Stream: fused loop induction variable
420    ///
421    /// This is RipTide's key optimization - fusing the induction variable
422    /// computation (i = 0; i < N; i++) into a single operator.
423    /// Achieves 27-52% operator count reduction.
424    Stream {
425        /// Starting value
426        start: ChannelId,
427        /// Step/increment
428        step: ChannelId,
429        /// Upper bound
430        bound: ChannelId,
431        /// Current value output (fires once per iteration)
432        output: ChannelId,
433        /// Done signal (true when loop complete)
434        done: ChannelId,
435    },
436
437    /// Carry: loop-carried dependency
438    ///
439    /// Handles values that flow from one iteration to the next.
440    Carry {
441        /// Initial value (first iteration)
442        initial: ChannelId,
443        /// Feedback from previous iteration
444        feedback: ChannelId,
445        /// Continue signal (false = loop done)
446        continue_signal: ChannelId,
447        /// Output value
448        output: ChannelId,
449    },
450
451    /// CarryGate: multi-level loop support
452    ///
453    /// Selects between tokens from inner and outer loops.
454    CarryGate {
455        /// Inner loop token
456        inner: ChannelId,
457        /// Outer loop token
458        outer: ChannelId,
459        /// Level selector
460        level: ChannelId,
461        /// Output
462        output: ChannelId,
463    },
464
465    /// Merge: join divergent control paths
466    ///
467    /// Combines tokens from multiple paths into a single stream.
468    /// Used for ordering at control flow join points.
469    Merge {
470        inputs: Vec<ChannelId>,
471        output: ChannelId,
472    },
473
474    /// Split: fan-out to multiple consumers
475    Split {
476        input: ChannelId,
477        outputs: Vec<ChannelId>,
478    },
479
480    /// Constant: emit a constant value
481    Constant {
482        value: TokenValue,
483        output: ChannelId,
484        /// How many times to emit (None = once)
485        repeat: Option<u64>,
486    },
487
488    /// Source: external input to the graph
489    Source {
490        /// External ID for this input
491        external_id: u32,
492        output: ChannelId,
493    },
494
495    /// Sink: external output from the graph
496    Sink {
497        input: ChannelId,
498        /// External ID for this output
499        external_id: u32,
500    },
501
502    /// Select: multiplexer (choose one of N inputs based on selector)
503    Select {
504        /// Selector (index into inputs)
505        selector: ChannelId,
506        /// Input channels
507        inputs: Vec<ChannelId>,
508        /// Output
509        output: ChannelId,
510    },
511
512    /// Reduce: accumulate values (for reductions like sum, product)
513    Reduce {
514        /// Reduction operation
515        op: ComputeOp,
516        /// Initial/identity value
517        initial: ChannelId,
518        /// Stream of values to reduce
519        values: ChannelId,
520        /// Count of values (or done signal)
521        count_or_done: ChannelId,
522        /// Final result
523        output: ChannelId,
524    },
525}
526
527impl DfOperator {
528    /// Get all input channels for this operator
529    pub fn inputs(&self) -> Vec<ChannelId> {
530        match self {
531            DfOperator::Compute { inputs, .. } => inputs.clone(),
532            DfOperator::Memory {
533                address,
534                data,
535                ordering,
536                op,
537            } => {
538                let mut inputs = vec![*address];
539                if *op == MemoryOp::Store {
540                    inputs.push(*data);
541                }
542                inputs.extend(ordering.iter().copied());
543                inputs
544            }
545            DfOperator::Steer { decider, data, .. } => vec![*decider, *data],
546            DfOperator::Stream {
547                start, step, bound, ..
548            } => vec![*start, *step, *bound],
549            DfOperator::Carry {
550                initial,
551                feedback,
552                continue_signal,
553                ..
554            } => {
555                vec![*initial, *feedback, *continue_signal]
556            }
557            DfOperator::CarryGate {
558                inner,
559                outer,
560                level,
561                ..
562            } => vec![*inner, *outer, *level],
563            DfOperator::Merge { inputs, .. } => inputs.clone(),
564            DfOperator::Split { input, .. } => vec![*input],
565            DfOperator::Constant { .. } => vec![],
566            DfOperator::Source { .. } => vec![],
567            DfOperator::Sink { input, .. } => vec![*input],
568            DfOperator::Select {
569                selector, inputs, ..
570            } => {
571                let mut all = vec![*selector];
572                all.extend(inputs.iter().copied());
573                all
574            }
575            DfOperator::Reduce {
576                initial,
577                values,
578                count_or_done,
579                ..
580            } => {
581                vec![*initial, *values, *count_or_done]
582            }
583        }
584    }
585
586    /// Get all output channels for this operator
587    pub fn outputs(&self) -> Vec<ChannelId> {
588        match self {
589            DfOperator::Compute { output, .. } => vec![*output],
590            DfOperator::Memory { data, op, .. } => {
591                if *op == MemoryOp::Load {
592                    vec![*data]
593                } else {
594                    vec![] // Store has no data output
595                }
596            }
597            DfOperator::Steer {
598                true_out,
599                false_out,
600                ..
601            } => vec![*true_out, *false_out],
602            DfOperator::Stream { output, done, .. } => vec![*output, *done],
603            DfOperator::Carry { output, .. } => vec![*output],
604            DfOperator::CarryGate { output, .. } => vec![*output],
605            DfOperator::Merge { output, .. } => vec![*output],
606            DfOperator::Split { outputs, .. } => outputs.clone(),
607            DfOperator::Constant { output, .. } => vec![*output],
608            DfOperator::Source { output, .. } => vec![*output],
609            DfOperator::Sink { .. } => vec![],
610            DfOperator::Select { output, .. } => vec![*output],
611            DfOperator::Reduce { output, .. } => vec![*output],
612        }
613    }
614
615    /// Estimated energy cost for this operator
616    pub fn energy_cost(&self) -> u32 {
617        match self {
618            DfOperator::Compute { op, .. } => op.energy_cost(),
619            DfOperator::Memory { .. } => 20, // Memory access is expensive
620            DfOperator::Steer { .. } => 1,   // Just routing
621            DfOperator::Stream { .. } => 2,  // Fused counter
622            DfOperator::Carry { .. } => 1,
623            DfOperator::CarryGate { .. } => 1,
624            DfOperator::Merge { .. } => 1,
625            DfOperator::Split { .. } => 1,
626            DfOperator::Constant { .. } => 0,
627            DfOperator::Source { .. } => 0,
628            DfOperator::Sink { .. } => 0,
629            DfOperator::Select { .. } => 1,
630            DfOperator::Reduce { op, .. } => op.energy_cost() + 2,
631        }
632    }
633
634    /// Is this a control-flow operator?
635    pub fn is_control(&self) -> bool {
636        matches!(
637            self,
638            DfOperator::Steer { .. }
639                | DfOperator::Stream { .. }
640                | DfOperator::Carry { .. }
641                | DfOperator::CarryGate { .. }
642                | DfOperator::Merge { .. }
643                | DfOperator::Select { .. }
644        )
645    }
646}
647
648// ============================================================================
649// Channel Definition
650// ============================================================================
651
652/// A channel connecting operators
653#[derive(Debug, Clone)]
654pub struct Channel {
655    pub id: ChannelId,
656    /// Type of tokens in this channel
657    pub token_type: TokenType,
658    /// Bounded buffer capacity
659    pub capacity: usize,
660    /// Source operator (None for external inputs)
661    pub source: Option<OperatorId>,
662    /// Destination operators
663    pub destinations: Vec<OperatorId>,
664}
665
666impl Channel {
667    pub fn new(id: ChannelId, token_type: TokenType) -> Self {
668        Self {
669            id,
670            token_type,
671            capacity: 4, // Default capacity
672            source: None,
673            destinations: Vec::new(),
674        }
675    }
676
677    pub fn with_capacity(mut self, capacity: usize) -> Self {
678        self.capacity = capacity;
679        self
680    }
681}
682
683// ============================================================================
684// Dataflow Graph
685// ============================================================================
686
687/// A complete dataflow graph
688#[derive(Debug, Clone)]
689pub struct DataflowGraph {
690    /// All operators in the graph
691    pub operators: Vec<DfOperator>,
692    /// All channels
693    pub channels: Vec<Channel>,
694    /// External input channels (fed from outside)
695    pub inputs: Vec<ChannelId>,
696    /// External output channels (results)
697    pub outputs: Vec<ChannelId>,
698    /// Energy estimates per operator
699    pub energy_estimates: HashMap<OperatorId, EnergyEstimate>,
700    /// Name/label for debugging
701    pub name: Option<String>,
702}
703
704/// Energy estimate for an operator
705#[derive(Debug, Clone, Copy, Default)]
706pub struct EnergyEstimate {
707    /// Static energy cost (per firing)
708    pub static_cost: f64,
709    /// Dynamic energy cost (scales with data)
710    pub dynamic_cost: f64,
711    /// Estimated number of firings
712    pub firing_count: u64,
713}
714
715impl DataflowGraph {
716    /// Create a new empty dataflow graph
717    pub fn new() -> Self {
718        Self {
719            operators: Vec::new(),
720            channels: Vec::new(),
721            inputs: Vec::new(),
722            outputs: Vec::new(),
723            energy_estimates: HashMap::new(),
724            name: None,
725        }
726    }
727
728    /// Create with a name
729    pub fn with_name(name: impl Into<String>) -> Self {
730        Self {
731            name: Some(name.into()),
732            ..Self::new()
733        }
734    }
735
736    /// Add an operator and return its ID
737    pub fn add_operator(&mut self, op: DfOperator) -> OperatorId {
738        let id = OperatorId::new(self.operators.len() as u32);
739
740        // Update channel connections
741        for ch_id in op.outputs() {
742            if let Some(ch) = self.channels.get_mut(ch_id.0 as usize) {
743                ch.source = Some(id);
744            }
745        }
746        for ch_id in op.inputs() {
747            if let Some(ch) = self.channels.get_mut(ch_id.0 as usize) {
748                ch.destinations.push(id);
749            }
750        }
751
752        self.operators.push(op);
753        id
754    }
755
756    /// Add a channel and return its ID
757    pub fn add_channel(&mut self, token_type: TokenType) -> ChannelId {
758        let id = ChannelId::new(self.channels.len() as u32);
759        self.channels.push(Channel::new(id, token_type));
760        id
761    }
762
763    /// Add a channel with specified capacity
764    pub fn add_channel_with_capacity(
765        &mut self,
766        token_type: TokenType,
767        capacity: usize,
768    ) -> ChannelId {
769        let id = ChannelId::new(self.channels.len() as u32);
770        self.channels
771            .push(Channel::new(id, token_type).with_capacity(capacity));
772        id
773    }
774
775    /// Mark a channel as external input
776    pub fn add_input(&mut self, channel: ChannelId) {
777        if !self.inputs.contains(&channel) {
778            self.inputs.push(channel);
779        }
780    }
781
782    /// Mark a channel as external output
783    pub fn add_output(&mut self, channel: ChannelId) {
784        if !self.outputs.contains(&channel) {
785            self.outputs.push(channel);
786        }
787    }
788
789    /// Get total estimated energy for the graph
790    pub fn total_energy_estimate(&self) -> f64 {
791        self.energy_estimates
792            .values()
793            .map(|e| e.static_cost * e.firing_count as f64 + e.dynamic_cost)
794            .sum()
795    }
796
797    /// Get operator count by type
798    pub fn operator_counts(&self) -> HashMap<&'static str, usize> {
799        let mut counts = HashMap::new();
800        for op in &self.operators {
801            let name = match op {
802                DfOperator::Compute { .. } => "Compute",
803                DfOperator::Memory { .. } => "Memory",
804                DfOperator::Steer { .. } => "Steer",
805                DfOperator::Stream { .. } => "Stream",
806                DfOperator::Carry { .. } => "Carry",
807                DfOperator::CarryGate { .. } => "CarryGate",
808                DfOperator::Merge { .. } => "Merge",
809                DfOperator::Split { .. } => "Split",
810                DfOperator::Constant { .. } => "Constant",
811                DfOperator::Source { .. } => "Source",
812                DfOperator::Sink { .. } => "Sink",
813                DfOperator::Select { .. } => "Select",
814                DfOperator::Reduce { .. } => "Reduce",
815            };
816            *counts.entry(name).or_insert(0) += 1;
817        }
818        counts
819    }
820}
821
822impl Default for DataflowGraph {
823    fn default() -> Self {
824        Self::new()
825    }
826}
827
828// ============================================================================
829// Dependency Analysis
830// ============================================================================
831
832/// Dependency analysis for a dataflow graph
833#[derive(Debug)]
834pub struct DependencyAnalysis {
835    /// Predecessors of each operator (operators that must complete first)
836    pub predecessors: HashMap<OperatorId, HashSet<OperatorId>>,
837    /// Successors of each operator (operators that depend on this one)
838    pub successors: HashMap<OperatorId, HashSet<OperatorId>>,
839    /// Operators with no predecessors (can start immediately)
840    pub sources: Vec<OperatorId>,
841    /// Operators with no successors (final outputs)
842    pub sinks: Vec<OperatorId>,
843    /// Critical path length
844    pub critical_path_length: usize,
845    /// Operators on the critical path
846    pub critical_path: Vec<OperatorId>,
847}
848
849impl DependencyAnalysis {
850    /// Analyze dependencies in a dataflow graph
851    pub fn analyze(dfg: &DataflowGraph) -> Self {
852        let mut predecessors: HashMap<OperatorId, HashSet<OperatorId>> = HashMap::new();
853        let mut successors: HashMap<OperatorId, HashSet<OperatorId>> = HashMap::new();
854
855        // Initialize empty sets for all operators
856        for i in 0..dfg.operators.len() {
857            let id = OperatorId::new(i as u32);
858            predecessors.insert(id, HashSet::new());
859            successors.insert(id, HashSet::new());
860        }
861
862        // Build dependency graph from channel connections
863        for (op_idx, op) in dfg.operators.iter().enumerate() {
864            let op_id = OperatorId::new(op_idx as u32);
865
866            for input_ch in op.inputs() {
867                if let Some(channel) = dfg.channels.get(input_ch.0 as usize) {
868                    if let Some(source_op) = channel.source {
869                        predecessors.get_mut(&op_id).unwrap().insert(source_op);
870                        successors.get_mut(&source_op).unwrap().insert(op_id);
871                    }
872                }
873            }
874        }
875
876        // Find sources and sinks
877        let sources: Vec<_> = predecessors
878            .iter()
879            .filter(|(_, preds)| preds.is_empty())
880            .map(|(id, _)| *id)
881            .collect();
882
883        let sinks: Vec<_> = successors
884            .iter()
885            .filter(|(_, succs)| succs.is_empty())
886            .map(|(id, _)| *id)
887            .collect();
888
889        // Compute critical path using longest path algorithm
890        let (critical_path_length, critical_path) =
891            Self::find_critical_path(&predecessors, &successors, &sources, dfg.operators.len());
892
893        Self {
894            predecessors,
895            successors,
896            sources,
897            sinks,
898            critical_path_length,
899            critical_path,
900        }
901    }
902
903    /// Find the critical path (longest dependency chain)
904    fn find_critical_path(
905        predecessors: &HashMap<OperatorId, HashSet<OperatorId>>,
906        successors: &HashMap<OperatorId, HashSet<OperatorId>>,
907        sources: &[OperatorId],
908        num_ops: usize,
909    ) -> (usize, Vec<OperatorId>) {
910        if num_ops == 0 {
911            return (0, Vec::new());
912        }
913
914        // Topological sort with distance tracking
915        let mut in_degree: HashMap<OperatorId, usize> = HashMap::new();
916        let mut dist: HashMap<OperatorId, usize> = HashMap::new();
917        let mut parent: HashMap<OperatorId, Option<OperatorId>> = HashMap::new();
918
919        for i in 0..num_ops {
920            let id = OperatorId::new(i as u32);
921            in_degree.insert(id, predecessors.get(&id).map(|s| s.len()).unwrap_or(0));
922            dist.insert(id, 0);
923            parent.insert(id, None);
924        }
925
926        let mut queue: VecDeque<OperatorId> = sources.iter().copied().collect();
927
928        while let Some(op) = queue.pop_front() {
929            let op_dist = *dist.get(&op).unwrap();
930
931            if let Some(succs) = successors.get(&op) {
932                for &succ in succs {
933                    let new_dist = op_dist + 1;
934                    if new_dist > *dist.get(&succ).unwrap() {
935                        dist.insert(succ, new_dist);
936                        parent.insert(succ, Some(op));
937                    }
938
939                    let degree = in_degree.get_mut(&succ).unwrap();
940                    *degree -= 1;
941                    if *degree == 0 {
942                        queue.push_back(succ);
943                    }
944                }
945            }
946        }
947
948        // Find the endpoint with maximum distance
949        let (max_dist, end_op) = dist
950            .iter()
951            .max_by_key(|(_, d)| *d)
952            .map(|(id, d)| (*d, *id))
953            .unwrap_or((0, OperatorId::new(0)));
954
955        // Reconstruct path
956        let mut path = Vec::new();
957        let mut current = Some(end_op);
958        while let Some(op) = current {
959            path.push(op);
960            current = *parent.get(&op).unwrap();
961        }
962        path.reverse();
963
964        (max_dist, path)
965    }
966
967    /// Get maximum available parallelism at each level
968    pub fn parallelism_profile(&self, _dfg: &DataflowGraph) -> Vec<usize> {
969        let mut levels: HashMap<OperatorId, usize> = HashMap::new();
970        let mut queue: VecDeque<OperatorId> = self.sources.iter().copied().collect();
971
972        for &src in &self.sources {
973            levels.insert(src, 0);
974        }
975
976        while let Some(op) = queue.pop_front() {
977            let op_level = *levels.get(&op).unwrap();
978
979            if let Some(succs) = self.successors.get(&op) {
980                for &succ in succs {
981                    let new_level = op_level + 1;
982                    let current_level = levels.entry(succ).or_insert(0);
983                    if new_level > *current_level {
984                        *current_level = new_level;
985                    }
986
987                    // Check if all predecessors have been processed
988                    let all_preds_done = self
989                        .predecessors
990                        .get(&succ)
991                        .map(|preds| preds.iter().all(|p| levels.contains_key(p)))
992                        .unwrap_or(true);
993
994                    if all_preds_done && !queue.contains(&succ) {
995                        queue.push_back(succ);
996                    }
997                }
998            }
999        }
1000
1001        // Count operators at each level
1002        let max_level = levels.values().max().copied().unwrap_or(0);
1003        let mut profile = vec![0; max_level + 1];
1004        for level in levels.values() {
1005            profile[*level] += 1;
1006        }
1007
1008        profile
1009    }
1010}
1011
1012// ============================================================================
1013// Scheduled Graph
1014// ============================================================================
1015
1016/// A scheduled dataflow graph ready for execution
1017#[derive(Debug, Clone)]
1018pub struct ScheduledDfg {
1019    /// The underlying dataflow graph
1020    pub dfg: DataflowGraph,
1021    /// Execution schedule (groups of operators that can run in parallel)
1022    pub schedule: Vec<ScheduleStep>,
1023    /// Estimated total energy
1024    pub estimated_energy_j: f64,
1025    /// Estimated duration
1026    pub estimated_duration_s: f64,
1027}
1028
1029/// A step in the schedule (operators that can execute in parallel)
1030#[derive(Debug, Clone)]
1031pub struct ScheduleStep {
1032    /// Operators to execute in this step
1033    pub operators: Vec<OperatorId>,
1034    /// Maximum parallelism for this step (may be reduced by energy constraints)
1035    pub max_parallelism: usize,
1036    /// Estimated energy for this step
1037    pub estimated_energy: f64,
1038}
1039
1040// ============================================================================
1041// Display
1042// ============================================================================
1043
1044impl fmt::Display for DataflowGraph {
1045    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1046        writeln!(f, "DataflowGraph {{")?;
1047        if let Some(name) = &self.name {
1048            writeln!(f, "  name: {}", name)?;
1049        }
1050        writeln!(f, "  operators: {},", self.operators.len())?;
1051        writeln!(f, "  channels: {},", self.channels.len())?;
1052        writeln!(f, "  inputs: {:?},", self.inputs)?;
1053        writeln!(f, "  outputs: {:?},", self.outputs)?;
1054
1055        writeln!(f, "  operator_counts: {{")?;
1056        for (name, count) in self.operator_counts() {
1057            writeln!(f, "    {}: {},", name, count)?;
1058        }
1059        writeln!(f, "  }}")?;
1060
1061        writeln!(f, "}}")
1062    }
1063}
1064
1065// ============================================================================
1066// Tests
1067// ============================================================================
1068
1069#[cfg(test)]
1070mod tests {
1071    use super::*;
1072
1073    #[test]
1074    fn test_create_dfg() {
1075        let dfg = DataflowGraph::with_name("test");
1076        assert_eq!(dfg.name, Some("test".to_string()));
1077        assert!(dfg.operators.is_empty());
1078        assert!(dfg.channels.is_empty());
1079    }
1080
1081    #[test]
1082    fn test_add_channel_and_operator() {
1083        let mut dfg = DataflowGraph::new();
1084
1085        let ch_a = dfg.add_channel(TokenType::i32());
1086        let ch_b = dfg.add_channel(TokenType::i32());
1087        let ch_out = dfg.add_channel(TokenType::i32());
1088
1089        dfg.add_input(ch_a);
1090        dfg.add_input(ch_b);
1091
1092        let op_id = dfg.add_operator(DfOperator::Compute {
1093            op: ComputeOp::Add,
1094            inputs: vec![ch_a, ch_b],
1095            output: ch_out,
1096        });
1097
1098        dfg.add_output(ch_out);
1099
1100        assert_eq!(dfg.operators.len(), 1);
1101        assert_eq!(dfg.channels.len(), 3);
1102        assert_eq!(dfg.inputs.len(), 2);
1103        assert_eq!(dfg.outputs.len(), 1);
1104        assert_eq!(op_id, OperatorId::new(0));
1105    }
1106
1107    #[test]
1108    fn test_steer_operator() {
1109        let mut dfg = DataflowGraph::new();
1110
1111        let decider = dfg.add_channel(TokenType::Bool);
1112        let data = dfg.add_channel(TokenType::i32());
1113        let true_out = dfg.add_channel(TokenType::i32());
1114        let false_out = dfg.add_channel(TokenType::i32());
1115
1116        let steer = DfOperator::Steer {
1117            decider,
1118            data,
1119            true_out,
1120            false_out,
1121        };
1122
1123        assert!(steer.is_control());
1124        assert_eq!(steer.inputs().len(), 2);
1125        assert_eq!(steer.outputs().len(), 2);
1126        assert_eq!(steer.energy_cost(), 1);
1127    }
1128
1129    #[test]
1130    fn test_stream_operator() {
1131        let mut dfg = DataflowGraph::new();
1132
1133        let start = dfg.add_channel(TokenType::i64());
1134        let step = dfg.add_channel(TokenType::i64());
1135        let bound = dfg.add_channel(TokenType::i64());
1136        let output = dfg.add_channel(TokenType::i64());
1137        let done = dfg.add_channel(TokenType::Bool);
1138
1139        let stream = DfOperator::Stream {
1140            start,
1141            step,
1142            bound,
1143            output,
1144            done,
1145        };
1146
1147        assert!(stream.is_control());
1148        assert_eq!(stream.inputs().len(), 3);
1149        assert_eq!(stream.outputs().len(), 2);
1150    }
1151
1152    #[test]
1153    fn test_dependency_analysis() {
1154        let mut dfg = DataflowGraph::new();
1155
1156        // Create a simple graph: A -> B -> C
1157        let ch1 = dfg.add_channel(TokenType::i32());
1158        let ch2 = dfg.add_channel(TokenType::i32());
1159        let _ch3 = dfg.add_channel(TokenType::i32());
1160
1161        // Source operator
1162        dfg.add_operator(DfOperator::Source {
1163            external_id: 0,
1164            output: ch1,
1165        });
1166
1167        // Compute operator
1168        dfg.add_operator(DfOperator::Compute {
1169            op: ComputeOp::Add,
1170            inputs: vec![ch1],
1171            output: ch2,
1172        });
1173
1174        // Sink operator
1175        dfg.add_operator(DfOperator::Sink {
1176            input: ch2,
1177            external_id: 0,
1178        });
1179
1180        let analysis = DependencyAnalysis::analyze(&dfg);
1181
1182        assert_eq!(analysis.sources.len(), 1);
1183        assert_eq!(analysis.sinks.len(), 1);
1184        assert_eq!(analysis.critical_path_length, 2);
1185    }
1186
1187    #[test]
1188    fn test_parallel_graph() {
1189        let mut dfg = DataflowGraph::new();
1190
1191        // Create parallel graph:
1192        //     A
1193        //    / \
1194        //   B   C
1195        //    \ /
1196        //     D
1197
1198        let ch_in = dfg.add_channel(TokenType::i32());
1199        let ch_b = dfg.add_channel(TokenType::i32());
1200        let ch_c = dfg.add_channel(TokenType::i32());
1201        let ch_out = dfg.add_channel(TokenType::i32());
1202
1203        // A: Source
1204        dfg.add_operator(DfOperator::Source {
1205            external_id: 0,
1206            output: ch_in,
1207        });
1208
1209        // Split to B and C
1210        let split_b = dfg.add_channel(TokenType::i32());
1211        let split_c = dfg.add_channel(TokenType::i32());
1212        dfg.add_operator(DfOperator::Split {
1213            input: ch_in,
1214            outputs: vec![split_b, split_c],
1215        });
1216
1217        // B: Compute
1218        dfg.add_operator(DfOperator::Compute {
1219            op: ComputeOp::Mul,
1220            inputs: vec![split_b],
1221            output: ch_b,
1222        });
1223
1224        // C: Compute
1225        dfg.add_operator(DfOperator::Compute {
1226            op: ComputeOp::Add,
1227            inputs: vec![split_c],
1228            output: ch_c,
1229        });
1230
1231        // D: Merge and output
1232        dfg.add_operator(DfOperator::Compute {
1233            op: ComputeOp::Add,
1234            inputs: vec![ch_b, ch_c],
1235            output: ch_out,
1236        });
1237
1238        dfg.add_operator(DfOperator::Sink {
1239            input: ch_out,
1240            external_id: 0,
1241        });
1242
1243        let analysis = DependencyAnalysis::analyze(&dfg);
1244        let profile = analysis.parallelism_profile(&dfg);
1245
1246        // Should have some level with parallelism > 1
1247        assert!(profile.iter().any(|&p| p >= 2));
1248    }
1249
1250    #[test]
1251    fn test_compute_op_costs() {
1252        assert_eq!(ComputeOp::Add.energy_cost(), 1);
1253        assert_eq!(ComputeOp::Mul.energy_cost(), 3);
1254        assert_eq!(ComputeOp::Div.energy_cost(), 10);
1255        assert_eq!(ComputeOp::Fma.energy_cost(), 4);
1256
1257        assert_eq!(ComputeOp::Add.input_count(), 2);
1258        assert_eq!(ComputeOp::Neg.input_count(), 1);
1259        assert_eq!(ComputeOp::Fma.input_count(), 3);
1260    }
1261
1262    #[test]
1263    #[allow(clippy::float_cmp, clippy::approx_constant)]
1264    fn test_token_value_conversions() {
1265        let b = TokenValue::Bool(true);
1266        assert!(b.as_bool().unwrap());
1267        assert_eq!(b.as_i64().unwrap(), 1);
1268
1269        let i = TokenValue::Int(-42);
1270        assert_eq!(i.as_i64().unwrap(), -42);
1271        assert_eq!(i.as_f64().unwrap(), -42.0);
1272
1273        let f = TokenValue::Float(3.14);
1274        assert!((f.as_f64().unwrap() - 3.14).abs() < 0.001);
1275
1276        // Verify that array/tuple conversions return errors, not panics
1277        let arr = TokenValue::Array(vec![TokenValue::Int(1)]);
1278        assert!(arr.as_bool().is_err());
1279        assert!(arr.as_i64().is_err());
1280        assert!(arr.as_u64().is_err());
1281        assert!(arr.as_f64().is_err());
1282    }
1283}