joule_mir/
const_fold.rs

1//! Constant folding optimization pass for MIR.
2//!
3//! Walks basic blocks and evaluates `Rvalue::BinaryOp` / `Rvalue::UnaryOp`
4//! where both operands are `Operand::Constant`, replacing with `Rvalue::Use(Operand::Constant(...))`.
5
6use crate::{BasicBlock, BinOp, Constant, FunctionMIR, Literal, Operand, Rvalue, Statement, UnOp};
7use joule_common::Span;
8
9/// Run constant folding on a MIR body. Returns the number of folds applied.
10pub fn fold_constants(body: &mut FunctionMIR) -> usize {
11    let mut folds = 0;
12
13    for block in &mut body.basic_blocks {
14        folds += fold_block(block);
15    }
16
17    folds
18}
19
20fn fold_block(block: &mut BasicBlock) -> usize {
21    let mut folds = 0;
22
23    for stmt in &mut block.statements {
24        if let Statement::Assign {
25            rvalue, span, ..
26        } = stmt
27        {
28            if let Some(folded) = try_fold_rvalue(rvalue, *span) {
29                *rvalue = folded;
30                folds += 1;
31            }
32        }
33    }
34
35    folds
36}
37
38fn try_fold_rvalue(rvalue: &Rvalue, span: Span) -> Option<Rvalue> {
39    match rvalue {
40        Rvalue::BinaryOp { op, left, right } => {
41            let left_const = operand_as_constant(left)?;
42            let right_const = operand_as_constant(right)?;
43            fold_binary(*op, left_const, right_const, span)
44        }
45        Rvalue::UnaryOp { op, operand } => {
46            let c = operand_as_constant(operand)?;
47            fold_unary(*op, c, span)
48        }
49        _ => None,
50    }
51}
52
53fn operand_as_constant(op: &Operand) -> Option<&Constant> {
54    match op {
55        Operand::Constant(c) => Some(c),
56        _ => None,
57    }
58}
59
60fn fold_binary(op: BinOp, left: &Constant, right: &Constant, span: Span) -> Option<Rvalue> {
61    let result = match (&left.literal, &right.literal) {
62        // Integer arithmetic
63        (Literal::Int(a, ty), Literal::Int(b, _)) => {
64            let ty = *ty;
65            match op {
66                BinOp::Add => a.checked_add(*b).map(|r| Literal::Int(r, ty)),
67                BinOp::Sub => a.checked_sub(*b).map(|r| Literal::Int(r, ty)),
68                BinOp::Mul => a.checked_mul(*b).map(|r| Literal::Int(r, ty)),
69                BinOp::Div => {
70                    if *b == 0 {
71                        None
72                    } else {
73                        a.checked_div(*b).map(|r| Literal::Int(r, ty))
74                    }
75                }
76                BinOp::Rem => {
77                    if *b == 0 {
78                        None
79                    } else {
80                        a.checked_rem(*b).map(|r| Literal::Int(r, ty))
81                    }
82                }
83                BinOp::BitAnd => Some(Literal::Int(a & b, ty)),
84                BinOp::BitOr => Some(Literal::Int(a | b, ty)),
85                BinOp::BitXor => Some(Literal::Int(a ^ b, ty)),
86                BinOp::Shl => Some(Literal::Int(a << (*b as u32), ty)),
87                BinOp::Shr => Some(Literal::Int(a >> (*b as u32), ty)),
88                BinOp::Eq => Some(Literal::Bool(a == b)),
89                BinOp::Ne => Some(Literal::Bool(a != b)),
90                BinOp::Lt => Some(Literal::Bool(a < b)),
91                BinOp::Le => Some(Literal::Bool(a <= b)),
92                BinOp::Gt => Some(Literal::Bool(a > b)),
93                BinOp::Ge => Some(Literal::Bool(a >= b)),
94                _ => None,
95            }
96        }
97
98        // Unsigned integer arithmetic
99        (Literal::Uint(a, ty), Literal::Uint(b, _)) => {
100            let ty = *ty;
101            match op {
102                BinOp::Add => a.checked_add(*b).map(|r| Literal::Uint(r, ty)),
103                BinOp::Sub => a.checked_sub(*b).map(|r| Literal::Uint(r, ty)),
104                BinOp::Mul => a.checked_mul(*b).map(|r| Literal::Uint(r, ty)),
105                BinOp::Div => {
106                    if *b == 0 {
107                        None
108                    } else {
109                        a.checked_div(*b).map(|r| Literal::Uint(r, ty))
110                    }
111                }
112                BinOp::Rem => {
113                    if *b == 0 {
114                        None
115                    } else {
116                        a.checked_rem(*b).map(|r| Literal::Uint(r, ty))
117                    }
118                }
119                BinOp::BitAnd => Some(Literal::Uint(a & b, ty)),
120                BinOp::BitOr => Some(Literal::Uint(a | b, ty)),
121                BinOp::BitXor => Some(Literal::Uint(a ^ b, ty)),
122                BinOp::Shl => Some(Literal::Uint(a << (*b as u32), ty)),
123                BinOp::Shr => Some(Literal::Uint(a >> (*b as u32), ty)),
124                BinOp::Eq => Some(Literal::Bool(a == b)),
125                BinOp::Ne => Some(Literal::Bool(a != b)),
126                BinOp::Lt => Some(Literal::Bool(a < b)),
127                BinOp::Le => Some(Literal::Bool(a <= b)),
128                BinOp::Gt => Some(Literal::Bool(a > b)),
129                BinOp::Ge => Some(Literal::Bool(a >= b)),
130                _ => None,
131            }
132        }
133
134        // Float arithmetic
135        (Literal::Float(a, ty), Literal::Float(b, _)) => {
136            let ty = *ty;
137            match op {
138                BinOp::Add => Some(Literal::Float(a + b, ty)),
139                BinOp::Sub => Some(Literal::Float(a - b, ty)),
140                BinOp::Mul => Some(Literal::Float(a * b, ty)),
141                BinOp::Div => {
142                    if *b == 0.0 {
143                        None
144                    } else {
145                        Some(Literal::Float(a / b, ty))
146                    }
147                }
148                BinOp::Rem => Some(Literal::Float(a % b, ty)),
149                BinOp::Eq => Some(Literal::Bool(a == b)),
150                BinOp::Ne => Some(Literal::Bool(a != b)),
151                BinOp::Lt => Some(Literal::Bool(a < b)),
152                BinOp::Le => Some(Literal::Bool(a <= b)),
153                BinOp::Gt => Some(Literal::Bool(a > b)),
154                BinOp::Ge => Some(Literal::Bool(a >= b)),
155                _ => None,
156            }
157        }
158
159        // Boolean operations
160        (Literal::Bool(a), Literal::Bool(b)) => match op {
161            BinOp::And => Some(Literal::Bool(*a && *b)),
162            BinOp::Or => Some(Literal::Bool(*a || *b)),
163            BinOp::Eq => Some(Literal::Bool(a == b)),
164            BinOp::Ne => Some(Literal::Bool(a != b)),
165            _ => None,
166        },
167
168        _ => None,
169    }?;
170
171    let ty = literal_ty(&result, &left.ty);
172    Some(Rvalue::Use(Operand::Constant(Constant::new(
173        result, ty, span,
174    ))))
175}
176
177fn fold_unary(op: UnOp, c: &Constant, span: Span) -> Option<Rvalue> {
178    let result = match (&c.literal, op) {
179        (Literal::Int(n, ty), UnOp::Neg) => n.checked_neg().map(|r| Literal::Int(r, *ty)),
180        (Literal::Float(n, ty), UnOp::Neg) => Some(Literal::Float(-n, *ty)),
181        (Literal::Bool(b), UnOp::Not) => Some(Literal::Bool(!b)),
182        (Literal::Int(n, ty), UnOp::Not) => Some(Literal::Int(!n, *ty)),
183        _ => None,
184    }?;
185
186    let ty = literal_ty(&result, &c.ty);
187    Some(Rvalue::Use(Operand::Constant(Constant::new(
188        result, ty, span,
189    ))))
190}
191
192/// Determine the MIR type for a folded literal.
193fn literal_ty(lit: &Literal, original_ty: &crate::Ty) -> crate::Ty {
194    match lit {
195        Literal::Bool(_) => crate::Ty::Bool,
196        _ => original_ty.clone(),
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::{BasicBlock, FunctionId, FunctionMIR, Local, Place, Terminator, Ty};
204    use joule_common::{Span, Symbol};
205    use joule_hir::{FloatTy, IntTy};
206
207    fn dummy_span() -> Span {
208        Span {
209            start: 0,
210            end: 0,
211            file_id: 0,
212        }
213    }
214
215    fn int_const(n: i64) -> Operand {
216        Operand::Constant(Constant::new(
217            Literal::Int(n, IntTy::I64),
218            Ty::Int(IntTy::I64),
219            dummy_span(),
220        ))
221    }
222
223    fn float_const(n: f64) -> Operand {
224        Operand::Constant(Constant::new(
225            Literal::Float(n, FloatTy::F64),
226            Ty::Float(FloatTy::F64),
227            dummy_span(),
228        ))
229    }
230
231    fn bool_const(b: bool) -> Operand {
232        Operand::Constant(Constant::new(Literal::Bool(b), Ty::Bool, dummy_span()))
233    }
234
235    fn make_assign(local: u32, rvalue: Rvalue) -> Statement {
236        Statement::Assign {
237            place: Place::from_local(Local::new(local)),
238            rvalue,
239            span: dummy_span(),
240        }
241    }
242
243    #[test]
244    fn test_fold_int_add() {
245        let mut body = FunctionMIR::new(FunctionId::new(0), Symbol::intern("test"), Ty::Int(IntTy::I64), dummy_span());
246        let mut block = BasicBlock::new(Terminator::Return { span: dummy_span() });
247        block.push_statement(make_assign(
248            0,
249            Rvalue::BinaryOp {
250                op: BinOp::Add,
251                left: int_const(3),
252                right: int_const(4),
253            },
254        ));
255        body.basic_blocks.push(block);
256
257        let folds = fold_constants(&mut body);
258        assert_eq!(folds, 1);
259
260        if let Statement::Assign { rvalue, .. } = &body.basic_blocks[0].statements[0] {
261            if let Rvalue::Use(Operand::Constant(c)) = rvalue {
262                assert_eq!(c.literal, Literal::Int(7, IntTy::I64));
263            } else {
264                panic!("expected folded constant");
265            }
266        }
267    }
268
269    #[test]
270    fn test_fold_float_mul() {
271        let mut body = FunctionMIR::new(FunctionId::new(0), Symbol::intern("test"), Ty::Float(FloatTy::F64), dummy_span());
272        let mut block = BasicBlock::new(Terminator::Return { span: dummy_span() });
273        block.push_statement(make_assign(
274            0,
275            Rvalue::BinaryOp {
276                op: BinOp::Mul,
277                left: float_const(3.0),
278                right: float_const(2.5),
279            },
280        ));
281        body.basic_blocks.push(block);
282
283        let folds = fold_constants(&mut body);
284        assert_eq!(folds, 1);
285
286        if let Statement::Assign { rvalue, .. } = &body.basic_blocks[0].statements[0] {
287            if let Rvalue::Use(Operand::Constant(c)) = rvalue {
288                assert_eq!(c.literal, Literal::Float(7.5, FloatTy::F64));
289            } else {
290                panic!("expected folded constant");
291            }
292        }
293    }
294
295    #[test]
296    fn test_fold_comparison() {
297        let mut body = FunctionMIR::new(FunctionId::new(0), Symbol::intern("test"), Ty::Bool, dummy_span());
298        let mut block = BasicBlock::new(Terminator::Return { span: dummy_span() });
299        block.push_statement(make_assign(
300            0,
301            Rvalue::BinaryOp {
302                op: BinOp::Lt,
303                left: int_const(3),
304                right: int_const(5),
305            },
306        ));
307        body.basic_blocks.push(block);
308
309        let folds = fold_constants(&mut body);
310        assert_eq!(folds, 1);
311
312        if let Statement::Assign { rvalue, .. } = &body.basic_blocks[0].statements[0] {
313            if let Rvalue::Use(Operand::Constant(c)) = rvalue {
314                assert_eq!(c.literal, Literal::Bool(true));
315            } else {
316                panic!("expected folded constant");
317            }
318        }
319    }
320
321    #[test]
322    fn test_fold_negation() {
323        let mut body = FunctionMIR::new(FunctionId::new(0), Symbol::intern("test"), Ty::Int(IntTy::I64), dummy_span());
324        let mut block = BasicBlock::new(Terminator::Return { span: dummy_span() });
325        block.push_statement(make_assign(
326            0,
327            Rvalue::UnaryOp {
328                op: UnOp::Neg,
329                operand: int_const(42),
330            },
331        ));
332        body.basic_blocks.push(block);
333
334        let folds = fold_constants(&mut body);
335        assert_eq!(folds, 1);
336
337        if let Statement::Assign { rvalue, .. } = &body.basic_blocks[0].statements[0] {
338            if let Rvalue::Use(Operand::Constant(c)) = rvalue {
339                assert_eq!(c.literal, Literal::Int(-42, IntTy::I64));
340            } else {
341                panic!("expected folded constant");
342            }
343        }
344    }
345
346    #[test]
347    fn test_no_fold_with_variable() {
348        let mut body = FunctionMIR::new(FunctionId::new(0), Symbol::intern("test"), Ty::Int(IntTy::I64), dummy_span());
349        let mut block = BasicBlock::new(Terminator::Return { span: dummy_span() });
350        block.push_statement(make_assign(
351            1,
352            Rvalue::BinaryOp {
353                op: BinOp::Add,
354                left: Operand::Copy(Place::from_local(Local::new(0))),
355                right: int_const(1),
356            },
357        ));
358        body.basic_blocks.push(block);
359
360        let folds = fold_constants(&mut body);
361        assert_eq!(folds, 0); // Can't fold: one operand is a variable
362    }
363
364    #[test]
365    fn test_no_fold_div_by_zero() {
366        let mut body = FunctionMIR::new(FunctionId::new(0), Symbol::intern("test"), Ty::Int(IntTy::I64), dummy_span());
367        let mut block = BasicBlock::new(Terminator::Return { span: dummy_span() });
368        block.push_statement(make_assign(
369            0,
370            Rvalue::BinaryOp {
371                op: BinOp::Div,
372                left: int_const(10),
373                right: int_const(0),
374            },
375        ));
376        body.basic_blocks.push(block);
377
378        let folds = fold_constants(&mut body);
379        assert_eq!(folds, 0); // Preserves div-by-zero for runtime
380    }
381
382    #[test]
383    fn test_fold_boolean_and() {
384        let mut body = FunctionMIR::new(FunctionId::new(0), Symbol::intern("test"), Ty::Bool, dummy_span());
385        let mut block = BasicBlock::new(Terminator::Return { span: dummy_span() });
386        block.push_statement(make_assign(
387            0,
388            Rvalue::BinaryOp {
389                op: BinOp::And,
390                left: bool_const(true),
391                right: bool_const(false),
392            },
393        ));
394        body.basic_blocks.push(block);
395
396        let folds = fold_constants(&mut body);
397        assert_eq!(folds, 1);
398
399        if let Statement::Assign { rvalue, .. } = &body.basic_blocks[0].statements[0] {
400            if let Rvalue::Use(Operand::Constant(c)) = rvalue {
401                assert_eq!(c.literal, Literal::Bool(false));
402            } else {
403                panic!("expected folded constant");
404            }
405        }
406    }
407
408    #[test]
409    fn test_fold_bitwise_xor() {
410        let mut body = FunctionMIR::new(FunctionId::new(0), Symbol::intern("test"), Ty::Int(IntTy::I64), dummy_span());
411        let mut block = BasicBlock::new(Terminator::Return { span: dummy_span() });
412        block.push_statement(make_assign(
413            0,
414            Rvalue::BinaryOp {
415                op: BinOp::BitXor,
416                left: int_const(0xFF),
417                right: int_const(0x0F),
418            },
419        ));
420        body.basic_blocks.push(block);
421
422        let folds = fold_constants(&mut body);
423        assert_eq!(folds, 1);
424
425        if let Statement::Assign { rvalue, .. } = &body.basic_blocks[0].statements[0] {
426            if let Rvalue::Use(Operand::Constant(c)) = rvalue {
427                assert_eq!(c.literal, Literal::Int(0xF0, IntTy::I64));
428            } else {
429                panic!("expected folded constant");
430            }
431        }
432    }
433}