1use std::collections::HashMap;
28
29#[derive(Debug, Clone)]
31pub enum FusedExpr {
32 Source(usize),
34 BinaryOp {
36 op: FusedBinOp,
37 left: Box<FusedExpr>,
38 right: Box<FusedExpr>,
39 },
40 UnaryOp {
42 op: FusedUnaryOp,
43 operand: Box<FusedExpr>,
44 },
45 Scalar(f64),
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum FusedBinOp {
52 Add,
53 Sub,
54 Mul,
55 Div,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum FusedUnaryOp {
61 Neg,
62 Abs,
63}
64
65impl FusedBinOp {
66 pub fn c_op(&self) -> &'static str {
68 match self {
69 Self::Add => "+",
70 Self::Sub => "-",
71 Self::Mul => "*",
72 Self::Div => "/",
73 }
74 }
75
76 pub fn from_suffix(s: &str) -> Option<Self> {
78 match s {
79 "add" => Some(Self::Add),
80 "sub" => Some(Self::Sub),
81 "mul" => Some(Self::Mul),
82 "div" => Some(Self::Div),
83 _ => None,
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
90pub struct FusionChain {
91 pub expr: FusedExpr,
93 pub sources: Vec<crate::Local>,
95 pub result_local: crate::Local,
97 pub elem_c: String,
99 pub rank: u32,
101 pub eliminated_locals: Vec<crate::Local>,
103}
104
105pub fn find_fusion_chains(func: &crate::FunctionMIR) -> Vec<FusionChain> {
110 let mut definitions: HashMap<crate::Local, NdarrayCallInfo> = HashMap::new();
112 let mut use_counts: HashMap<crate::Local, usize> = HashMap::new();
114
115 for bb in &func.basic_blocks {
117 for stmt in &bb.statements {
119 if let crate::Statement::Assign { rvalue, .. } = stmt {
120 count_rvalue_uses(rvalue, &mut use_counts);
121 }
122 }
123
124 if let crate::Terminator::Call {
126 func_name: Some(ref name),
127 ref args,
128 ref destination,
129 ..
130 } = bb.terminator
131 {
132 let name_str = name.as_str();
133 let dest_local = destination.local;
134 if let Some(info) = parse_ndarray_elementwise_call(&name_str, args, dest_local) {
135 definitions.insert(dest_local, info);
136 }
137 for arg in args {
139 if let crate::Operand::Move(place) | crate::Operand::Copy(place) = arg {
140 *use_counts.entry(place.local).or_insert(0) += 1;
141 }
142 }
143 }
144 }
145
146 let mut chains = Vec::new();
148 let mut visited: std::collections::HashSet<crate::Local> = std::collections::HashSet::new();
149
150 for (&local, info) in &definitions {
151 if visited.contains(&local) {
152 continue;
153 }
154 if use_counts.get(&local).copied().unwrap_or(0) <= 1
157 && definitions.values().any(|d| d.args.contains(&local))
158 {
159 continue; }
161
162 let mut sources = Vec::new();
163 let mut eliminated = Vec::new();
164 let expr = build_fused_expr(
165 local,
166 &definitions,
167 &use_counts,
168 &mut sources,
169 &mut eliminated,
170 &mut visited,
171 );
172
173 if !eliminated.is_empty() {
175 chains.push(FusionChain {
176 expr,
177 sources,
178 result_local: local,
179 elem_c: info.elem_c.clone(),
180 rank: info.rank,
181 eliminated_locals: eliminated,
182 });
183 }
184 }
185
186 chains
187}
188
189#[derive(Debug, Clone)]
191struct NdarrayCallInfo {
192 op: FusedBinOp,
193 args: Vec<crate::Local>,
194 elem_c: String,
195 rank: u32,
196}
197
198fn parse_ndarray_elementwise_call(
200 name: &str,
201 args: &[crate::Operand],
202 _destination: crate::Local,
203) -> Option<NdarrayCallInfo> {
204 let prefix = "JouleNDArray_";
206 if !name.starts_with(prefix) {
207 return None;
208 }
209 let rest = &name[prefix.len()..];
210
211 let last_sep = rest.rfind('_')?;
213 let op_name = &rest[last_sep + 1..];
214 let type_part = &rest[..last_sep];
215
216 let op = FusedBinOp::from_suffix(op_name)?;
217
218 let rank_sep = type_part.rfind('_')?;
220 let elem_c = type_part[..rank_sep].to_string();
221 let rank: u32 = type_part[rank_sep + 1..].parse().ok()?;
222
223 if args.len() != 2 {
225 return None;
226 }
227
228 let arg_locals: Vec<_> = args
229 .iter()
230 .filter_map(|a| match a {
231 crate::Operand::Move(p) | crate::Operand::Copy(p) => Some(p.local),
232 _ => None,
233 })
234 .collect();
235
236 if arg_locals.len() != 2 {
237 return None;
238 }
239
240 Some(NdarrayCallInfo {
241 op,
242 args: arg_locals,
243 elem_c,
244 rank,
245 })
246}
247
248fn build_fused_expr(
250 local: crate::Local,
251 definitions: &HashMap<crate::Local, NdarrayCallInfo>,
252 use_counts: &HashMap<crate::Local, usize>,
253 sources: &mut Vec<crate::Local>,
254 eliminated: &mut Vec<crate::Local>,
255 visited: &mut std::collections::HashSet<crate::Local>,
256) -> FusedExpr {
257 visited.insert(local);
258
259 if let Some(info) = definitions.get(&local) {
260 let left_local = info.args[0];
262 let right_local = info.args[1];
263
264 let left = if definitions.contains_key(&left_local)
265 && use_counts.get(&left_local).copied().unwrap_or(0) == 1
266 && !visited.contains(&left_local)
267 {
268 eliminated.push(left_local);
269 build_fused_expr(left_local, definitions, use_counts, sources, eliminated, visited)
270 } else {
271 make_source(left_local, sources)
272 };
273
274 let right = if definitions.contains_key(&right_local)
275 && use_counts.get(&right_local).copied().unwrap_or(0) == 1
276 && !visited.contains(&right_local)
277 {
278 eliminated.push(right_local);
279 build_fused_expr(right_local, definitions, use_counts, sources, eliminated, visited)
280 } else {
281 make_source(right_local, sources)
282 };
283
284 FusedExpr::BinaryOp {
285 op: info.op,
286 left: Box::new(left),
287 right: Box::new(right),
288 }
289 } else {
290 make_source(local, sources)
291 }
292}
293
294fn make_source(local: crate::Local, sources: &mut Vec<crate::Local>) -> FusedExpr {
295 if let Some(idx) = sources.iter().position(|&l| l == local) {
296 FusedExpr::Source(idx)
297 } else {
298 let idx = sources.len();
299 sources.push(local);
300 FusedExpr::Source(idx)
301 }
302}
303
304fn count_rvalue_uses(rvalue: &crate::Rvalue, counts: &mut HashMap<crate::Local, usize>) {
305 match rvalue {
306 crate::Rvalue::Use(op) | crate::Rvalue::UnaryOp { operand: op, .. } => {
307 count_operand_uses(op, counts);
308 }
309 crate::Rvalue::BinaryOp { left, right, .. } => {
310 count_operand_uses(left, counts);
311 count_operand_uses(right, counts);
312 }
313 crate::Rvalue::Ref { place, .. } => {
314 *counts.entry(place.local).or_insert(0) += 1;
315 }
316 crate::Rvalue::Aggregate { operands, .. } => {
317 for op in operands {
318 count_operand_uses(op, counts);
319 }
320 }
321 crate::Rvalue::Cast { operand, .. } => {
322 count_operand_uses(operand, counts);
323 }
324 crate::Rvalue::Discriminant { place } | crate::Rvalue::Len { place } => {
325 *counts.entry(place.local).or_insert(0) += 1;
326 }
327 _ => {}
328 }
329}
330
331fn count_operand_uses(op: &crate::Operand, counts: &mut HashMap<crate::Local, usize>) {
332 if let crate::Operand::Move(place) | crate::Operand::Copy(place) = op {
333 *counts.entry(place.local).or_insert(0) += 1;
334 }
335}
336
337pub fn fused_expr_to_c(expr: &FusedExpr) -> String {
340 match expr {
341 FusedExpr::Source(idx) => format!("src{idx}->data[i]"),
342 FusedExpr::BinaryOp { op, left, right } => {
343 format!(
344 "({} {} {})",
345 fused_expr_to_c(left),
346 op.c_op(),
347 fused_expr_to_c(right)
348 )
349 }
350 FusedExpr::UnaryOp {
351 op: FusedUnaryOp::Neg,
352 operand,
353 } => {
354 format!("(-{})", fused_expr_to_c(operand))
355 }
356 FusedExpr::UnaryOp {
357 op: FusedUnaryOp::Abs,
358 operand,
359 } => {
360 let inner = fused_expr_to_c(operand);
361 format!("({inner} < 0 ? -{inner} : {inner})")
362 }
363 FusedExpr::Scalar(v) => format!("{v}"),
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[test]
372 fn test_fused_expr_to_c() {
373 let expr = FusedExpr::BinaryOp {
375 op: FusedBinOp::Add,
376 left: Box::new(FusedExpr::Source(0)),
377 right: Box::new(FusedExpr::BinaryOp {
378 op: FusedBinOp::Mul,
379 left: Box::new(FusedExpr::Source(1)),
380 right: Box::new(FusedExpr::Source(2)),
381 }),
382 };
383 let c = fused_expr_to_c(&expr);
384 assert_eq!(
385 c,
386 "(src0->data[i] + (src1->data[i] * src2->data[i]))"
387 );
388 }
389
390 #[test]
391 fn test_fused_bin_op_from_suffix() {
392 assert_eq!(FusedBinOp::from_suffix("add"), Some(FusedBinOp::Add));
393 assert_eq!(FusedBinOp::from_suffix("sub"), Some(FusedBinOp::Sub));
394 assert_eq!(FusedBinOp::from_suffix("mul"), Some(FusedBinOp::Mul));
395 assert_eq!(FusedBinOp::from_suffix("div"), Some(FusedBinOp::Div));
396 assert_eq!(FusedBinOp::from_suffix("matmul"), None);
397 }
398
399 #[test]
400 fn test_neg_expr() {
401 let expr = FusedExpr::UnaryOp {
402 op: FusedUnaryOp::Neg,
403 operand: Box::new(FusedExpr::Source(0)),
404 };
405 assert_eq!(fused_expr_to_c(&expr), "(-src0->data[i])");
406 }
407}