joule_mir/
ndarray_fusion.rs

1//! NDArray expression fusion pass.
2//!
3//! Identifies chains of element-wise ndarray operations that produce single-use
4//! temporaries, and fuses them into a single loop. For example:
5//!
6//! ```text
7//! // Before fusion:
8//! let t1 = a.add(&b);    // allocates + loops
9//! let t2 = t1.mul(&c);   // allocates + loops
10//! let result = t2.sub(&d); // allocates + loops
11//!
12//! // After fusion (conceptually):
13//! let result = fused_loop(a, b, c, d, |a, b, c, d| (a + b) * c - d);
14//! ```
15//!
16//! This eliminates intermediate allocations and reduces memory traffic from
17//! 3 full passes over data to 1 pass. The energy savings are substantial:
18//! each eliminated pass saves ~5 pJ per cacheline of data.
19//!
20//! The fusion pass works at the MIR level by:
21//! 1. Identifying `Terminator::Call` sequences targeting ndarray element-wise functions
22//! 2. Building a `FusedExpr` tree from the chain
23//! 3. Replacing the chain with a single fused call
24//!
25//! The C codegen then emits the fused call as one loop with the composed expression.
26
27use std::collections::HashMap;
28
29/// A fused expression tree representing composed element-wise operations.
30#[derive(Debug, Clone)]
31pub enum FusedExpr {
32    /// Reference to an input source array (index into sources list)
33    Source(usize),
34    /// Binary operation on two sub-expressions
35    BinaryOp {
36        op: FusedBinOp,
37        left: Box<FusedExpr>,
38        right: Box<FusedExpr>,
39    },
40    /// Unary operation on a sub-expression
41    UnaryOp {
42        op: FusedUnaryOp,
43        operand: Box<FusedExpr>,
44    },
45    /// Scalar constant
46    Scalar(f64),
47}
48
49/// Binary operations that can be fused
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
51pub enum FusedBinOp {
52    Add,
53    Sub,
54    Mul,
55    Div,
56}
57
58/// Unary operations that can be fused
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum FusedUnaryOp {
61    Neg,
62    Abs,
63}
64
65impl FusedBinOp {
66    /// Get the C operator symbol for this operation
67    pub fn c_op(&self) -> &'static str {
68        match self {
69            Self::Add => "+",
70            Self::Sub => "-",
71            Self::Mul => "*",
72            Self::Div => "/",
73        }
74    }
75
76    /// Parse from ndarray function name suffix
77    pub fn from_suffix(s: &str) -> Option<Self> {
78        match s {
79            "add" => Some(Self::Add),
80            "sub" => Some(Self::Sub),
81            "mul" => Some(Self::Mul),
82            "div" => Some(Self::Div),
83            _ => None,
84        }
85    }
86}
87
88/// A fusion opportunity: a chain of element-wise ops that can be fused.
89#[derive(Debug, Clone)]
90pub struct FusionChain {
91    /// The fused expression tree
92    pub expr: FusedExpr,
93    /// Source array locals (in order of Source(0), Source(1), ...)
94    pub sources: Vec<crate::Local>,
95    /// The result local where the fused output is stored
96    pub result_local: crate::Local,
97    /// The element type (C identifier, e.g. "double")
98    pub elem_c: String,
99    /// Rank of the ndarray
100    pub rank: u32,
101    /// Intermediate locals that can be eliminated after fusion
102    pub eliminated_locals: Vec<crate::Local>,
103}
104
105/// Analyze a function's MIR to find fusion opportunities.
106///
107/// Returns a list of fusion chains found. Each chain represents a sequence
108/// of element-wise operations that can be replaced by a single fused loop.
109pub fn find_fusion_chains(func: &crate::FunctionMIR) -> Vec<FusionChain> {
110    // Map from local → the ndarray call that defines it
111    let mut definitions: HashMap<crate::Local, NdarrayCallInfo> = HashMap::new();
112    // Map from local → number of uses
113    let mut use_counts: HashMap<crate::Local, usize> = HashMap::new();
114
115    // Pass 1: collect ndarray element-wise call definitions and count uses
116    for bb in &func.basic_blocks {
117        // Count uses in statements
118        for stmt in &bb.statements {
119            if let crate::Statement::Assign { rvalue, .. } = stmt {
120                count_rvalue_uses(rvalue, &mut use_counts);
121            }
122        }
123
124        // Check terminator for ndarray calls
125        if let crate::Terminator::Call {
126            func_name: Some(ref name),
127            ref args,
128            ref destination,
129            ..
130        } = bb.terminator
131        {
132            let name_str = name.as_str();
133            let dest_local = destination.local;
134            if let Some(info) = parse_ndarray_elementwise_call(&name_str, args, dest_local) {
135                definitions.insert(dest_local, info);
136            }
137            // Count args as uses
138            for arg in args {
139                if let crate::Operand::Move(place) | crate::Operand::Copy(place) = arg {
140                    *use_counts.entry(place.local).or_insert(0) += 1;
141                }
142            }
143        }
144    }
145
146    // Pass 2: build chains by following single-use temporaries backward
147    let mut chains = Vec::new();
148    let mut visited: std::collections::HashSet<crate::Local> = std::collections::HashSet::new();
149
150    for (&local, info) in &definitions {
151        if visited.contains(&local) {
152            continue;
153        }
154        // Only start chains from results that are NOT single-use intermediates
155        // (i.e., they're used more than once, or they're the final result)
156        if use_counts.get(&local).copied().unwrap_or(0) <= 1
157            && definitions.values().any(|d| d.args.contains(&local))
158        {
159            continue; // This is an intermediate — will be caught by another chain
160        }
161
162        let mut sources = Vec::new();
163        let mut eliminated = Vec::new();
164        let expr = build_fused_expr(
165            local,
166            &definitions,
167            &use_counts,
168            &mut sources,
169            &mut eliminated,
170            &mut visited,
171        );
172
173        // Only worth fusing if we eliminated at least one intermediate
174        if !eliminated.is_empty() {
175            chains.push(FusionChain {
176                expr,
177                sources,
178                result_local: local,
179                elem_c: info.elem_c.clone(),
180                rank: info.rank,
181                eliminated_locals: eliminated,
182            });
183        }
184    }
185
186    chains
187}
188
189/// Info about an ndarray element-wise call
190#[derive(Debug, Clone)]
191struct NdarrayCallInfo {
192    op: FusedBinOp,
193    args: Vec<crate::Local>,
194    elem_c: String,
195    rank: u32,
196}
197
198/// Parse a Terminator::Call to see if it's an ndarray element-wise operation
199fn parse_ndarray_elementwise_call(
200    name: &str,
201    args: &[crate::Operand],
202    _destination: crate::Local,
203) -> Option<NdarrayCallInfo> {
204    // Pattern: JouleNDArray_{elem}_{rank}_{op}
205    let prefix = "JouleNDArray_";
206    if !name.starts_with(prefix) {
207        return None;
208    }
209    let rest = &name[prefix.len()..];
210
211    // Find the last underscore to get the operation name
212    let last_sep = rest.rfind('_')?;
213    let op_name = &rest[last_sep + 1..];
214    let type_part = &rest[..last_sep];
215
216    let op = FusedBinOp::from_suffix(op_name)?;
217
218    // Parse elem and rank from type_part (e.g., "double_2")
219    let rank_sep = type_part.rfind('_')?;
220    let elem_c = type_part[..rank_sep].to_string();
221    let rank: u32 = type_part[rank_sep + 1..].parse().ok()?;
222
223    // Must have exactly 2 args for binary op
224    if args.len() != 2 {
225        return None;
226    }
227
228    let arg_locals: Vec<_> = args
229        .iter()
230        .filter_map(|a| match a {
231            crate::Operand::Move(p) | crate::Operand::Copy(p) => Some(p.local),
232            _ => None,
233        })
234        .collect();
235
236    if arg_locals.len() != 2 {
237        return None;
238    }
239
240    Some(NdarrayCallInfo {
241        op,
242        args: arg_locals,
243        elem_c,
244        rank,
245    })
246}
247
248/// Recursively build a fused expression tree from a chain of operations
249fn build_fused_expr(
250    local: crate::Local,
251    definitions: &HashMap<crate::Local, NdarrayCallInfo>,
252    use_counts: &HashMap<crate::Local, usize>,
253    sources: &mut Vec<crate::Local>,
254    eliminated: &mut Vec<crate::Local>,
255    visited: &mut std::collections::HashSet<crate::Local>,
256) -> FusedExpr {
257    visited.insert(local);
258
259    if let Some(info) = definitions.get(&local) {
260        // Check if both args can be fused (single-use intermediates)
261        let left_local = info.args[0];
262        let right_local = info.args[1];
263
264        let left = if definitions.contains_key(&left_local)
265            && use_counts.get(&left_local).copied().unwrap_or(0) == 1
266            && !visited.contains(&left_local)
267        {
268            eliminated.push(left_local);
269            build_fused_expr(left_local, definitions, use_counts, sources, eliminated, visited)
270        } else {
271            make_source(left_local, sources)
272        };
273
274        let right = if definitions.contains_key(&right_local)
275            && use_counts.get(&right_local).copied().unwrap_or(0) == 1
276            && !visited.contains(&right_local)
277        {
278            eliminated.push(right_local);
279            build_fused_expr(right_local, definitions, use_counts, sources, eliminated, visited)
280        } else {
281            make_source(right_local, sources)
282        };
283
284        FusedExpr::BinaryOp {
285            op: info.op,
286            left: Box::new(left),
287            right: Box::new(right),
288        }
289    } else {
290        make_source(local, sources)
291    }
292}
293
294fn make_source(local: crate::Local, sources: &mut Vec<crate::Local>) -> FusedExpr {
295    if let Some(idx) = sources.iter().position(|&l| l == local) {
296        FusedExpr::Source(idx)
297    } else {
298        let idx = sources.len();
299        sources.push(local);
300        FusedExpr::Source(idx)
301    }
302}
303
304fn count_rvalue_uses(rvalue: &crate::Rvalue, counts: &mut HashMap<crate::Local, usize>) {
305    match rvalue {
306        crate::Rvalue::Use(op) | crate::Rvalue::UnaryOp { operand: op, .. } => {
307            count_operand_uses(op, counts);
308        }
309        crate::Rvalue::BinaryOp { left, right, .. } => {
310            count_operand_uses(left, counts);
311            count_operand_uses(right, counts);
312        }
313        crate::Rvalue::Ref { place, .. } => {
314            *counts.entry(place.local).or_insert(0) += 1;
315        }
316        crate::Rvalue::Aggregate { operands, .. } => {
317            for op in operands {
318                count_operand_uses(op, counts);
319            }
320        }
321        crate::Rvalue::Cast { operand, .. } => {
322            count_operand_uses(operand, counts);
323        }
324        crate::Rvalue::Discriminant { place } | crate::Rvalue::Len { place } => {
325            *counts.entry(place.local).or_insert(0) += 1;
326        }
327        _ => {}
328    }
329}
330
331fn count_operand_uses(op: &crate::Operand, counts: &mut HashMap<crate::Local, usize>) {
332    if let crate::Operand::Move(place) | crate::Operand::Copy(place) = op {
333        *counts.entry(place.local).or_insert(0) += 1;
334    }
335}
336
337/// Generate the C expression string for a fused expression tree.
338/// The sources are accessed as `src0->data[i]`, `src1->data[i]`, etc.
339pub fn fused_expr_to_c(expr: &FusedExpr) -> String {
340    match expr {
341        FusedExpr::Source(idx) => format!("src{idx}->data[i]"),
342        FusedExpr::BinaryOp { op, left, right } => {
343            format!(
344                "({} {} {})",
345                fused_expr_to_c(left),
346                op.c_op(),
347                fused_expr_to_c(right)
348            )
349        }
350        FusedExpr::UnaryOp {
351            op: FusedUnaryOp::Neg,
352            operand,
353        } => {
354            format!("(-{})", fused_expr_to_c(operand))
355        }
356        FusedExpr::UnaryOp {
357            op: FusedUnaryOp::Abs,
358            operand,
359        } => {
360            let inner = fused_expr_to_c(operand);
361            format!("({inner} < 0 ? -{inner} : {inner})")
362        }
363        FusedExpr::Scalar(v) => format!("{v}"),
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn test_fused_expr_to_c() {
373        // a + b * c → (src0->data[i] + (src1->data[i] * src2->data[i]))
374        let expr = FusedExpr::BinaryOp {
375            op: FusedBinOp::Add,
376            left: Box::new(FusedExpr::Source(0)),
377            right: Box::new(FusedExpr::BinaryOp {
378                op: FusedBinOp::Mul,
379                left: Box::new(FusedExpr::Source(1)),
380                right: Box::new(FusedExpr::Source(2)),
381            }),
382        };
383        let c = fused_expr_to_c(&expr);
384        assert_eq!(
385            c,
386            "(src0->data[i] + (src1->data[i] * src2->data[i]))"
387        );
388    }
389
390    #[test]
391    fn test_fused_bin_op_from_suffix() {
392        assert_eq!(FusedBinOp::from_suffix("add"), Some(FusedBinOp::Add));
393        assert_eq!(FusedBinOp::from_suffix("sub"), Some(FusedBinOp::Sub));
394        assert_eq!(FusedBinOp::from_suffix("mul"), Some(FusedBinOp::Mul));
395        assert_eq!(FusedBinOp::from_suffix("div"), Some(FusedBinOp::Div));
396        assert_eq!(FusedBinOp::from_suffix("matmul"), None);
397    }
398
399    #[test]
400    fn test_neg_expr() {
401        let expr = FusedExpr::UnaryOp {
402            op: FusedUnaryOp::Neg,
403            operand: Box::new(FusedExpr::Source(0)),
404        };
405        assert_eq!(fused_expr_to_c(&expr), "(-src0->data[i])");
406    }
407}