1pub mod extract;
27pub mod optimize;
28pub mod schedule;
29
30pub use extract::DfgExtractor;
31pub use optimize::{DfgOptimizer, StreamFusion};
32pub use schedule::{EnergyAwareScheduler, SchedulingConfig, ThermalConstraint};
33
34use std::collections::{HashMap, HashSet, VecDeque};
35use std::fmt;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
43pub struct OperatorId(pub u32);
44
45impl OperatorId {
46 pub fn new(id: u32) -> Self {
47 Self(id)
48 }
49}
50
51impl fmt::Display for OperatorId {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 write!(f, "op{}", self.0)
54 }
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
59pub struct ChannelId(pub u32);
60
61impl ChannelId {
62 pub fn new(id: u32) -> Self {
63 Self(id)
64 }
65}
66
67impl fmt::Display for ChannelId {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 write!(f, "ch{}", self.0)
70 }
71}
72
73#[derive(Debug, Clone)]
79pub enum TokenValue {
80 Unit,
82 Bool(bool),
84 Int(i64),
86 Uint(u64),
88 Float(f64),
90 Ptr(u64),
92 Array(Vec<TokenValue>),
94 Tuple(Vec<TokenValue>),
96}
97
98#[derive(Debug, Clone)]
100pub struct TokenConversionError {
101 pub value: String,
102 pub target_type: &'static str,
103}
104
105impl fmt::Display for TokenConversionError {
106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 write!(f, "cannot convert {} to {}", self.value, self.target_type)
108 }
109}
110
111impl std::error::Error for TokenConversionError {}
112
113impl TokenValue {
114 pub fn as_bool(&self) -> Result<bool, TokenConversionError> {
115 match self {
116 TokenValue::Bool(b) => Ok(*b),
117 TokenValue::Int(i) => Ok(*i != 0),
118 TokenValue::Uint(u) => Ok(*u != 0),
119 TokenValue::Unit => Ok(false),
120 TokenValue::Float(f) => Ok(*f != 0.0),
121 TokenValue::Ptr(p) => Ok(*p != 0),
122 TokenValue::Array(_) | TokenValue::Tuple(_) => Err(TokenConversionError {
123 value: format!("{:?}", self),
124 target_type: "bool",
125 }),
126 }
127 }
128
129 pub fn as_i64(&self) -> Result<i64, TokenConversionError> {
130 match self {
131 TokenValue::Int(i) => Ok(*i),
132 TokenValue::Uint(u) => Ok(*u as i64),
133 TokenValue::Bool(b) => Ok(if *b { 1 } else { 0 }),
134 TokenValue::Float(f) => Ok(*f as i64),
135 TokenValue::Ptr(p) => Ok(*p as i64),
136 TokenValue::Unit => Ok(0),
137 TokenValue::Array(_) | TokenValue::Tuple(_) => Err(TokenConversionError {
138 value: format!("{:?}", self),
139 target_type: "i64",
140 }),
141 }
142 }
143
144 pub fn as_u64(&self) -> Result<u64, TokenConversionError> {
145 match self {
146 TokenValue::Uint(u) => Ok(*u),
147 TokenValue::Int(i) => Ok(*i as u64),
148 TokenValue::Bool(b) => Ok(if *b { 1 } else { 0 }),
149 TokenValue::Float(f) => Ok(*f as u64),
150 TokenValue::Ptr(p) => Ok(*p),
151 TokenValue::Unit => Ok(0),
152 TokenValue::Array(_) | TokenValue::Tuple(_) => Err(TokenConversionError {
153 value: format!("{:?}", self),
154 target_type: "u64",
155 }),
156 }
157 }
158
159 pub fn as_f64(&self) -> Result<f64, TokenConversionError> {
160 match self {
161 TokenValue::Float(f) => Ok(*f),
162 TokenValue::Int(i) => Ok(*i as f64),
163 TokenValue::Uint(u) => Ok(*u as f64),
164 TokenValue::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
165 TokenValue::Ptr(p) => Ok(*p as f64),
166 TokenValue::Unit => Ok(0.0),
167 TokenValue::Array(_) | TokenValue::Tuple(_) => Err(TokenConversionError {
168 value: format!("{:?}", self),
169 target_type: "f64",
170 }),
171 }
172 }
173}
174
175#[derive(Debug, Clone, PartialEq, Eq, Hash)]
177pub enum TokenType {
178 Unit,
179 Bool,
180 Int {
181 bits: u8,
182 signed: bool,
183 },
184 Float {
185 bits: u8,
186 },
187 Ptr,
188 Array {
189 element: Box<TokenType>,
190 size: usize,
191 },
192 Tuple(Vec<TokenType>),
193}
194
195impl TokenType {
196 pub fn i32() -> Self {
197 TokenType::Int {
198 bits: 32,
199 signed: true,
200 }
201 }
202
203 pub fn i64() -> Self {
204 TokenType::Int {
205 bits: 64,
206 signed: true,
207 }
208 }
209
210 pub fn u32() -> Self {
211 TokenType::Int {
212 bits: 32,
213 signed: false,
214 }
215 }
216
217 pub fn u64() -> Self {
218 TokenType::Int {
219 bits: 64,
220 signed: false,
221 }
222 }
223
224 pub fn f32() -> Self {
225 TokenType::Float { bits: 32 }
226 }
227
228 pub fn f64() -> Self {
229 TokenType::Float { bits: 64 }
230 }
231}
232
233#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
239pub enum ComputeOp {
240 Add,
242 Sub,
243 Mul,
244 Div,
245 Rem,
246 Neg,
247
248 BitAnd,
250 BitOr,
251 BitXor,
252 BitNot,
253 Shl,
254 Shr,
255
256 Eq,
258 Ne,
259 Lt,
260 Le,
261 Gt,
262 Ge,
263
264 And,
266 Or,
267 Not,
268
269 IntToFloat,
271 FloatToInt,
272 SignExtend,
273 ZeroExtend,
274 Truncate,
275
276 Min,
278 Max,
279 Abs,
280 Sqrt,
281 Fma, }
283
284impl ComputeOp {
285 pub fn input_count(&self) -> usize {
287 match self {
288 ComputeOp::Neg
290 | ComputeOp::BitNot
291 | ComputeOp::Not
292 | ComputeOp::IntToFloat
293 | ComputeOp::FloatToInt
294 | ComputeOp::SignExtend
295 | ComputeOp::ZeroExtend
296 | ComputeOp::Truncate
297 | ComputeOp::Abs
298 | ComputeOp::Sqrt => 1,
299
300 ComputeOp::Add
302 | ComputeOp::Sub
303 | ComputeOp::Mul
304 | ComputeOp::Div
305 | ComputeOp::Rem
306 | ComputeOp::BitAnd
307 | ComputeOp::BitOr
308 | ComputeOp::BitXor
309 | ComputeOp::Shl
310 | ComputeOp::Shr
311 | ComputeOp::Eq
312 | ComputeOp::Ne
313 | ComputeOp::Lt
314 | ComputeOp::Le
315 | ComputeOp::Gt
316 | ComputeOp::Ge
317 | ComputeOp::And
318 | ComputeOp::Or
319 | ComputeOp::Min
320 | ComputeOp::Max => 2,
321
322 ComputeOp::Fma => 3,
324 }
325 }
326
327 pub fn energy_cost(&self) -> u32 {
329 match self {
330 ComputeOp::Add
332 | ComputeOp::Sub
333 | ComputeOp::Neg
334 | ComputeOp::BitAnd
335 | ComputeOp::BitOr
336 | ComputeOp::BitXor
337 | ComputeOp::BitNot
338 | ComputeOp::Shl
339 | ComputeOp::Shr
340 | ComputeOp::Not
341 | ComputeOp::And
342 | ComputeOp::Or
343 | ComputeOp::Eq
344 | ComputeOp::Ne
345 | ComputeOp::Lt
346 | ComputeOp::Le
347 | ComputeOp::Gt
348 | ComputeOp::Ge => 1,
349
350 ComputeOp::Mul
352 | ComputeOp::Min
353 | ComputeOp::Max
354 | ComputeOp::Abs
355 | ComputeOp::SignExtend
356 | ComputeOp::ZeroExtend
357 | ComputeOp::Truncate => 3,
358
359 ComputeOp::Div
361 | ComputeOp::Rem
362 | ComputeOp::IntToFloat
363 | ComputeOp::FloatToInt
364 | ComputeOp::Sqrt => 10,
365
366 ComputeOp::Fma => 4,
368 }
369 }
370}
371
372#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
374pub enum MemoryOp {
375 Load,
376 Store,
377}
378
379#[derive(Debug, Clone)]
385pub enum DfOperator {
386 Compute {
388 op: ComputeOp,
389 inputs: Vec<ChannelId>,
390 output: ChannelId,
391 },
392
393 Memory {
395 op: MemoryOp,
396 address: ChannelId,
398 data: ChannelId,
400 ordering: Vec<ChannelId>,
402 },
403
404 Steer {
409 decider: ChannelId,
411 data: ChannelId,
413 true_out: ChannelId,
415 false_out: ChannelId,
417 },
418
419 Stream {
425 start: ChannelId,
427 step: ChannelId,
429 bound: ChannelId,
431 output: ChannelId,
433 done: ChannelId,
435 },
436
437 Carry {
441 initial: ChannelId,
443 feedback: ChannelId,
445 continue_signal: ChannelId,
447 output: ChannelId,
449 },
450
451 CarryGate {
455 inner: ChannelId,
457 outer: ChannelId,
459 level: ChannelId,
461 output: ChannelId,
463 },
464
465 Merge {
470 inputs: Vec<ChannelId>,
471 output: ChannelId,
472 },
473
474 Split {
476 input: ChannelId,
477 outputs: Vec<ChannelId>,
478 },
479
480 Constant {
482 value: TokenValue,
483 output: ChannelId,
484 repeat: Option<u64>,
486 },
487
488 Source {
490 external_id: u32,
492 output: ChannelId,
493 },
494
495 Sink {
497 input: ChannelId,
498 external_id: u32,
500 },
501
502 Select {
504 selector: ChannelId,
506 inputs: Vec<ChannelId>,
508 output: ChannelId,
510 },
511
512 Reduce {
514 op: ComputeOp,
516 initial: ChannelId,
518 values: ChannelId,
520 count_or_done: ChannelId,
522 output: ChannelId,
524 },
525}
526
527impl DfOperator {
528 pub fn inputs(&self) -> Vec<ChannelId> {
530 match self {
531 DfOperator::Compute { inputs, .. } => inputs.clone(),
532 DfOperator::Memory {
533 address,
534 data,
535 ordering,
536 op,
537 } => {
538 let mut inputs = vec![*address];
539 if *op == MemoryOp::Store {
540 inputs.push(*data);
541 }
542 inputs.extend(ordering.iter().copied());
543 inputs
544 }
545 DfOperator::Steer { decider, data, .. } => vec![*decider, *data],
546 DfOperator::Stream {
547 start, step, bound, ..
548 } => vec![*start, *step, *bound],
549 DfOperator::Carry {
550 initial,
551 feedback,
552 continue_signal,
553 ..
554 } => {
555 vec![*initial, *feedback, *continue_signal]
556 }
557 DfOperator::CarryGate {
558 inner,
559 outer,
560 level,
561 ..
562 } => vec![*inner, *outer, *level],
563 DfOperator::Merge { inputs, .. } => inputs.clone(),
564 DfOperator::Split { input, .. } => vec![*input],
565 DfOperator::Constant { .. } => vec![],
566 DfOperator::Source { .. } => vec![],
567 DfOperator::Sink { input, .. } => vec![*input],
568 DfOperator::Select {
569 selector, inputs, ..
570 } => {
571 let mut all = vec![*selector];
572 all.extend(inputs.iter().copied());
573 all
574 }
575 DfOperator::Reduce {
576 initial,
577 values,
578 count_or_done,
579 ..
580 } => {
581 vec![*initial, *values, *count_or_done]
582 }
583 }
584 }
585
586 pub fn outputs(&self) -> Vec<ChannelId> {
588 match self {
589 DfOperator::Compute { output, .. } => vec![*output],
590 DfOperator::Memory { data, op, .. } => {
591 if *op == MemoryOp::Load {
592 vec![*data]
593 } else {
594 vec![] }
596 }
597 DfOperator::Steer {
598 true_out,
599 false_out,
600 ..
601 } => vec![*true_out, *false_out],
602 DfOperator::Stream { output, done, .. } => vec![*output, *done],
603 DfOperator::Carry { output, .. } => vec![*output],
604 DfOperator::CarryGate { output, .. } => vec![*output],
605 DfOperator::Merge { output, .. } => vec![*output],
606 DfOperator::Split { outputs, .. } => outputs.clone(),
607 DfOperator::Constant { output, .. } => vec![*output],
608 DfOperator::Source { output, .. } => vec![*output],
609 DfOperator::Sink { .. } => vec![],
610 DfOperator::Select { output, .. } => vec![*output],
611 DfOperator::Reduce { output, .. } => vec![*output],
612 }
613 }
614
615 pub fn energy_cost(&self) -> u32 {
617 match self {
618 DfOperator::Compute { op, .. } => op.energy_cost(),
619 DfOperator::Memory { .. } => 20, DfOperator::Steer { .. } => 1, DfOperator::Stream { .. } => 2, DfOperator::Carry { .. } => 1,
623 DfOperator::CarryGate { .. } => 1,
624 DfOperator::Merge { .. } => 1,
625 DfOperator::Split { .. } => 1,
626 DfOperator::Constant { .. } => 0,
627 DfOperator::Source { .. } => 0,
628 DfOperator::Sink { .. } => 0,
629 DfOperator::Select { .. } => 1,
630 DfOperator::Reduce { op, .. } => op.energy_cost() + 2,
631 }
632 }
633
634 pub fn is_control(&self) -> bool {
636 matches!(
637 self,
638 DfOperator::Steer { .. }
639 | DfOperator::Stream { .. }
640 | DfOperator::Carry { .. }
641 | DfOperator::CarryGate { .. }
642 | DfOperator::Merge { .. }
643 | DfOperator::Select { .. }
644 )
645 }
646}
647
648#[derive(Debug, Clone)]
654pub struct Channel {
655 pub id: ChannelId,
656 pub token_type: TokenType,
658 pub capacity: usize,
660 pub source: Option<OperatorId>,
662 pub destinations: Vec<OperatorId>,
664}
665
666impl Channel {
667 pub fn new(id: ChannelId, token_type: TokenType) -> Self {
668 Self {
669 id,
670 token_type,
671 capacity: 4, source: None,
673 destinations: Vec::new(),
674 }
675 }
676
677 pub fn with_capacity(mut self, capacity: usize) -> Self {
678 self.capacity = capacity;
679 self
680 }
681}
682
683#[derive(Debug, Clone)]
689pub struct DataflowGraph {
690 pub operators: Vec<DfOperator>,
692 pub channels: Vec<Channel>,
694 pub inputs: Vec<ChannelId>,
696 pub outputs: Vec<ChannelId>,
698 pub energy_estimates: HashMap<OperatorId, EnergyEstimate>,
700 pub name: Option<String>,
702}
703
704#[derive(Debug, Clone, Copy, Default)]
706pub struct EnergyEstimate {
707 pub static_cost: f64,
709 pub dynamic_cost: f64,
711 pub firing_count: u64,
713}
714
715impl DataflowGraph {
716 pub fn new() -> Self {
718 Self {
719 operators: Vec::new(),
720 channels: Vec::new(),
721 inputs: Vec::new(),
722 outputs: Vec::new(),
723 energy_estimates: HashMap::new(),
724 name: None,
725 }
726 }
727
728 pub fn with_name(name: impl Into<String>) -> Self {
730 Self {
731 name: Some(name.into()),
732 ..Self::new()
733 }
734 }
735
736 pub fn add_operator(&mut self, op: DfOperator) -> OperatorId {
738 let id = OperatorId::new(self.operators.len() as u32);
739
740 for ch_id in op.outputs() {
742 if let Some(ch) = self.channels.get_mut(ch_id.0 as usize) {
743 ch.source = Some(id);
744 }
745 }
746 for ch_id in op.inputs() {
747 if let Some(ch) = self.channels.get_mut(ch_id.0 as usize) {
748 ch.destinations.push(id);
749 }
750 }
751
752 self.operators.push(op);
753 id
754 }
755
756 pub fn add_channel(&mut self, token_type: TokenType) -> ChannelId {
758 let id = ChannelId::new(self.channels.len() as u32);
759 self.channels.push(Channel::new(id, token_type));
760 id
761 }
762
763 pub fn add_channel_with_capacity(
765 &mut self,
766 token_type: TokenType,
767 capacity: usize,
768 ) -> ChannelId {
769 let id = ChannelId::new(self.channels.len() as u32);
770 self.channels
771 .push(Channel::new(id, token_type).with_capacity(capacity));
772 id
773 }
774
775 pub fn add_input(&mut self, channel: ChannelId) {
777 if !self.inputs.contains(&channel) {
778 self.inputs.push(channel);
779 }
780 }
781
782 pub fn add_output(&mut self, channel: ChannelId) {
784 if !self.outputs.contains(&channel) {
785 self.outputs.push(channel);
786 }
787 }
788
789 pub fn total_energy_estimate(&self) -> f64 {
791 self.energy_estimates
792 .values()
793 .map(|e| e.static_cost * e.firing_count as f64 + e.dynamic_cost)
794 .sum()
795 }
796
797 pub fn operator_counts(&self) -> HashMap<&'static str, usize> {
799 let mut counts = HashMap::new();
800 for op in &self.operators {
801 let name = match op {
802 DfOperator::Compute { .. } => "Compute",
803 DfOperator::Memory { .. } => "Memory",
804 DfOperator::Steer { .. } => "Steer",
805 DfOperator::Stream { .. } => "Stream",
806 DfOperator::Carry { .. } => "Carry",
807 DfOperator::CarryGate { .. } => "CarryGate",
808 DfOperator::Merge { .. } => "Merge",
809 DfOperator::Split { .. } => "Split",
810 DfOperator::Constant { .. } => "Constant",
811 DfOperator::Source { .. } => "Source",
812 DfOperator::Sink { .. } => "Sink",
813 DfOperator::Select { .. } => "Select",
814 DfOperator::Reduce { .. } => "Reduce",
815 };
816 *counts.entry(name).or_insert(0) += 1;
817 }
818 counts
819 }
820}
821
822impl Default for DataflowGraph {
823 fn default() -> Self {
824 Self::new()
825 }
826}
827
828#[derive(Debug)]
834pub struct DependencyAnalysis {
835 pub predecessors: HashMap<OperatorId, HashSet<OperatorId>>,
837 pub successors: HashMap<OperatorId, HashSet<OperatorId>>,
839 pub sources: Vec<OperatorId>,
841 pub sinks: Vec<OperatorId>,
843 pub critical_path_length: usize,
845 pub critical_path: Vec<OperatorId>,
847}
848
849impl DependencyAnalysis {
850 pub fn analyze(dfg: &DataflowGraph) -> Self {
852 let mut predecessors: HashMap<OperatorId, HashSet<OperatorId>> = HashMap::new();
853 let mut successors: HashMap<OperatorId, HashSet<OperatorId>> = HashMap::new();
854
855 for i in 0..dfg.operators.len() {
857 let id = OperatorId::new(i as u32);
858 predecessors.insert(id, HashSet::new());
859 successors.insert(id, HashSet::new());
860 }
861
862 for (op_idx, op) in dfg.operators.iter().enumerate() {
864 let op_id = OperatorId::new(op_idx as u32);
865
866 for input_ch in op.inputs() {
867 if let Some(channel) = dfg.channels.get(input_ch.0 as usize) {
868 if let Some(source_op) = channel.source {
869 predecessors.get_mut(&op_id).unwrap().insert(source_op);
870 successors.get_mut(&source_op).unwrap().insert(op_id);
871 }
872 }
873 }
874 }
875
876 let sources: Vec<_> = predecessors
878 .iter()
879 .filter(|(_, preds)| preds.is_empty())
880 .map(|(id, _)| *id)
881 .collect();
882
883 let sinks: Vec<_> = successors
884 .iter()
885 .filter(|(_, succs)| succs.is_empty())
886 .map(|(id, _)| *id)
887 .collect();
888
889 let (critical_path_length, critical_path) =
891 Self::find_critical_path(&predecessors, &successors, &sources, dfg.operators.len());
892
893 Self {
894 predecessors,
895 successors,
896 sources,
897 sinks,
898 critical_path_length,
899 critical_path,
900 }
901 }
902
903 fn find_critical_path(
905 predecessors: &HashMap<OperatorId, HashSet<OperatorId>>,
906 successors: &HashMap<OperatorId, HashSet<OperatorId>>,
907 sources: &[OperatorId],
908 num_ops: usize,
909 ) -> (usize, Vec<OperatorId>) {
910 if num_ops == 0 {
911 return (0, Vec::new());
912 }
913
914 let mut in_degree: HashMap<OperatorId, usize> = HashMap::new();
916 let mut dist: HashMap<OperatorId, usize> = HashMap::new();
917 let mut parent: HashMap<OperatorId, Option<OperatorId>> = HashMap::new();
918
919 for i in 0..num_ops {
920 let id = OperatorId::new(i as u32);
921 in_degree.insert(id, predecessors.get(&id).map(|s| s.len()).unwrap_or(0));
922 dist.insert(id, 0);
923 parent.insert(id, None);
924 }
925
926 let mut queue: VecDeque<OperatorId> = sources.iter().copied().collect();
927
928 while let Some(op) = queue.pop_front() {
929 let op_dist = *dist.get(&op).unwrap();
930
931 if let Some(succs) = successors.get(&op) {
932 for &succ in succs {
933 let new_dist = op_dist + 1;
934 if new_dist > *dist.get(&succ).unwrap() {
935 dist.insert(succ, new_dist);
936 parent.insert(succ, Some(op));
937 }
938
939 let degree = in_degree.get_mut(&succ).unwrap();
940 *degree -= 1;
941 if *degree == 0 {
942 queue.push_back(succ);
943 }
944 }
945 }
946 }
947
948 let (max_dist, end_op) = dist
950 .iter()
951 .max_by_key(|(_, d)| *d)
952 .map(|(id, d)| (*d, *id))
953 .unwrap_or((0, OperatorId::new(0)));
954
955 let mut path = Vec::new();
957 let mut current = Some(end_op);
958 while let Some(op) = current {
959 path.push(op);
960 current = *parent.get(&op).unwrap();
961 }
962 path.reverse();
963
964 (max_dist, path)
965 }
966
967 pub fn parallelism_profile(&self, _dfg: &DataflowGraph) -> Vec<usize> {
969 let mut levels: HashMap<OperatorId, usize> = HashMap::new();
970 let mut queue: VecDeque<OperatorId> = self.sources.iter().copied().collect();
971
972 for &src in &self.sources {
973 levels.insert(src, 0);
974 }
975
976 while let Some(op) = queue.pop_front() {
977 let op_level = *levels.get(&op).unwrap();
978
979 if let Some(succs) = self.successors.get(&op) {
980 for &succ in succs {
981 let new_level = op_level + 1;
982 let current_level = levels.entry(succ).or_insert(0);
983 if new_level > *current_level {
984 *current_level = new_level;
985 }
986
987 let all_preds_done = self
989 .predecessors
990 .get(&succ)
991 .map(|preds| preds.iter().all(|p| levels.contains_key(p)))
992 .unwrap_or(true);
993
994 if all_preds_done && !queue.contains(&succ) {
995 queue.push_back(succ);
996 }
997 }
998 }
999 }
1000
1001 let max_level = levels.values().max().copied().unwrap_or(0);
1003 let mut profile = vec![0; max_level + 1];
1004 for level in levels.values() {
1005 profile[*level] += 1;
1006 }
1007
1008 profile
1009 }
1010}
1011
1012#[derive(Debug, Clone)]
1018pub struct ScheduledDfg {
1019 pub dfg: DataflowGraph,
1021 pub schedule: Vec<ScheduleStep>,
1023 pub estimated_energy_j: f64,
1025 pub estimated_duration_s: f64,
1027}
1028
1029#[derive(Debug, Clone)]
1031pub struct ScheduleStep {
1032 pub operators: Vec<OperatorId>,
1034 pub max_parallelism: usize,
1036 pub estimated_energy: f64,
1038}
1039
1040impl fmt::Display for DataflowGraph {
1045 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1046 writeln!(f, "DataflowGraph {{")?;
1047 if let Some(name) = &self.name {
1048 writeln!(f, " name: {}", name)?;
1049 }
1050 writeln!(f, " operators: {},", self.operators.len())?;
1051 writeln!(f, " channels: {},", self.channels.len())?;
1052 writeln!(f, " inputs: {:?},", self.inputs)?;
1053 writeln!(f, " outputs: {:?},", self.outputs)?;
1054
1055 writeln!(f, " operator_counts: {{")?;
1056 for (name, count) in self.operator_counts() {
1057 writeln!(f, " {}: {},", name, count)?;
1058 }
1059 writeln!(f, " }}")?;
1060
1061 writeln!(f, "}}")
1062 }
1063}
1064
1065#[cfg(test)]
1070mod tests {
1071 use super::*;
1072
1073 #[test]
1074 fn test_create_dfg() {
1075 let dfg = DataflowGraph::with_name("test");
1076 assert_eq!(dfg.name, Some("test".to_string()));
1077 assert!(dfg.operators.is_empty());
1078 assert!(dfg.channels.is_empty());
1079 }
1080
1081 #[test]
1082 fn test_add_channel_and_operator() {
1083 let mut dfg = DataflowGraph::new();
1084
1085 let ch_a = dfg.add_channel(TokenType::i32());
1086 let ch_b = dfg.add_channel(TokenType::i32());
1087 let ch_out = dfg.add_channel(TokenType::i32());
1088
1089 dfg.add_input(ch_a);
1090 dfg.add_input(ch_b);
1091
1092 let op_id = dfg.add_operator(DfOperator::Compute {
1093 op: ComputeOp::Add,
1094 inputs: vec![ch_a, ch_b],
1095 output: ch_out,
1096 });
1097
1098 dfg.add_output(ch_out);
1099
1100 assert_eq!(dfg.operators.len(), 1);
1101 assert_eq!(dfg.channels.len(), 3);
1102 assert_eq!(dfg.inputs.len(), 2);
1103 assert_eq!(dfg.outputs.len(), 1);
1104 assert_eq!(op_id, OperatorId::new(0));
1105 }
1106
1107 #[test]
1108 fn test_steer_operator() {
1109 let mut dfg = DataflowGraph::new();
1110
1111 let decider = dfg.add_channel(TokenType::Bool);
1112 let data = dfg.add_channel(TokenType::i32());
1113 let true_out = dfg.add_channel(TokenType::i32());
1114 let false_out = dfg.add_channel(TokenType::i32());
1115
1116 let steer = DfOperator::Steer {
1117 decider,
1118 data,
1119 true_out,
1120 false_out,
1121 };
1122
1123 assert!(steer.is_control());
1124 assert_eq!(steer.inputs().len(), 2);
1125 assert_eq!(steer.outputs().len(), 2);
1126 assert_eq!(steer.energy_cost(), 1);
1127 }
1128
1129 #[test]
1130 fn test_stream_operator() {
1131 let mut dfg = DataflowGraph::new();
1132
1133 let start = dfg.add_channel(TokenType::i64());
1134 let step = dfg.add_channel(TokenType::i64());
1135 let bound = dfg.add_channel(TokenType::i64());
1136 let output = dfg.add_channel(TokenType::i64());
1137 let done = dfg.add_channel(TokenType::Bool);
1138
1139 let stream = DfOperator::Stream {
1140 start,
1141 step,
1142 bound,
1143 output,
1144 done,
1145 };
1146
1147 assert!(stream.is_control());
1148 assert_eq!(stream.inputs().len(), 3);
1149 assert_eq!(stream.outputs().len(), 2);
1150 }
1151
1152 #[test]
1153 fn test_dependency_analysis() {
1154 let mut dfg = DataflowGraph::new();
1155
1156 let ch1 = dfg.add_channel(TokenType::i32());
1158 let ch2 = dfg.add_channel(TokenType::i32());
1159 let _ch3 = dfg.add_channel(TokenType::i32());
1160
1161 dfg.add_operator(DfOperator::Source {
1163 external_id: 0,
1164 output: ch1,
1165 });
1166
1167 dfg.add_operator(DfOperator::Compute {
1169 op: ComputeOp::Add,
1170 inputs: vec![ch1],
1171 output: ch2,
1172 });
1173
1174 dfg.add_operator(DfOperator::Sink {
1176 input: ch2,
1177 external_id: 0,
1178 });
1179
1180 let analysis = DependencyAnalysis::analyze(&dfg);
1181
1182 assert_eq!(analysis.sources.len(), 1);
1183 assert_eq!(analysis.sinks.len(), 1);
1184 assert_eq!(analysis.critical_path_length, 2);
1185 }
1186
1187 #[test]
1188 fn test_parallel_graph() {
1189 let mut dfg = DataflowGraph::new();
1190
1191 let ch_in = dfg.add_channel(TokenType::i32());
1199 let ch_b = dfg.add_channel(TokenType::i32());
1200 let ch_c = dfg.add_channel(TokenType::i32());
1201 let ch_out = dfg.add_channel(TokenType::i32());
1202
1203 dfg.add_operator(DfOperator::Source {
1205 external_id: 0,
1206 output: ch_in,
1207 });
1208
1209 let split_b = dfg.add_channel(TokenType::i32());
1211 let split_c = dfg.add_channel(TokenType::i32());
1212 dfg.add_operator(DfOperator::Split {
1213 input: ch_in,
1214 outputs: vec![split_b, split_c],
1215 });
1216
1217 dfg.add_operator(DfOperator::Compute {
1219 op: ComputeOp::Mul,
1220 inputs: vec![split_b],
1221 output: ch_b,
1222 });
1223
1224 dfg.add_operator(DfOperator::Compute {
1226 op: ComputeOp::Add,
1227 inputs: vec![split_c],
1228 output: ch_c,
1229 });
1230
1231 dfg.add_operator(DfOperator::Compute {
1233 op: ComputeOp::Add,
1234 inputs: vec![ch_b, ch_c],
1235 output: ch_out,
1236 });
1237
1238 dfg.add_operator(DfOperator::Sink {
1239 input: ch_out,
1240 external_id: 0,
1241 });
1242
1243 let analysis = DependencyAnalysis::analyze(&dfg);
1244 let profile = analysis.parallelism_profile(&dfg);
1245
1246 assert!(profile.iter().any(|&p| p >= 2));
1248 }
1249
1250 #[test]
1251 fn test_compute_op_costs() {
1252 assert_eq!(ComputeOp::Add.energy_cost(), 1);
1253 assert_eq!(ComputeOp::Mul.energy_cost(), 3);
1254 assert_eq!(ComputeOp::Div.energy_cost(), 10);
1255 assert_eq!(ComputeOp::Fma.energy_cost(), 4);
1256
1257 assert_eq!(ComputeOp::Add.input_count(), 2);
1258 assert_eq!(ComputeOp::Neg.input_count(), 1);
1259 assert_eq!(ComputeOp::Fma.input_count(), 3);
1260 }
1261
1262 #[test]
1263 #[allow(clippy::float_cmp, clippy::approx_constant)]
1264 fn test_token_value_conversions() {
1265 let b = TokenValue::Bool(true);
1266 assert!(b.as_bool().unwrap());
1267 assert_eq!(b.as_i64().unwrap(), 1);
1268
1269 let i = TokenValue::Int(-42);
1270 assert_eq!(i.as_i64().unwrap(), -42);
1271 assert_eq!(i.as_f64().unwrap(), -42.0);
1272
1273 let f = TokenValue::Float(3.14);
1274 assert!((f.as_f64().unwrap() - 3.14).abs() < 0.001);
1275
1276 let arr = TokenValue::Array(vec![TokenValue::Int(1)]);
1278 assert!(arr.as_bool().is_err());
1279 assert!(arr.as_i64().is_err());
1280 assert!(arr.as_u64().is_err());
1281 assert!(arr.as_f64().is_err());
1282 }
1283}