1use crate::{
7 BasicBlock, BasicBlockId, FunctionMIR, Local, MirContext, Operand, Place, Rvalue, Statement,
8 Terminator,
9};
10
11pub trait MirVisitor {
16 fn visit_context(&mut self, ctx: &MirContext) {
18 for (_, func) in &ctx.functions {
19 self.visit_function(func);
20 }
21 }
22
23 fn visit_function(&mut self, func: &FunctionMIR) {
25 for (idx, block) in func.basic_blocks.iter().enumerate() {
26 self.visit_basic_block(BasicBlockId::new(idx as u32), block);
27 }
28 }
29
30 fn visit_basic_block(&mut self, _id: BasicBlockId, block: &BasicBlock) {
32 for stmt in &block.statements {
33 self.visit_statement(stmt);
34 }
35 self.visit_terminator(&block.terminator);
36 }
37
38 fn visit_statement(&mut self, stmt: &Statement) {
40 match stmt {
41 Statement::Assign { place, rvalue, .. } => {
42 self.visit_place(place);
43 self.visit_rvalue(rvalue);
44 }
45 Statement::StorageLive { local, .. } | Statement::StorageDead { local, .. } => {
46 self.visit_local(*local);
47 }
48 Statement::Nop => {}
49 }
50 }
51
52 fn visit_terminator(&mut self, term: &Terminator) {
54 match term {
55 Terminator::Return { .. } => {}
56 Terminator::Goto { target, .. } => {
57 self.visit_target(*target);
58 }
59 Terminator::SwitchInt {
60 discriminant,
61 targets,
62 ..
63 } => {
64 self.visit_operand(discriminant);
65 for (_, target) in &targets.branches {
66 self.visit_target(*target);
67 }
68 self.visit_target(targets.otherwise);
69 }
70 Terminator::Call {
71 func,
72 args,
73 destination,
74 target,
75 ..
76 } => {
77 self.visit_operand(func);
78 for arg in args {
79 self.visit_operand(arg);
80 }
81 self.visit_place(destination);
82 self.visit_target(*target);
83 }
84 Terminator::Abort { .. }
85 | Terminator::Unreachable { .. }
86 | Terminator::Cancel { .. } => {}
87
88 Terminator::Spawn {
90 func,
91 args,
92 destination,
93 target,
94 ..
95 } => {
96 self.visit_operand(func);
97 for arg in args {
98 self.visit_operand(arg);
99 }
100 self.visit_place(destination);
101 self.visit_target(*target);
102 }
103 Terminator::TaskAwait {
104 task,
105 destination,
106 target,
107 ..
108 } => {
109 self.visit_operand(task);
110 self.visit_place(destination);
111 self.visit_target(*target);
112 }
113 Terminator::TaskGroupEnter {
114 destination,
115 body,
116 join_block,
117 ..
118 } => {
119 self.visit_place(destination);
120 self.visit_target(*body);
121 self.visit_target(*join_block);
122 }
123 Terminator::TaskGroupExit { group, target, .. } => {
124 self.visit_operand(group);
125 self.visit_target(*target);
126 }
127 Terminator::ChannelRecv {
128 channel,
129 destination,
130 target,
131 closed_target,
132 ..
133 } => {
134 self.visit_operand(channel);
135 self.visit_place(destination);
136 self.visit_target(*target);
137 self.visit_target(*closed_target);
138 }
139 Terminator::ChannelSend {
140 channel,
141 value,
142 target,
143 closed_target,
144 ..
145 } => {
146 self.visit_operand(channel);
147 self.visit_operand(value);
148 self.visit_target(*target);
149 self.visit_target(*closed_target);
150 }
151 Terminator::Select {
152 arms,
153 default,
154 destination,
155 selected_arm,
156 ..
157 } => {
158 for arm in arms {
159 match &arm.operation {
160 crate::ChannelOp::Recv { channel } => {
161 self.visit_operand(channel);
162 }
163 crate::ChannelOp::Send { channel, value } => {
164 self.visit_operand(channel);
165 self.visit_operand(value);
166 }
167 crate::ChannelOp::Timeout { .. } => {}
168 }
169 self.visit_target(arm.target);
170 }
171 if let Some(default) = default {
172 self.visit_target(*default);
173 }
174 self.visit_place(destination);
175 self.visit_place(selected_arm);
176 }
177 }
178 }
179
180 fn visit_place(&mut self, place: &Place) {
182 self.visit_local(place.local);
183 for elem in &place.projection {
184 if let crate::PlaceElem::Index(local) = elem {
185 self.visit_local(*local);
186 }
187 }
188 }
189
190 fn visit_rvalue(&mut self, rvalue: &Rvalue) {
192 match rvalue {
193 Rvalue::Use(operand) => {
194 self.visit_operand(operand);
195 }
196 Rvalue::BinaryOp { left, right, .. } => {
197 self.visit_operand(left);
198 self.visit_operand(right);
199 }
200 Rvalue::UnaryOp { operand, .. } => {
201 self.visit_operand(operand);
202 }
203 Rvalue::Ref { place, .. } => {
204 self.visit_place(place);
205 }
206 Rvalue::Aggregate { operands, .. } => {
207 for operand in operands {
208 self.visit_operand(operand);
209 }
210 }
211 Rvalue::Cast { operand, .. } => {
212 self.visit_operand(operand);
213 }
214 Rvalue::Discriminant { place } | Rvalue::Len { place } => {
215 self.visit_place(place);
216 }
217 Rvalue::SimdBinaryOp { left, right, .. } => {
218 self.visit_operand(left);
219 self.visit_operand(right);
220 }
221 Rvalue::SimdLoad { source, .. } => {
222 self.visit_place(source);
223 }
224 Rvalue::SimdStore { value, dest, .. } => {
225 self.visit_operand(value);
226 self.visit_place(dest);
227 }
228 Rvalue::SimdSplat { value, .. } => {
229 self.visit_operand(value);
230 }
231
232 Rvalue::ChannelCreate { .. } => {}
234 Rvalue::ChannelTryRecv { channel } => {
235 self.visit_operand(channel);
236 }
237 Rvalue::ChannelTrySend { channel, value } => {
238 self.visit_operand(channel);
239 self.visit_operand(value);
240 }
241 Rvalue::ChannelSender { channel }
242 | Rvalue::ChannelReceiver { channel }
243 | Rvalue::ChannelClose { channel } => {
244 self.visit_operand(channel);
245 }
246 Rvalue::IsCancelled | Rvalue::CurrentTask => {}
247 Rvalue::Try(operand) => {
248 self.visit_operand(operand);
249 }
250 Rvalue::EnumVariant { fields, .. } => {
251 for op in fields {
252 self.visit_operand(op);
253 }
254 }
255 }
256 }
257
258 fn visit_operand(&mut self, operand: &Operand) {
260 match operand {
261 Operand::Copy(place) | Operand::Move(place) => {
262 self.visit_place(place);
263 }
264 Operand::Constant(_) => {}
265 }
266 }
267
268 fn visit_local(&mut self, _local: Local) {}
270
271 fn visit_target(&mut self, _target: BasicBlockId) {}
273}
274
275pub trait MirVisitorMut {
280 fn visit_context_mut(&mut self, ctx: &mut MirContext) {
282 for (_, func) in &mut ctx.functions {
283 self.visit_function_mut(func);
284 }
285 }
286
287 fn visit_function_mut(&mut self, func: &mut FunctionMIR) {
289 for idx in 0..func.basic_blocks.len() {
290 let block = &mut func.basic_blocks[idx];
291 self.visit_basic_block_mut(BasicBlockId::new(idx as u32), block);
292 }
293 }
294
295 fn visit_basic_block_mut(&mut self, _id: BasicBlockId, block: &mut BasicBlock) {
297 for stmt in &mut block.statements {
298 self.visit_statement_mut(stmt);
299 }
300 self.visit_terminator_mut(&mut block.terminator);
301 }
302
303 fn visit_statement_mut(&mut self, stmt: &mut Statement) {
305 match stmt {
306 Statement::Assign { place, rvalue, .. } => {
307 self.visit_place_mut(place);
308 self.visit_rvalue_mut(rvalue);
309 }
310 Statement::StorageLive { local, .. } | Statement::StorageDead { local, .. } => {
311 self.visit_local_mut(local);
312 }
313 Statement::Nop => {}
314 }
315 }
316
317 fn visit_terminator_mut(&mut self, term: &mut Terminator) {
319 match term {
320 Terminator::Return { .. } => {}
321 Terminator::Goto { target, .. } => {
322 self.visit_target_mut(target);
323 }
324 Terminator::SwitchInt {
325 discriminant,
326 targets,
327 ..
328 } => {
329 self.visit_operand_mut(discriminant);
330 for (_, target) in &mut targets.branches {
331 self.visit_target_mut(target);
332 }
333 self.visit_target_mut(&mut targets.otherwise);
334 }
335 Terminator::Call {
336 func,
337 args,
338 destination,
339 target,
340 ..
341 } => {
342 self.visit_operand_mut(func);
343 for arg in args {
344 self.visit_operand_mut(arg);
345 }
346 self.visit_place_mut(destination);
347 self.visit_target_mut(target);
348 }
349 Terminator::Abort { .. }
350 | Terminator::Unreachable { .. }
351 | Terminator::Cancel { .. } => {}
352
353 Terminator::Spawn {
355 func,
356 args,
357 destination,
358 target,
359 ..
360 } => {
361 self.visit_operand_mut(func);
362 for arg in args {
363 self.visit_operand_mut(arg);
364 }
365 self.visit_place_mut(destination);
366 self.visit_target_mut(target);
367 }
368 Terminator::TaskAwait {
369 task,
370 destination,
371 target,
372 ..
373 } => {
374 self.visit_operand_mut(task);
375 self.visit_place_mut(destination);
376 self.visit_target_mut(target);
377 }
378 Terminator::TaskGroupEnter {
379 destination,
380 body,
381 join_block,
382 ..
383 } => {
384 self.visit_place_mut(destination);
385 self.visit_target_mut(body);
386 self.visit_target_mut(join_block);
387 }
388 Terminator::TaskGroupExit { group, target, .. } => {
389 self.visit_operand_mut(group);
390 self.visit_target_mut(target);
391 }
392 Terminator::ChannelRecv {
393 channel,
394 destination,
395 target,
396 closed_target,
397 ..
398 } => {
399 self.visit_operand_mut(channel);
400 self.visit_place_mut(destination);
401 self.visit_target_mut(target);
402 self.visit_target_mut(closed_target);
403 }
404 Terminator::ChannelSend {
405 channel,
406 value,
407 target,
408 closed_target,
409 ..
410 } => {
411 self.visit_operand_mut(channel);
412 self.visit_operand_mut(value);
413 self.visit_target_mut(target);
414 self.visit_target_mut(closed_target);
415 }
416 Terminator::Select {
417 arms,
418 default,
419 destination,
420 selected_arm,
421 ..
422 } => {
423 for arm in arms {
424 match &mut arm.operation {
425 crate::ChannelOp::Recv { channel } => {
426 self.visit_operand_mut(channel);
427 }
428 crate::ChannelOp::Send { channel, value } => {
429 self.visit_operand_mut(channel);
430 self.visit_operand_mut(value);
431 }
432 crate::ChannelOp::Timeout { .. } => {}
433 }
434 self.visit_target_mut(&mut arm.target);
435 }
436 if let Some(default) = default {
437 self.visit_target_mut(default);
438 }
439 self.visit_place_mut(destination);
440 self.visit_place_mut(selected_arm);
441 }
442 }
443 }
444
445 fn visit_place_mut(&mut self, place: &mut Place) {
447 self.visit_local_mut(&mut place.local);
448 for elem in &mut place.projection {
449 if let crate::PlaceElem::Index(local) = elem {
450 self.visit_local_mut(local);
451 }
452 }
453 }
454
455 fn visit_rvalue_mut(&mut self, rvalue: &mut Rvalue) {
457 match rvalue {
458 Rvalue::Use(operand) => {
459 self.visit_operand_mut(operand);
460 }
461 Rvalue::BinaryOp { left, right, .. } => {
462 self.visit_operand_mut(left);
463 self.visit_operand_mut(right);
464 }
465 Rvalue::UnaryOp { operand, .. } => {
466 self.visit_operand_mut(operand);
467 }
468 Rvalue::Ref { place, .. } => {
469 self.visit_place_mut(place);
470 }
471 Rvalue::Aggregate { operands, .. } => {
472 for operand in operands {
473 self.visit_operand_mut(operand);
474 }
475 }
476 Rvalue::Cast { operand, .. } => {
477 self.visit_operand_mut(operand);
478 }
479 Rvalue::Discriminant { place } | Rvalue::Len { place } => {
480 self.visit_place_mut(place);
481 }
482 Rvalue::SimdBinaryOp { left, right, .. } => {
483 self.visit_operand_mut(left);
484 self.visit_operand_mut(right);
485 }
486 Rvalue::SimdLoad { source, .. } => {
487 self.visit_place_mut(source);
488 }
489 Rvalue::SimdStore { value, dest, .. } => {
490 self.visit_operand_mut(value);
491 self.visit_place_mut(dest);
492 }
493 Rvalue::SimdSplat { value, .. } => {
494 self.visit_operand_mut(value);
495 }
496
497 Rvalue::ChannelCreate { .. } => {}
499 Rvalue::ChannelTryRecv { channel } => {
500 self.visit_operand_mut(channel);
501 }
502 Rvalue::ChannelTrySend { channel, value } => {
503 self.visit_operand_mut(channel);
504 self.visit_operand_mut(value);
505 }
506 Rvalue::ChannelSender { channel }
507 | Rvalue::ChannelReceiver { channel }
508 | Rvalue::ChannelClose { channel } => {
509 self.visit_operand_mut(channel);
510 }
511 Rvalue::IsCancelled | Rvalue::CurrentTask => {}
512 Rvalue::Try(operand) => {
513 self.visit_operand_mut(operand);
514 }
515 Rvalue::EnumVariant { fields, .. } => {
516 for op in fields {
517 self.visit_operand_mut(op);
518 }
519 }
520 }
521 }
522
523 fn visit_operand_mut(&mut self, operand: &mut Operand) {
525 match operand {
526 Operand::Copy(place) | Operand::Move(place) => {
527 self.visit_place_mut(place);
528 }
529 Operand::Constant(_) => {}
530 }
531 }
532
533 fn visit_local_mut(&mut self, _local: &mut Local) {}
535
536 fn visit_target_mut(&mut self, _target: &mut BasicBlockId) {}
538}
539
540pub struct LocalCounter {
542 pub count: usize,
543}
544
545impl LocalCounter {
546 pub fn new() -> Self {
547 Self { count: 0 }
548 }
549}
550
551impl Default for LocalCounter {
552 fn default() -> Self {
553 Self::new()
554 }
555}
556
557impl MirVisitor for LocalCounter {
558 fn visit_local(&mut self, _local: Local) {
559 self.count += 1;
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566 use crate::{FunctionId, LocalDecl, Symbol, Ty};
567 use joule_common::Span;
568
569 #[test]
570 fn test_local_counter() {
571 let mut func = FunctionMIR::new(
572 FunctionId::new(0),
573 Symbol::from_u32(0),
574 Ty::Unit,
575 Span::dummy(),
576 );
577
578 func.add_local(LocalDecl::new(None, Ty::Unit, false, Span::dummy()));
580 func.add_local(LocalDecl::new(None, Ty::Unit, false, Span::dummy()));
581
582 let mut block = BasicBlock::new(Terminator::Return {
584 span: Span::dummy(),
585 });
586 block.push_statement(Statement::Assign {
587 place: Place::from_local(Local::new(1)),
588 rvalue: Rvalue::Use(Operand::Copy(Place::from_local(Local::new(2)))),
589 span: Span::dummy(),
590 });
591 func.add_block(block);
592
593 let mut counter = LocalCounter::new();
594 counter.visit_function(&func);
595
596 assert_eq!(counter.count, 2);
598 }
599
600 #[test]
601 fn test_visitor_traversal() {
602 struct TestVisitor {
603 visited_blocks: Vec<BasicBlockId>,
604 }
605
606 impl TestVisitor {
607 fn new() -> Self {
608 Self {
609 visited_blocks: Vec::new(),
610 }
611 }
612 }
613
614 impl MirVisitor for TestVisitor {
615 fn visit_basic_block(&mut self, id: BasicBlockId, block: &BasicBlock) {
616 self.visited_blocks.push(id);
617 for stmt in &block.statements {
619 self.visit_statement(stmt);
620 }
621 self.visit_terminator(&block.terminator);
622 }
623 }
624
625 let mut func = FunctionMIR::new(
626 FunctionId::new(0),
627 Symbol::from_u32(0),
628 Ty::Unit,
629 Span::dummy(),
630 );
631
632 func.add_block(BasicBlock::new(Terminator::Return {
634 span: Span::dummy(),
635 }));
636 func.add_block(BasicBlock::new(Terminator::Return {
637 span: Span::dummy(),
638 }));
639 func.add_block(BasicBlock::new(Terminator::Return {
640 span: Span::dummy(),
641 }));
642
643 let mut visitor = TestVisitor::new();
644 visitor.visit_function(&func);
645
646 assert_eq!(visitor.visited_blocks.len(), 3);
647 assert_eq!(visitor.visited_blocks[0], BasicBlockId::new(0));
648 assert_eq!(visitor.visited_blocks[1], BasicBlockId::new(1));
649 assert_eq!(visitor.visited_blocks[2], BasicBlockId::new(2));
650 }
651}