joule_mir/
visit.rs

1//! Visitor pattern for traversing MIR
2//!
3//! This module provides visitor traits for traversing and transforming MIR.
4//! Used by optimization passes, analysis passes, and code generation.
5
6use crate::{
7    BasicBlock, BasicBlockId, FunctionMIR, Local, MirContext, Operand, Place, Rvalue, Statement,
8    Terminator,
9};
10
11/// Visitor trait for MIR
12///
13/// Implement this trait to traverse MIR structures.
14/// All methods have default implementations that do nothing.
15pub trait MirVisitor {
16    /// Visit a MIR context
17    fn visit_context(&mut self, ctx: &MirContext) {
18        for (_, func) in &ctx.functions {
19            self.visit_function(func);
20        }
21    }
22
23    /// Visit a function
24    fn visit_function(&mut self, func: &FunctionMIR) {
25        for (idx, block) in func.basic_blocks.iter().enumerate() {
26            self.visit_basic_block(BasicBlockId::new(idx as u32), block);
27        }
28    }
29
30    /// Visit a basic block
31    fn visit_basic_block(&mut self, _id: BasicBlockId, block: &BasicBlock) {
32        for stmt in &block.statements {
33            self.visit_statement(stmt);
34        }
35        self.visit_terminator(&block.terminator);
36    }
37
38    /// Visit a statement
39    fn visit_statement(&mut self, stmt: &Statement) {
40        match stmt {
41            Statement::Assign { place, rvalue, .. } => {
42                self.visit_place(place);
43                self.visit_rvalue(rvalue);
44            }
45            Statement::StorageLive { local, .. } | Statement::StorageDead { local, .. } => {
46                self.visit_local(*local);
47            }
48            Statement::Nop => {}
49        }
50    }
51
52    /// Visit a terminator
53    fn visit_terminator(&mut self, term: &Terminator) {
54        match term {
55            Terminator::Return { .. } => {}
56            Terminator::Goto { target, .. } => {
57                self.visit_target(*target);
58            }
59            Terminator::SwitchInt {
60                discriminant,
61                targets,
62                ..
63            } => {
64                self.visit_operand(discriminant);
65                for (_, target) in &targets.branches {
66                    self.visit_target(*target);
67                }
68                self.visit_target(targets.otherwise);
69            }
70            Terminator::Call {
71                func,
72                args,
73                destination,
74                target,
75                ..
76            } => {
77                self.visit_operand(func);
78                for arg in args {
79                    self.visit_operand(arg);
80                }
81                self.visit_place(destination);
82                self.visit_target(*target);
83            }
84            Terminator::Abort { .. }
85            | Terminator::Unreachable { .. }
86            | Terminator::Cancel { .. } => {}
87
88            // Concurrency terminators
89            Terminator::Spawn {
90                func,
91                args,
92                destination,
93                target,
94                ..
95            } => {
96                self.visit_operand(func);
97                for arg in args {
98                    self.visit_operand(arg);
99                }
100                self.visit_place(destination);
101                self.visit_target(*target);
102            }
103            Terminator::TaskAwait {
104                task,
105                destination,
106                target,
107                ..
108            } => {
109                self.visit_operand(task);
110                self.visit_place(destination);
111                self.visit_target(*target);
112            }
113            Terminator::TaskGroupEnter {
114                destination,
115                body,
116                join_block,
117                ..
118            } => {
119                self.visit_place(destination);
120                self.visit_target(*body);
121                self.visit_target(*join_block);
122            }
123            Terminator::TaskGroupExit { group, target, .. } => {
124                self.visit_operand(group);
125                self.visit_target(*target);
126            }
127            Terminator::ChannelRecv {
128                channel,
129                destination,
130                target,
131                closed_target,
132                ..
133            } => {
134                self.visit_operand(channel);
135                self.visit_place(destination);
136                self.visit_target(*target);
137                self.visit_target(*closed_target);
138            }
139            Terminator::ChannelSend {
140                channel,
141                value,
142                target,
143                closed_target,
144                ..
145            } => {
146                self.visit_operand(channel);
147                self.visit_operand(value);
148                self.visit_target(*target);
149                self.visit_target(*closed_target);
150            }
151            Terminator::Select {
152                arms,
153                default,
154                destination,
155                selected_arm,
156                ..
157            } => {
158                for arm in arms {
159                    match &arm.operation {
160                        crate::ChannelOp::Recv { channel } => {
161                            self.visit_operand(channel);
162                        }
163                        crate::ChannelOp::Send { channel, value } => {
164                            self.visit_operand(channel);
165                            self.visit_operand(value);
166                        }
167                        crate::ChannelOp::Timeout { .. } => {}
168                    }
169                    self.visit_target(arm.target);
170                }
171                if let Some(default) = default {
172                    self.visit_target(*default);
173                }
174                self.visit_place(destination);
175                self.visit_place(selected_arm);
176            }
177        }
178    }
179
180    /// Visit a place
181    fn visit_place(&mut self, place: &Place) {
182        self.visit_local(place.local);
183        for elem in &place.projection {
184            if let crate::PlaceElem::Index(local) = elem {
185                self.visit_local(*local);
186            }
187        }
188    }
189
190    /// Visit an rvalue
191    fn visit_rvalue(&mut self, rvalue: &Rvalue) {
192        match rvalue {
193            Rvalue::Use(operand) => {
194                self.visit_operand(operand);
195            }
196            Rvalue::BinaryOp { left, right, .. } => {
197                self.visit_operand(left);
198                self.visit_operand(right);
199            }
200            Rvalue::UnaryOp { operand, .. } => {
201                self.visit_operand(operand);
202            }
203            Rvalue::Ref { place, .. } => {
204                self.visit_place(place);
205            }
206            Rvalue::Aggregate { operands, .. } => {
207                for operand in operands {
208                    self.visit_operand(operand);
209                }
210            }
211            Rvalue::Cast { operand, .. } => {
212                self.visit_operand(operand);
213            }
214            Rvalue::Discriminant { place } | Rvalue::Len { place } => {
215                self.visit_place(place);
216            }
217            Rvalue::SimdBinaryOp { left, right, .. } => {
218                self.visit_operand(left);
219                self.visit_operand(right);
220            }
221            Rvalue::SimdLoad { source, .. } => {
222                self.visit_place(source);
223            }
224            Rvalue::SimdStore { value, dest, .. } => {
225                self.visit_operand(value);
226                self.visit_place(dest);
227            }
228            Rvalue::SimdSplat { value, .. } => {
229                self.visit_operand(value);
230            }
231
232            // Concurrency rvalues
233            Rvalue::ChannelCreate { .. } => {}
234            Rvalue::ChannelTryRecv { channel } => {
235                self.visit_operand(channel);
236            }
237            Rvalue::ChannelTrySend { channel, value } => {
238                self.visit_operand(channel);
239                self.visit_operand(value);
240            }
241            Rvalue::ChannelSender { channel }
242            | Rvalue::ChannelReceiver { channel }
243            | Rvalue::ChannelClose { channel } => {
244                self.visit_operand(channel);
245            }
246            Rvalue::IsCancelled | Rvalue::CurrentTask => {}
247            Rvalue::Try(operand) => {
248                self.visit_operand(operand);
249            }
250            Rvalue::EnumVariant { fields, .. } => {
251                for op in fields {
252                    self.visit_operand(op);
253                }
254            }
255        }
256    }
257
258    /// Visit an operand
259    fn visit_operand(&mut self, operand: &Operand) {
260        match operand {
261            Operand::Copy(place) | Operand::Move(place) => {
262                self.visit_place(place);
263            }
264            Operand::Constant(_) => {}
265        }
266    }
267
268    /// Visit a local
269    fn visit_local(&mut self, _local: Local) {}
270
271    /// Visit a target block
272    fn visit_target(&mut self, _target: BasicBlockId) {}
273}
274
275/// Mutable visitor trait for MIR
276///
277/// Implement this trait to transform MIR structures.
278/// All methods have default implementations that do nothing.
279pub trait MirVisitorMut {
280    /// Visit a MIR context
281    fn visit_context_mut(&mut self, ctx: &mut MirContext) {
282        for (_, func) in &mut ctx.functions {
283            self.visit_function_mut(func);
284        }
285    }
286
287    /// Visit a function
288    fn visit_function_mut(&mut self, func: &mut FunctionMIR) {
289        for idx in 0..func.basic_blocks.len() {
290            let block = &mut func.basic_blocks[idx];
291            self.visit_basic_block_mut(BasicBlockId::new(idx as u32), block);
292        }
293    }
294
295    /// Visit a basic block
296    fn visit_basic_block_mut(&mut self, _id: BasicBlockId, block: &mut BasicBlock) {
297        for stmt in &mut block.statements {
298            self.visit_statement_mut(stmt);
299        }
300        self.visit_terminator_mut(&mut block.terminator);
301    }
302
303    /// Visit a statement
304    fn visit_statement_mut(&mut self, stmt: &mut Statement) {
305        match stmt {
306            Statement::Assign { place, rvalue, .. } => {
307                self.visit_place_mut(place);
308                self.visit_rvalue_mut(rvalue);
309            }
310            Statement::StorageLive { local, .. } | Statement::StorageDead { local, .. } => {
311                self.visit_local_mut(local);
312            }
313            Statement::Nop => {}
314        }
315    }
316
317    /// Visit a terminator
318    fn visit_terminator_mut(&mut self, term: &mut Terminator) {
319        match term {
320            Terminator::Return { .. } => {}
321            Terminator::Goto { target, .. } => {
322                self.visit_target_mut(target);
323            }
324            Terminator::SwitchInt {
325                discriminant,
326                targets,
327                ..
328            } => {
329                self.visit_operand_mut(discriminant);
330                for (_, target) in &mut targets.branches {
331                    self.visit_target_mut(target);
332                }
333                self.visit_target_mut(&mut targets.otherwise);
334            }
335            Terminator::Call {
336                func,
337                args,
338                destination,
339                target,
340                ..
341            } => {
342                self.visit_operand_mut(func);
343                for arg in args {
344                    self.visit_operand_mut(arg);
345                }
346                self.visit_place_mut(destination);
347                self.visit_target_mut(target);
348            }
349            Terminator::Abort { .. }
350            | Terminator::Unreachable { .. }
351            | Terminator::Cancel { .. } => {}
352
353            // Concurrency terminators
354            Terminator::Spawn {
355                func,
356                args,
357                destination,
358                target,
359                ..
360            } => {
361                self.visit_operand_mut(func);
362                for arg in args {
363                    self.visit_operand_mut(arg);
364                }
365                self.visit_place_mut(destination);
366                self.visit_target_mut(target);
367            }
368            Terminator::TaskAwait {
369                task,
370                destination,
371                target,
372                ..
373            } => {
374                self.visit_operand_mut(task);
375                self.visit_place_mut(destination);
376                self.visit_target_mut(target);
377            }
378            Terminator::TaskGroupEnter {
379                destination,
380                body,
381                join_block,
382                ..
383            } => {
384                self.visit_place_mut(destination);
385                self.visit_target_mut(body);
386                self.visit_target_mut(join_block);
387            }
388            Terminator::TaskGroupExit { group, target, .. } => {
389                self.visit_operand_mut(group);
390                self.visit_target_mut(target);
391            }
392            Terminator::ChannelRecv {
393                channel,
394                destination,
395                target,
396                closed_target,
397                ..
398            } => {
399                self.visit_operand_mut(channel);
400                self.visit_place_mut(destination);
401                self.visit_target_mut(target);
402                self.visit_target_mut(closed_target);
403            }
404            Terminator::ChannelSend {
405                channel,
406                value,
407                target,
408                closed_target,
409                ..
410            } => {
411                self.visit_operand_mut(channel);
412                self.visit_operand_mut(value);
413                self.visit_target_mut(target);
414                self.visit_target_mut(closed_target);
415            }
416            Terminator::Select {
417                arms,
418                default,
419                destination,
420                selected_arm,
421                ..
422            } => {
423                for arm in arms {
424                    match &mut arm.operation {
425                        crate::ChannelOp::Recv { channel } => {
426                            self.visit_operand_mut(channel);
427                        }
428                        crate::ChannelOp::Send { channel, value } => {
429                            self.visit_operand_mut(channel);
430                            self.visit_operand_mut(value);
431                        }
432                        crate::ChannelOp::Timeout { .. } => {}
433                    }
434                    self.visit_target_mut(&mut arm.target);
435                }
436                if let Some(default) = default {
437                    self.visit_target_mut(default);
438                }
439                self.visit_place_mut(destination);
440                self.visit_place_mut(selected_arm);
441            }
442        }
443    }
444
445    /// Visit a place
446    fn visit_place_mut(&mut self, place: &mut Place) {
447        self.visit_local_mut(&mut place.local);
448        for elem in &mut place.projection {
449            if let crate::PlaceElem::Index(local) = elem {
450                self.visit_local_mut(local);
451            }
452        }
453    }
454
455    /// Visit an rvalue
456    fn visit_rvalue_mut(&mut self, rvalue: &mut Rvalue) {
457        match rvalue {
458            Rvalue::Use(operand) => {
459                self.visit_operand_mut(operand);
460            }
461            Rvalue::BinaryOp { left, right, .. } => {
462                self.visit_operand_mut(left);
463                self.visit_operand_mut(right);
464            }
465            Rvalue::UnaryOp { operand, .. } => {
466                self.visit_operand_mut(operand);
467            }
468            Rvalue::Ref { place, .. } => {
469                self.visit_place_mut(place);
470            }
471            Rvalue::Aggregate { operands, .. } => {
472                for operand in operands {
473                    self.visit_operand_mut(operand);
474                }
475            }
476            Rvalue::Cast { operand, .. } => {
477                self.visit_operand_mut(operand);
478            }
479            Rvalue::Discriminant { place } | Rvalue::Len { place } => {
480                self.visit_place_mut(place);
481            }
482            Rvalue::SimdBinaryOp { left, right, .. } => {
483                self.visit_operand_mut(left);
484                self.visit_operand_mut(right);
485            }
486            Rvalue::SimdLoad { source, .. } => {
487                self.visit_place_mut(source);
488            }
489            Rvalue::SimdStore { value, dest, .. } => {
490                self.visit_operand_mut(value);
491                self.visit_place_mut(dest);
492            }
493            Rvalue::SimdSplat { value, .. } => {
494                self.visit_operand_mut(value);
495            }
496
497            // Concurrency rvalues
498            Rvalue::ChannelCreate { .. } => {}
499            Rvalue::ChannelTryRecv { channel } => {
500                self.visit_operand_mut(channel);
501            }
502            Rvalue::ChannelTrySend { channel, value } => {
503                self.visit_operand_mut(channel);
504                self.visit_operand_mut(value);
505            }
506            Rvalue::ChannelSender { channel }
507            | Rvalue::ChannelReceiver { channel }
508            | Rvalue::ChannelClose { channel } => {
509                self.visit_operand_mut(channel);
510            }
511            Rvalue::IsCancelled | Rvalue::CurrentTask => {}
512            Rvalue::Try(operand) => {
513                self.visit_operand_mut(operand);
514            }
515            Rvalue::EnumVariant { fields, .. } => {
516                for op in fields {
517                    self.visit_operand_mut(op);
518                }
519            }
520        }
521    }
522
523    /// Visit an operand
524    fn visit_operand_mut(&mut self, operand: &mut Operand) {
525        match operand {
526            Operand::Copy(place) | Operand::Move(place) => {
527                self.visit_place_mut(place);
528            }
529            Operand::Constant(_) => {}
530        }
531    }
532
533    /// Visit a local
534    fn visit_local_mut(&mut self, _local: &mut Local) {}
535
536    /// Visit a target block
537    fn visit_target_mut(&mut self, _target: &mut BasicBlockId) {}
538}
539
540/// Example visitor that counts locals
541pub struct LocalCounter {
542    pub count: usize,
543}
544
545impl LocalCounter {
546    pub fn new() -> Self {
547        Self { count: 0 }
548    }
549}
550
551impl Default for LocalCounter {
552    fn default() -> Self {
553        Self::new()
554    }
555}
556
557impl MirVisitor for LocalCounter {
558    fn visit_local(&mut self, _local: Local) {
559        self.count += 1;
560    }
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566    use crate::{FunctionId, LocalDecl, Symbol, Ty};
567    use joule_common::Span;
568
569    #[test]
570    fn test_local_counter() {
571        let mut func = FunctionMIR::new(
572            FunctionId::new(0),
573            Symbol::from_u32(0),
574            Ty::Unit,
575            Span::dummy(),
576        );
577
578        // Add some locals
579        func.add_local(LocalDecl::new(None, Ty::Unit, false, Span::dummy()));
580        func.add_local(LocalDecl::new(None, Ty::Unit, false, Span::dummy()));
581
582        // Add a block with a statement
583        let mut block = BasicBlock::new(Terminator::Return {
584            span: Span::dummy(),
585        });
586        block.push_statement(Statement::Assign {
587            place: Place::from_local(Local::new(1)),
588            rvalue: Rvalue::Use(Operand::Copy(Place::from_local(Local::new(2)))),
589            span: Span::dummy(),
590        });
591        func.add_block(block);
592
593        let mut counter = LocalCounter::new();
594        counter.visit_function(&func);
595
596        // Should count: local 1 (destination) + local 2 (source) = 2
597        assert_eq!(counter.count, 2);
598    }
599
600    #[test]
601    fn test_visitor_traversal() {
602        struct TestVisitor {
603            visited_blocks: Vec<BasicBlockId>,
604        }
605
606        impl TestVisitor {
607            fn new() -> Self {
608                Self {
609                    visited_blocks: Vec::new(),
610                }
611            }
612        }
613
614        impl MirVisitor for TestVisitor {
615            fn visit_basic_block(&mut self, id: BasicBlockId, block: &BasicBlock) {
616                self.visited_blocks.push(id);
617                // Continue with default behavior
618                for stmt in &block.statements {
619                    self.visit_statement(stmt);
620                }
621                self.visit_terminator(&block.terminator);
622            }
623        }
624
625        let mut func = FunctionMIR::new(
626            FunctionId::new(0),
627            Symbol::from_u32(0),
628            Ty::Unit,
629            Span::dummy(),
630        );
631
632        // Add three blocks
633        func.add_block(BasicBlock::new(Terminator::Return {
634            span: Span::dummy(),
635        }));
636        func.add_block(BasicBlock::new(Terminator::Return {
637            span: Span::dummy(),
638        }));
639        func.add_block(BasicBlock::new(Terminator::Return {
640            span: Span::dummy(),
641        }));
642
643        let mut visitor = TestVisitor::new();
644        visitor.visit_function(&func);
645
646        assert_eq!(visitor.visited_blocks.len(), 3);
647        assert_eq!(visitor.visited_blocks[0], BasicBlockId::new(0));
648        assert_eq!(visitor.visited_blocks[1], BasicBlockId::new(1));
649        assert_eq!(visitor.visited_blocks[2], BasicBlockId::new(2));
650    }
651}