joule_mir/dataflow/
extract.rs

1//! MIR to Dataflow Graph Extraction
2//!
3//! This module converts MIR (basic block form) to a dataflow graph representation.
4//! The extraction process:
5//!
6//! 1. Creates channels for each MIR local variable
7//! 2. Converts statements to Compute/Memory operators
8//! 3. Converts terminators to control-flow operators (Steer, Stream, etc.)
9//! 4. Detects loop patterns and generates Stream operators where beneficial
10//!
11//! # Example
12//!
13//! ```ignore
14//! let mir = ...;  // MIR for a function
15//! let dfg = DfgExtractor::extract(&mir);
16//! ```
17
18use super::{ChannelId, ComputeOp, DataflowGraph, DfOperator, MemoryOp, TokenType, TokenValue};
19use crate::{
20    BasicBlock, BasicBlockId, BinOp, FunctionMIR, Local, Operand, Place, Rvalue, Statement,
21    Terminator, Ty, UnOp,
22};
23use std::collections::HashMap;
24
25/// Extracts a dataflow graph from MIR
26pub struct DfgExtractor {
27    dfg: DataflowGraph,
28    /// Map from MIR local to channel ID
29    local_to_channel: HashMap<Local, ChannelId>,
30    /// Map from basic block to entry channel (for control flow)
31    block_entry: HashMap<BasicBlockId, ChannelId>,
32    /// Current block being processed
33    current_block: Option<BasicBlockId>,
34    /// Loop detection: back edges
35    back_edges: Vec<(BasicBlockId, BasicBlockId)>,
36}
37
38impl DfgExtractor {
39    /// Create a new extractor
40    pub fn new() -> Self {
41        Self {
42            dfg: DataflowGraph::new(),
43            local_to_channel: HashMap::new(),
44            block_entry: HashMap::new(),
45            current_block: None,
46            back_edges: Vec::new(),
47        }
48    }
49
50    /// Extract a dataflow graph from a MIR function
51    pub fn extract(mir: &FunctionMIR) -> DataflowGraph {
52        let mut extractor = Self::new();
53        extractor.dfg.name = Some(mir.name.as_str().to_string());
54
55        // Phase 1: Create channels for all locals
56        extractor.create_local_channels(mir);
57
58        // Phase 2: Detect loops (back edges)
59        extractor.detect_loops(mir);
60
61        // Phase 3: Process each basic block
62        for (bb_idx, block) in mir.basic_blocks.iter().enumerate() {
63            let bb_id = BasicBlockId::new(bb_idx as u32);
64            extractor.current_block = Some(bb_id);
65            extractor.process_block(bb_id, block, mir);
66        }
67
68        // Phase 4: Mark external inputs/outputs
69        extractor.mark_io(mir);
70
71        extractor.dfg
72    }
73
74    /// Create channels for all MIR locals
75    fn create_local_channels(&mut self, mir: &FunctionMIR) {
76        for (idx, local_decl) in mir.locals.iter().enumerate() {
77            let local = Local::new(idx as u32);
78            let token_type = Self::mir_ty_to_token_type(&local_decl.ty);
79            let channel = self.dfg.add_channel(token_type);
80            self.local_to_channel.insert(local, channel);
81        }
82    }
83
84    /// Detect loop back edges using DFS
85    fn detect_loops(&mut self, mir: &FunctionMIR) {
86        if mir.basic_blocks.is_empty() {
87            return;
88        }
89
90        let mut visited = vec![false; mir.basic_blocks.len()];
91        let mut in_stack = vec![false; mir.basic_blocks.len()];
92        self.dfs_detect_loops(mir, BasicBlockId::START, &mut visited, &mut in_stack);
93    }
94
95    fn dfs_detect_loops(
96        &mut self,
97        mir: &FunctionMIR,
98        bb: BasicBlockId,
99        visited: &mut [bool],
100        in_stack: &mut [bool],
101    ) {
102        let idx = bb.0 as usize;
103        if idx >= mir.basic_blocks.len() {
104            return;
105        }
106
107        visited[idx] = true;
108        in_stack[idx] = true;
109
110        let successors = self.get_successors(&mir.basic_blocks[idx].terminator);
111        for succ in successors {
112            let succ_idx = succ.0 as usize;
113            if succ_idx >= mir.basic_blocks.len() {
114                continue;
115            }
116
117            if !visited[succ_idx] {
118                self.dfs_detect_loops(mir, succ, visited, in_stack);
119            } else if in_stack[succ_idx] {
120                // Back edge found - this is a loop
121                self.back_edges.push((bb, succ));
122            }
123        }
124
125        in_stack[idx] = false;
126    }
127
128    fn get_successors(&self, terminator: &Terminator) -> Vec<BasicBlockId> {
129        match terminator {
130            Terminator::Return { .. }
131            | Terminator::Abort { .. }
132            | Terminator::Unreachable { .. }
133            | Terminator::Cancel { .. } => {
134                vec![]
135            }
136            Terminator::Goto { target, .. } => vec![*target],
137            Terminator::SwitchInt { targets, .. } => {
138                let mut succs: Vec<_> = targets.branches.iter().map(|(_, bb)| *bb).collect();
139                succs.push(targets.otherwise);
140                succs
141            }
142            Terminator::Call { target, .. } => vec![*target],
143            Terminator::Spawn { target, .. } => vec![*target],
144            Terminator::TaskAwait { target, .. } => vec![*target],
145            Terminator::TaskGroupEnter {
146                body, join_block, ..
147            } => vec![*body, *join_block],
148            Terminator::TaskGroupExit { target, .. } => vec![*target],
149            Terminator::ChannelRecv {
150                target,
151                closed_target,
152                ..
153            } => vec![*target, *closed_target],
154            Terminator::ChannelSend {
155                target,
156                closed_target,
157                ..
158            } => vec![*target, *closed_target],
159            Terminator::Select { arms, default, .. } => {
160                let mut succs: Vec<_> = arms.iter().map(|arm| arm.target).collect();
161                if let Some(def) = default {
162                    succs.push(*def);
163                }
164                succs
165            }
166        }
167    }
168
169    /// Process a basic block
170    fn process_block(&mut self, bb_id: BasicBlockId, block: &BasicBlock, mir: &FunctionMIR) {
171        // Create block entry channel if needed
172        if !self.block_entry.contains_key(&bb_id) {
173            let entry = self.dfg.add_channel(TokenType::Unit);
174            self.block_entry.insert(bb_id, entry);
175        }
176
177        // Process statements
178        for stmt in &block.statements {
179            self.process_statement(stmt);
180        }
181
182        // Process terminator
183        self.process_terminator(&block.terminator, mir);
184    }
185
186    /// Process a statement
187    fn process_statement(&mut self, stmt: &Statement) {
188        match stmt {
189            Statement::Assign { place, rvalue, .. } => {
190                self.process_assignment(place, rvalue);
191            }
192            Statement::StorageLive { .. } | Statement::StorageDead { .. } | Statement::Nop => {
193                // No dataflow impact
194            }
195        }
196    }
197
198    /// Process an assignment
199    fn process_assignment(&mut self, place: &Place, rvalue: &Rvalue) {
200        let dest_channel = self.place_to_channel(place);
201
202        match rvalue {
203            Rvalue::Use(operand) => {
204                // Simple copy/move - create identity operator or direct connection
205                if let Some(src_channel) = self.operand_to_channel(operand) {
206                    // Identity: just connect channels
207                    self.dfg.add_operator(DfOperator::Compute {
208                        op: ComputeOp::Add, // Add 0 is identity
209                        inputs: vec![src_channel],
210                        output: dest_channel,
211                    });
212                }
213            }
214
215            Rvalue::BinaryOp { op, left, right } => {
216                let compute_op = Self::binop_to_compute_op(op);
217                let left_ch = self.operand_to_channel(left);
218                let right_ch = self.operand_to_channel(right);
219
220                if let (Some(l), Some(r)) = (left_ch, right_ch) {
221                    self.dfg.add_operator(DfOperator::Compute {
222                        op: compute_op,
223                        inputs: vec![l, r],
224                        output: dest_channel,
225                    });
226                }
227            }
228
229            Rvalue::UnaryOp { op, operand } => {
230                let compute_op = Self::unop_to_compute_op(op);
231                if let Some(src) = self.operand_to_channel(operand) {
232                    self.dfg.add_operator(DfOperator::Compute {
233                        op: compute_op,
234                        inputs: vec![src],
235                        output: dest_channel,
236                    });
237                }
238            }
239
240            Rvalue::Ref { place, .. } => {
241                // Reference creation - treat as address computation
242                if let Some(src) = self.local_to_channel.get(&place.local) {
243                    self.dfg.add_operator(DfOperator::Compute {
244                        op: ComputeOp::Add, // Identity
245                        inputs: vec![*src],
246                        output: dest_channel,
247                    });
248                }
249            }
250
251            Rvalue::Aggregate { operands, .. } => {
252                let inputs: Vec<_> = operands
253                    .iter()
254                    .filter_map(|op| self.operand_to_channel(op))
255                    .collect();
256
257                if !inputs.is_empty() {
258                    // For now, treat as merge
259                    self.dfg.add_operator(DfOperator::Merge {
260                        inputs,
261                        output: dest_channel,
262                    });
263                }
264            }
265
266            Rvalue::Cast { operand, kind, .. } => {
267                if let Some(src) = self.operand_to_channel(operand) {
268                    let op = match kind {
269                        crate::CastKind::IntToFloat => ComputeOp::IntToFloat,
270                        crate::CastKind::FloatToInt => ComputeOp::FloatToInt,
271                        crate::CastKind::IntToInt => ComputeOp::SignExtend,
272                        crate::CastKind::FloatToFloat => ComputeOp::Truncate,
273                        crate::CastKind::PtrToPtr => ComputeOp::ZeroExtend,
274                        crate::CastKind::Bitcast => ComputeOp::ZeroExtend,
275                    };
276                    self.dfg.add_operator(DfOperator::Compute {
277                        op,
278                        inputs: vec![src],
279                        output: dest_channel,
280                    });
281                }
282            }
283
284            Rvalue::Len { place } => {
285                // Length extraction reads the length field from an array/slice
286                // structure. This is semantically a memory load (reading a
287                // field), modeled as a Load operation in the dataflow graph.
288                if let Some(src) = self.local_to_channel.get(&place.local) {
289                    self.dfg.add_operator(DfOperator::Memory {
290                        op: MemoryOp::Load,
291                        address: *src,
292                        data: dest_channel,
293                        ordering: vec![],
294                    });
295                }
296            }
297
298            Rvalue::Discriminant { place } => {
299                // Discriminant extraction reads the tag field of an enum and
300                // truncates it to the discriminant integer type. This is a
301                // load + truncate operation; we model the dominant cost as the
302                // memory read of the tag field.
303                if let Some(src) = self.local_to_channel.get(&place.local) {
304                    self.dfg.add_operator(DfOperator::Memory {
305                        op: MemoryOp::Load,
306                        address: *src,
307                        data: dest_channel,
308                        ordering: vec![],
309                    });
310                }
311            }
312
313            // SIMD operations
314            Rvalue::SimdBinaryOp {
315                op, left, right, ..
316            } => {
317                let compute_op = Self::simd_op_to_compute_op(op);
318                let left_ch = self.operand_to_channel(left);
319                let right_ch = self.operand_to_channel(right);
320
321                if let (Some(l), Some(r)) = (left_ch, right_ch) {
322                    self.dfg.add_operator(DfOperator::Compute {
323                        op: compute_op,
324                        inputs: vec![l, r],
325                        output: dest_channel,
326                    });
327                }
328            }
329
330            Rvalue::SimdLoad { source, .. } => {
331                if let Some(src) = self.local_to_channel.get(&source.local) {
332                    self.dfg.add_operator(DfOperator::Memory {
333                        op: MemoryOp::Load,
334                        address: *src,
335                        data: dest_channel,
336                        ordering: vec![],
337                    });
338                }
339            }
340
341            Rvalue::SimdStore { value, dest, .. } => {
342                let val_ch = self.operand_to_channel(value);
343                if let (Some(val), Some(dst)) = (val_ch, self.local_to_channel.get(&dest.local)) {
344                    self.dfg.add_operator(DfOperator::Memory {
345                        op: MemoryOp::Store,
346                        address: *dst,
347                        data: val,
348                        ordering: vec![],
349                    });
350                }
351            }
352
353            Rvalue::SimdSplat { value, .. } => {
354                if let Some(src) = self.operand_to_channel(value) {
355                    // Splat is just broadcasting
356                    self.dfg.add_operator(DfOperator::Split {
357                        input: src,
358                        outputs: vec![dest_channel],
359                    });
360                }
361            }
362
363            // Channel operations - convert to operators
364            Rvalue::ChannelCreate { .. } => {
365                // Channel creation is a source of values
366                self.dfg.add_operator(DfOperator::Source {
367                    external_id: dest_channel.0,
368                    output: dest_channel,
369                });
370            }
371
372            Rvalue::ChannelTryRecv { channel } => {
373                if let Some(ch) = self.operand_to_channel(channel) {
374                    self.dfg.add_operator(DfOperator::Compute {
375                        op: ComputeOp::Add,
376                        inputs: vec![ch],
377                        output: dest_channel,
378                    });
379                }
380            }
381
382            Rvalue::ChannelTrySend { channel, value } => {
383                let ch = self.operand_to_channel(channel);
384                let val = self.operand_to_channel(value);
385                if let (Some(c), Some(v)) = (ch, val) {
386                    self.dfg.add_operator(DfOperator::Merge {
387                        inputs: vec![c, v],
388                        output: dest_channel,
389                    });
390                }
391            }
392
393            Rvalue::ChannelSender { channel } | Rvalue::ChannelReceiver { channel } => {
394                if let Some(ch) = self.operand_to_channel(channel) {
395                    self.dfg.add_operator(DfOperator::Compute {
396                        op: ComputeOp::Add,
397                        inputs: vec![ch],
398                        output: dest_channel,
399                    });
400                }
401            }
402
403            Rvalue::ChannelClose { channel } => {
404                if let Some(ch) = self.operand_to_channel(channel) {
405                    self.dfg.add_operator(DfOperator::Sink {
406                        input: ch,
407                        external_id: dest_channel.0,
408                    });
409                }
410            }
411
412            Rvalue::IsCancelled | Rvalue::CurrentTask => {
413                self.dfg.add_operator(DfOperator::Source {
414                    external_id: dest_channel.0,
415                    output: dest_channel,
416                });
417            }
418            Rvalue::Try(_operand) => {
419                self.dfg.add_operator(DfOperator::Source {
420                    external_id: dest_channel.0,
421                    output: dest_channel,
422                });
423            }
424            Rvalue::EnumVariant { .. } => {
425                // Enum variant is a source (constant value)
426                self.dfg.add_operator(DfOperator::Source {
427                    external_id: dest_channel.0,
428                    output: dest_channel,
429                });
430            }
431        }
432    }
433
434    /// Process a terminator
435    fn process_terminator(&mut self, terminator: &Terminator, _mir: &FunctionMIR) {
436        match terminator {
437            Terminator::Return { .. } => {
438                // Return value goes to sink
439                if let Some(&ret_ch) = self.local_to_channel.get(&Local::RETURN) {
440                    self.dfg.add_operator(DfOperator::Sink {
441                        input: ret_ch,
442                        external_id: 0, // Return is always external_id 0
443                    });
444                    self.dfg.add_output(ret_ch);
445                }
446            }
447
448            Terminator::Goto { target, .. } => {
449                // Unconditional jump - connect to target block
450                let _entry = self.get_or_create_block_entry(*target);
451                // Goto doesn't need explicit dataflow - control flows implicitly
452            }
453
454            Terminator::SwitchInt {
455                discriminant,
456                targets,
457                ..
458            } => {
459                // Conditional branch - use Steer
460                if let Some(cond_ch) = self.operand_to_channel(discriminant) {
461                    // For simple if-else (one branch + otherwise)
462                    if targets.branches.len() == 1 {
463                        let true_target = targets.branches[0].1;
464                        let false_target = targets.otherwise;
465
466                        let true_entry = self.get_or_create_block_entry(true_target);
467                        let false_entry = self.get_or_create_block_entry(false_target);
468
469                        // Create data channel to route
470                        let data = self.dfg.add_channel(TokenType::Unit);
471                        self.dfg.add_operator(DfOperator::Constant {
472                            value: TokenValue::Unit,
473                            output: data,
474                            repeat: None,
475                        });
476
477                        self.dfg.add_operator(DfOperator::Steer {
478                            decider: cond_ch,
479                            data,
480                            true_out: true_entry,
481                            false_out: false_entry,
482                        });
483                    } else {
484                        // Multi-way switch - use Select
485                        // Collect branch entries first to avoid borrow checker issues
486                        let branch_bbs: Vec<_> =
487                            targets.branches.iter().map(|(_, bb)| *bb).collect();
488                        let mut outputs = Vec::with_capacity(branch_bbs.len() + 1);
489                        for bb in branch_bbs {
490                            outputs.push(self.get_or_create_block_entry(bb));
491                        }
492                        outputs.push(self.get_or_create_block_entry(targets.otherwise));
493
494                        let data = self.dfg.add_channel(TokenType::Unit);
495                        self.dfg.add_operator(DfOperator::Constant {
496                            value: TokenValue::Unit,
497                            output: data,
498                            repeat: None,
499                        });
500
501                        self.dfg.add_operator(DfOperator::Split {
502                            input: data,
503                            outputs,
504                        });
505                    }
506                }
507            }
508
509            Terminator::Call {
510                func: _,
511                args,
512                destination,
513                target: _,
514                func_name: _,
515                ..
516            } => {
517                // Function call - create compute node
518                let inputs: Vec<_> = args
519                    .iter()
520                    .filter_map(|arg| self.operand_to_channel(arg))
521                    .collect();
522
523                let dest_ch = self.place_to_channel(destination);
524
525                // Model the function call as a Merge: the result depends on all
526                // arguments. This correctly captures data dependencies (the output
527                // is available only after all inputs are ready) and the energy cost
528                // of the call is dominated by the callee's body, not the call
529                // instruction itself. For inter-procedural energy analysis, the
530                // callee's DFG would be inlined here.
531                if !inputs.is_empty() {
532                    self.dfg.add_operator(DfOperator::Merge {
533                        inputs,
534                        output: dest_ch,
535                    });
536                }
537            }
538
539            Terminator::Spawn {
540                func: _,
541                args: _,
542                destination,
543                ..
544            } => {
545                // Spawn - create a fork in the dataflow
546                // Spawned computation is represented as an async source
547                // (the actual args would be processed by the spawned task's DFG)
548                let dest_ch = self.place_to_channel(destination);
549                self.dfg.add_operator(DfOperator::Source {
550                    external_id: dest_ch.0,
551                    output: dest_ch,
552                });
553            }
554
555            Terminator::TaskAwait {
556                task, destination, ..
557            } => {
558                // Await - synchronization point
559                if let Some(task_ch) = self.operand_to_channel(task) {
560                    let dest_ch = self.place_to_channel(destination);
561                    self.dfg.add_operator(DfOperator::Compute {
562                        op: ComputeOp::Add,
563                        inputs: vec![task_ch],
564                        output: dest_ch,
565                    });
566                }
567            }
568
569            Terminator::ChannelRecv {
570                channel,
571                destination,
572                ..
573            } => {
574                if let Some(ch) = self.operand_to_channel(channel) {
575                    let dest_ch = self.place_to_channel(destination);
576                    self.dfg.add_operator(DfOperator::Compute {
577                        op: ComputeOp::Add,
578                        inputs: vec![ch],
579                        output: dest_ch,
580                    });
581                }
582            }
583
584            Terminator::ChannelSend { channel, value, .. } => {
585                let ch = self.operand_to_channel(channel);
586                let val = self.operand_to_channel(value);
587                if let (Some(c), Some(v)) = (ch, val) {
588                    // Send merges channel and value
589                    let out = self.dfg.add_channel(TokenType::Unit);
590                    self.dfg.add_operator(DfOperator::Merge {
591                        inputs: vec![c, v],
592                        output: out,
593                    });
594                }
595            }
596
597            // Other terminators - basic handling
598            Terminator::Abort { .. }
599            | Terminator::Unreachable { .. }
600            | Terminator::Cancel { .. } => {
601                // Terminal nodes - no dataflow continuation
602            }
603
604            Terminator::TaskGroupEnter { destination, .. } => {
605                let dest_ch = self.place_to_channel(destination);
606                self.dfg.add_operator(DfOperator::Source {
607                    external_id: dest_ch.0,
608                    output: dest_ch,
609                });
610            }
611
612            Terminator::TaskGroupExit { group, .. } => {
613                if let Some(ch) = self.operand_to_channel(group) {
614                    self.dfg.add_operator(DfOperator::Sink {
615                        input: ch,
616                        external_id: ch.0,
617                    });
618                }
619            }
620
621            Terminator::Select {
622                arms, destination, ..
623            } => {
624                let inputs: Vec<_> = arms
625                    .iter()
626                    .filter_map(|arm| match &arm.operation {
627                        crate::ChannelOp::Recv { channel } => self.operand_to_channel(channel),
628                        crate::ChannelOp::Send { channel, .. } => self.operand_to_channel(channel),
629                        crate::ChannelOp::Timeout { .. } => None,
630                    })
631                    .collect();
632
633                let dest_ch = self.place_to_channel(destination);
634                if !inputs.is_empty() {
635                    self.dfg.add_operator(DfOperator::Merge {
636                        inputs,
637                        output: dest_ch,
638                    });
639                }
640            }
641        }
642    }
643
644    fn get_or_create_block_entry(&mut self, bb: BasicBlockId) -> ChannelId {
645        if let Some(&ch) = self.block_entry.get(&bb) {
646            ch
647        } else {
648            let ch = self.dfg.add_channel(TokenType::Unit);
649            self.block_entry.insert(bb, ch);
650            ch
651        }
652    }
653
654    /// Mark external inputs and outputs
655    fn mark_io(&mut self, mir: &FunctionMIR) {
656        // Parameters are inputs
657        for param in &mir.params {
658            if let Some(&ch) = self.local_to_channel.get(param) {
659                self.dfg.add_input(ch);
660
661                // Add source operator for parameters
662                self.dfg.add_operator(DfOperator::Source {
663                    external_id: ch.0,
664                    output: ch,
665                });
666            }
667        }
668
669        // Return is output (already handled in Return terminator)
670    }
671
672    /// Convert a place to its channel
673    fn place_to_channel(&self, place: &Place) -> ChannelId {
674        // For now, just use the base local
675        // More complex projections would need additional channels
676        *self
677            .local_to_channel
678            .get(&place.local)
679            .unwrap_or(&ChannelId::new(0))
680    }
681
682    /// Convert an operand to a channel
683    fn operand_to_channel(&mut self, operand: &Operand) -> Option<ChannelId> {
684        match operand {
685            Operand::Copy(place) | Operand::Move(place) => {
686                self.local_to_channel.get(&place.local).copied()
687            }
688            Operand::Constant(constant) => {
689                // Create a constant operator
690                let token_type = Self::mir_ty_to_token_type(&constant.ty);
691                let output = self.dfg.add_channel(token_type);
692                let value = Self::literal_to_token_value(&constant.literal);
693
694                self.dfg.add_operator(DfOperator::Constant {
695                    value,
696                    output,
697                    repeat: None,
698                });
699
700                Some(output)
701            }
702        }
703    }
704
705    /// Convert MIR type to token type
706    fn mir_ty_to_token_type(ty: &Ty) -> TokenType {
707        match ty {
708            Ty::Bool => TokenType::Bool,
709            Ty::Int(int_ty) => {
710                let bits = match int_ty {
711                    crate::IntTy::I8 => 8,
712                    crate::IntTy::I16 => 16,
713                    crate::IntTy::I32 => 32,
714                    crate::IntTy::I64 => 64,
715                    crate::IntTy::Isize => 64,
716                };
717                TokenType::Int { bits, signed: true }
718            }
719            Ty::Uint(uint_ty) => {
720                let bits = match uint_ty {
721                    crate::UintTy::U8 => 8,
722                    crate::UintTy::U16 => 16,
723                    crate::UintTy::U32 => 32,
724                    crate::UintTy::U64 => 64,
725                    crate::UintTy::Usize => 64,
726                };
727                TokenType::Int {
728                    bits,
729                    signed: false,
730                }
731            }
732            Ty::Float(float_ty) => {
733                let bits = match float_ty {
734                    crate::FloatTy::F16 | crate::FloatTy::BF16 => 16,
735                    crate::FloatTy::F32 => 32,
736                    crate::FloatTy::F64 => 64,
737                };
738                TokenType::Float { bits }
739            }
740            Ty::Ref { .. } | Ty::RawPtr { .. } => TokenType::Ptr,
741            Ty::Unit => TokenType::Unit,
742            Ty::Tuple(tys) => {
743                TokenType::Tuple(tys.iter().map(Self::mir_ty_to_token_type).collect())
744            }
745            Ty::Array { element, size } => TokenType::Array {
746                element: Box::new(Self::mir_ty_to_token_type(element)),
747                size: *size as usize,
748            },
749            _ => TokenType::Unit, // Default for complex types
750        }
751    }
752
753    /// Convert literal to token value
754    fn literal_to_token_value(literal: &crate::Literal) -> TokenValue {
755        match literal {
756            crate::Literal::Bool(b) => TokenValue::Bool(*b),
757            crate::Literal::Int(i, _) => TokenValue::Int(*i),
758            crate::Literal::Uint(u, _) => TokenValue::Uint(*u),
759            crate::Literal::Float(f, _) => TokenValue::Float(*f),
760            crate::Literal::Char(c) => TokenValue::Uint(*c as u64),
761            // String literals are heap-allocated data referenced by pointer.
762            // In the dataflow graph, they are modeled as a pointer token with
763            // address 0 (the actual address is assigned at link time). This
764            // correctly represents the energy cost: reading a string literal
765            // involves a pointer load from the data section.
766            crate::Literal::String(_) => TokenValue::Ptr(0),
767        }
768    }
769
770    /// Convert BinOp to ComputeOp
771    fn binop_to_compute_op(op: &BinOp) -> ComputeOp {
772        match op {
773            BinOp::Add => ComputeOp::Add,
774            BinOp::Sub => ComputeOp::Sub,
775            BinOp::Mul => ComputeOp::Mul,
776            BinOp::Div => ComputeOp::Div,
777            BinOp::Rem => ComputeOp::Rem,
778            BinOp::BitAnd => ComputeOp::BitAnd,
779            BinOp::BitOr => ComputeOp::BitOr,
780            BinOp::BitXor => ComputeOp::BitXor,
781            BinOp::Shl => ComputeOp::Shl,
782            BinOp::Shr => ComputeOp::Shr,
783            BinOp::Eq => ComputeOp::Eq,
784            BinOp::Ne => ComputeOp::Ne,
785            BinOp::Lt => ComputeOp::Lt,
786            BinOp::Le => ComputeOp::Le,
787            BinOp::Gt => ComputeOp::Gt,
788            BinOp::Ge => ComputeOp::Ge,
789            BinOp::And => ComputeOp::And,
790            BinOp::Or => ComputeOp::Or,
791        }
792    }
793
794    /// Convert UnOp to ComputeOp
795    fn unop_to_compute_op(op: &UnOp) -> ComputeOp {
796        match op {
797            UnOp::Neg => ComputeOp::Neg,
798            UnOp::Not => ComputeOp::Not,
799        }
800    }
801
802    /// Convert SIMD op to ComputeOp
803    fn simd_op_to_compute_op(op: &crate::SimdOp) -> ComputeOp {
804        match op {
805            crate::SimdOp::Add => ComputeOp::Add,
806            crate::SimdOp::Sub => ComputeOp::Sub,
807            crate::SimdOp::Mul => ComputeOp::Mul,
808            crate::SimdOp::Div => ComputeOp::Div,
809            crate::SimdOp::Fma => ComputeOp::Fma,
810        }
811    }
812}
813
814impl Default for DfgExtractor {
815    fn default() -> Self {
816        Self::new()
817    }
818}
819
820#[cfg(test)]
821mod tests {
822    use super::*;
823    use crate::{BasicBlock, LocalDecl, Terminator};
824    use joule_common::{Span, Symbol};
825
826    fn dummy_span() -> Span {
827        Span::dummy()
828    }
829
830    #[test]
831    fn test_simple_extraction() {
832        // Create a simple MIR function: fn add(a: i32, b: i32) -> i32 { a + b }
833        let mut mir = FunctionMIR::new(
834            crate::FunctionId::new(0),
835            Symbol::intern("add"),
836            Ty::Int(crate::IntTy::I32),
837            dummy_span(),
838        );
839
840        // Add parameters
841        let a = mir.add_local(LocalDecl::new(
842            Some(Symbol::intern("a")),
843            Ty::Int(crate::IntTy::I32),
844            false,
845            dummy_span(),
846        ));
847        let b = mir.add_local(LocalDecl::new(
848            Some(Symbol::intern("b")),
849            Ty::Int(crate::IntTy::I32),
850            false,
851            dummy_span(),
852        ));
853        mir.params = vec![a, b];
854
855        // Create basic block with add and return
856        let mut block = BasicBlock::new(Terminator::Return { span: dummy_span() });
857        block.push_statement(Statement::Assign {
858            place: Place::return_place(),
859            rvalue: Rvalue::BinaryOp {
860                op: BinOp::Add,
861                left: Operand::from_local(a),
862                right: Operand::from_local(b),
863            },
864            span: dummy_span(),
865        });
866        mir.add_block(block);
867
868        // Extract DFG
869        let dfg = DfgExtractor::extract(&mir);
870
871        assert!(dfg.operators.len() >= 1); // At least the add operator
872        assert!(!dfg.channels.is_empty());
873    }
874}