1use crate::{BasicBlock, BinOp, Constant, FunctionMIR, Literal, Operand, Rvalue, Statement, UnOp};
7use joule_common::Span;
8
9pub fn fold_constants(body: &mut FunctionMIR) -> usize {
11 let mut folds = 0;
12
13 for block in &mut body.basic_blocks {
14 folds += fold_block(block);
15 }
16
17 folds
18}
19
20fn fold_block(block: &mut BasicBlock) -> usize {
21 let mut folds = 0;
22
23 for stmt in &mut block.statements {
24 if let Statement::Assign {
25 rvalue, span, ..
26 } = stmt
27 {
28 if let Some(folded) = try_fold_rvalue(rvalue, *span) {
29 *rvalue = folded;
30 folds += 1;
31 }
32 }
33 }
34
35 folds
36}
37
38fn try_fold_rvalue(rvalue: &Rvalue, span: Span) -> Option<Rvalue> {
39 match rvalue {
40 Rvalue::BinaryOp { op, left, right } => {
41 let left_const = operand_as_constant(left)?;
42 let right_const = operand_as_constant(right)?;
43 fold_binary(*op, left_const, right_const, span)
44 }
45 Rvalue::UnaryOp { op, operand } => {
46 let c = operand_as_constant(operand)?;
47 fold_unary(*op, c, span)
48 }
49 _ => None,
50 }
51}
52
53fn operand_as_constant(op: &Operand) -> Option<&Constant> {
54 match op {
55 Operand::Constant(c) => Some(c),
56 _ => None,
57 }
58}
59
60fn fold_binary(op: BinOp, left: &Constant, right: &Constant, span: Span) -> Option<Rvalue> {
61 let result = match (&left.literal, &right.literal) {
62 (Literal::Int(a, ty), Literal::Int(b, _)) => {
64 let ty = *ty;
65 match op {
66 BinOp::Add => a.checked_add(*b).map(|r| Literal::Int(r, ty)),
67 BinOp::Sub => a.checked_sub(*b).map(|r| Literal::Int(r, ty)),
68 BinOp::Mul => a.checked_mul(*b).map(|r| Literal::Int(r, ty)),
69 BinOp::Div => {
70 if *b == 0 {
71 None
72 } else {
73 a.checked_div(*b).map(|r| Literal::Int(r, ty))
74 }
75 }
76 BinOp::Rem => {
77 if *b == 0 {
78 None
79 } else {
80 a.checked_rem(*b).map(|r| Literal::Int(r, ty))
81 }
82 }
83 BinOp::BitAnd => Some(Literal::Int(a & b, ty)),
84 BinOp::BitOr => Some(Literal::Int(a | b, ty)),
85 BinOp::BitXor => Some(Literal::Int(a ^ b, ty)),
86 BinOp::Shl => Some(Literal::Int(a << (*b as u32), ty)),
87 BinOp::Shr => Some(Literal::Int(a >> (*b as u32), ty)),
88 BinOp::Eq => Some(Literal::Bool(a == b)),
89 BinOp::Ne => Some(Literal::Bool(a != b)),
90 BinOp::Lt => Some(Literal::Bool(a < b)),
91 BinOp::Le => Some(Literal::Bool(a <= b)),
92 BinOp::Gt => Some(Literal::Bool(a > b)),
93 BinOp::Ge => Some(Literal::Bool(a >= b)),
94 _ => None,
95 }
96 }
97
98 (Literal::Uint(a, ty), Literal::Uint(b, _)) => {
100 let ty = *ty;
101 match op {
102 BinOp::Add => a.checked_add(*b).map(|r| Literal::Uint(r, ty)),
103 BinOp::Sub => a.checked_sub(*b).map(|r| Literal::Uint(r, ty)),
104 BinOp::Mul => a.checked_mul(*b).map(|r| Literal::Uint(r, ty)),
105 BinOp::Div => {
106 if *b == 0 {
107 None
108 } else {
109 a.checked_div(*b).map(|r| Literal::Uint(r, ty))
110 }
111 }
112 BinOp::Rem => {
113 if *b == 0 {
114 None
115 } else {
116 a.checked_rem(*b).map(|r| Literal::Uint(r, ty))
117 }
118 }
119 BinOp::BitAnd => Some(Literal::Uint(a & b, ty)),
120 BinOp::BitOr => Some(Literal::Uint(a | b, ty)),
121 BinOp::BitXor => Some(Literal::Uint(a ^ b, ty)),
122 BinOp::Shl => Some(Literal::Uint(a << (*b as u32), ty)),
123 BinOp::Shr => Some(Literal::Uint(a >> (*b as u32), ty)),
124 BinOp::Eq => Some(Literal::Bool(a == b)),
125 BinOp::Ne => Some(Literal::Bool(a != b)),
126 BinOp::Lt => Some(Literal::Bool(a < b)),
127 BinOp::Le => Some(Literal::Bool(a <= b)),
128 BinOp::Gt => Some(Literal::Bool(a > b)),
129 BinOp::Ge => Some(Literal::Bool(a >= b)),
130 _ => None,
131 }
132 }
133
134 (Literal::Float(a, ty), Literal::Float(b, _)) => {
136 let ty = *ty;
137 match op {
138 BinOp::Add => Some(Literal::Float(a + b, ty)),
139 BinOp::Sub => Some(Literal::Float(a - b, ty)),
140 BinOp::Mul => Some(Literal::Float(a * b, ty)),
141 BinOp::Div => {
142 if *b == 0.0 {
143 None
144 } else {
145 Some(Literal::Float(a / b, ty))
146 }
147 }
148 BinOp::Rem => Some(Literal::Float(a % b, ty)),
149 BinOp::Eq => Some(Literal::Bool(a == b)),
150 BinOp::Ne => Some(Literal::Bool(a != b)),
151 BinOp::Lt => Some(Literal::Bool(a < b)),
152 BinOp::Le => Some(Literal::Bool(a <= b)),
153 BinOp::Gt => Some(Literal::Bool(a > b)),
154 BinOp::Ge => Some(Literal::Bool(a >= b)),
155 _ => None,
156 }
157 }
158
159 (Literal::Bool(a), Literal::Bool(b)) => match op {
161 BinOp::And => Some(Literal::Bool(*a && *b)),
162 BinOp::Or => Some(Literal::Bool(*a || *b)),
163 BinOp::Eq => Some(Literal::Bool(a == b)),
164 BinOp::Ne => Some(Literal::Bool(a != b)),
165 _ => None,
166 },
167
168 _ => None,
169 }?;
170
171 let ty = literal_ty(&result, &left.ty);
172 Some(Rvalue::Use(Operand::Constant(Constant::new(
173 result, ty, span,
174 ))))
175}
176
177fn fold_unary(op: UnOp, c: &Constant, span: Span) -> Option<Rvalue> {
178 let result = match (&c.literal, op) {
179 (Literal::Int(n, ty), UnOp::Neg) => n.checked_neg().map(|r| Literal::Int(r, *ty)),
180 (Literal::Float(n, ty), UnOp::Neg) => Some(Literal::Float(-n, *ty)),
181 (Literal::Bool(b), UnOp::Not) => Some(Literal::Bool(!b)),
182 (Literal::Int(n, ty), UnOp::Not) => Some(Literal::Int(!n, *ty)),
183 _ => None,
184 }?;
185
186 let ty = literal_ty(&result, &c.ty);
187 Some(Rvalue::Use(Operand::Constant(Constant::new(
188 result, ty, span,
189 ))))
190}
191
192fn literal_ty(lit: &Literal, original_ty: &crate::Ty) -> crate::Ty {
194 match lit {
195 Literal::Bool(_) => crate::Ty::Bool,
196 _ => original_ty.clone(),
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::{BasicBlock, FunctionId, FunctionMIR, Local, Place, Terminator, Ty};
204 use joule_common::{Span, Symbol};
205 use joule_hir::{FloatTy, IntTy};
206
207 fn dummy_span() -> Span {
208 Span {
209 start: 0,
210 end: 0,
211 file_id: 0,
212 }
213 }
214
215 fn int_const(n: i64) -> Operand {
216 Operand::Constant(Constant::new(
217 Literal::Int(n, IntTy::I64),
218 Ty::Int(IntTy::I64),
219 dummy_span(),
220 ))
221 }
222
223 fn float_const(n: f64) -> Operand {
224 Operand::Constant(Constant::new(
225 Literal::Float(n, FloatTy::F64),
226 Ty::Float(FloatTy::F64),
227 dummy_span(),
228 ))
229 }
230
231 fn bool_const(b: bool) -> Operand {
232 Operand::Constant(Constant::new(Literal::Bool(b), Ty::Bool, dummy_span()))
233 }
234
235 fn make_assign(local: u32, rvalue: Rvalue) -> Statement {
236 Statement::Assign {
237 place: Place::from_local(Local::new(local)),
238 rvalue,
239 span: dummy_span(),
240 }
241 }
242
243 #[test]
244 fn test_fold_int_add() {
245 let mut body = FunctionMIR::new(FunctionId::new(0), Symbol::intern("test"), Ty::Int(IntTy::I64), dummy_span());
246 let mut block = BasicBlock::new(Terminator::Return { span: dummy_span() });
247 block.push_statement(make_assign(
248 0,
249 Rvalue::BinaryOp {
250 op: BinOp::Add,
251 left: int_const(3),
252 right: int_const(4),
253 },
254 ));
255 body.basic_blocks.push(block);
256
257 let folds = fold_constants(&mut body);
258 assert_eq!(folds, 1);
259
260 if let Statement::Assign { rvalue, .. } = &body.basic_blocks[0].statements[0] {
261 if let Rvalue::Use(Operand::Constant(c)) = rvalue {
262 assert_eq!(c.literal, Literal::Int(7, IntTy::I64));
263 } else {
264 panic!("expected folded constant");
265 }
266 }
267 }
268
269 #[test]
270 fn test_fold_float_mul() {
271 let mut body = FunctionMIR::new(FunctionId::new(0), Symbol::intern("test"), Ty::Float(FloatTy::F64), dummy_span());
272 let mut block = BasicBlock::new(Terminator::Return { span: dummy_span() });
273 block.push_statement(make_assign(
274 0,
275 Rvalue::BinaryOp {
276 op: BinOp::Mul,
277 left: float_const(3.0),
278 right: float_const(2.5),
279 },
280 ));
281 body.basic_blocks.push(block);
282
283 let folds = fold_constants(&mut body);
284 assert_eq!(folds, 1);
285
286 if let Statement::Assign { rvalue, .. } = &body.basic_blocks[0].statements[0] {
287 if let Rvalue::Use(Operand::Constant(c)) = rvalue {
288 assert_eq!(c.literal, Literal::Float(7.5, FloatTy::F64));
289 } else {
290 panic!("expected folded constant");
291 }
292 }
293 }
294
295 #[test]
296 fn test_fold_comparison() {
297 let mut body = FunctionMIR::new(FunctionId::new(0), Symbol::intern("test"), Ty::Bool, dummy_span());
298 let mut block = BasicBlock::new(Terminator::Return { span: dummy_span() });
299 block.push_statement(make_assign(
300 0,
301 Rvalue::BinaryOp {
302 op: BinOp::Lt,
303 left: int_const(3),
304 right: int_const(5),
305 },
306 ));
307 body.basic_blocks.push(block);
308
309 let folds = fold_constants(&mut body);
310 assert_eq!(folds, 1);
311
312 if let Statement::Assign { rvalue, .. } = &body.basic_blocks[0].statements[0] {
313 if let Rvalue::Use(Operand::Constant(c)) = rvalue {
314 assert_eq!(c.literal, Literal::Bool(true));
315 } else {
316 panic!("expected folded constant");
317 }
318 }
319 }
320
321 #[test]
322 fn test_fold_negation() {
323 let mut body = FunctionMIR::new(FunctionId::new(0), Symbol::intern("test"), Ty::Int(IntTy::I64), dummy_span());
324 let mut block = BasicBlock::new(Terminator::Return { span: dummy_span() });
325 block.push_statement(make_assign(
326 0,
327 Rvalue::UnaryOp {
328 op: UnOp::Neg,
329 operand: int_const(42),
330 },
331 ));
332 body.basic_blocks.push(block);
333
334 let folds = fold_constants(&mut body);
335 assert_eq!(folds, 1);
336
337 if let Statement::Assign { rvalue, .. } = &body.basic_blocks[0].statements[0] {
338 if let Rvalue::Use(Operand::Constant(c)) = rvalue {
339 assert_eq!(c.literal, Literal::Int(-42, IntTy::I64));
340 } else {
341 panic!("expected folded constant");
342 }
343 }
344 }
345
346 #[test]
347 fn test_no_fold_with_variable() {
348 let mut body = FunctionMIR::new(FunctionId::new(0), Symbol::intern("test"), Ty::Int(IntTy::I64), dummy_span());
349 let mut block = BasicBlock::new(Terminator::Return { span: dummy_span() });
350 block.push_statement(make_assign(
351 1,
352 Rvalue::BinaryOp {
353 op: BinOp::Add,
354 left: Operand::Copy(Place::from_local(Local::new(0))),
355 right: int_const(1),
356 },
357 ));
358 body.basic_blocks.push(block);
359
360 let folds = fold_constants(&mut body);
361 assert_eq!(folds, 0); }
363
364 #[test]
365 fn test_no_fold_div_by_zero() {
366 let mut body = FunctionMIR::new(FunctionId::new(0), Symbol::intern("test"), Ty::Int(IntTy::I64), dummy_span());
367 let mut block = BasicBlock::new(Terminator::Return { span: dummy_span() });
368 block.push_statement(make_assign(
369 0,
370 Rvalue::BinaryOp {
371 op: BinOp::Div,
372 left: int_const(10),
373 right: int_const(0),
374 },
375 ));
376 body.basic_blocks.push(block);
377
378 let folds = fold_constants(&mut body);
379 assert_eq!(folds, 0); }
381
382 #[test]
383 fn test_fold_boolean_and() {
384 let mut body = FunctionMIR::new(FunctionId::new(0), Symbol::intern("test"), Ty::Bool, dummy_span());
385 let mut block = BasicBlock::new(Terminator::Return { span: dummy_span() });
386 block.push_statement(make_assign(
387 0,
388 Rvalue::BinaryOp {
389 op: BinOp::And,
390 left: bool_const(true),
391 right: bool_const(false),
392 },
393 ));
394 body.basic_blocks.push(block);
395
396 let folds = fold_constants(&mut body);
397 assert_eq!(folds, 1);
398
399 if let Statement::Assign { rvalue, .. } = &body.basic_blocks[0].statements[0] {
400 if let Rvalue::Use(Operand::Constant(c)) = rvalue {
401 assert_eq!(c.literal, Literal::Bool(false));
402 } else {
403 panic!("expected folded constant");
404 }
405 }
406 }
407
408 #[test]
409 fn test_fold_bitwise_xor() {
410 let mut body = FunctionMIR::new(FunctionId::new(0), Symbol::intern("test"), Ty::Int(IntTy::I64), dummy_span());
411 let mut block = BasicBlock::new(Terminator::Return { span: dummy_span() });
412 block.push_statement(make_assign(
413 0,
414 Rvalue::BinaryOp {
415 op: BinOp::BitXor,
416 left: int_const(0xFF),
417 right: int_const(0x0F),
418 },
419 ));
420 body.basic_blocks.push(block);
421
422 let folds = fold_constants(&mut body);
423 assert_eq!(folds, 1);
424
425 if let Statement::Assign { rvalue, .. } = &body.basic_blocks[0].statements[0] {
426 if let Rvalue::Use(Operand::Constant(c)) = rvalue {
427 assert_eq!(c.literal, Literal::Int(0xF0, IntTy::I64));
428 } else {
429 panic!("expected folded constant");
430 }
431 }
432 }
433}