1use super::{
14 ChannelId, ComputeOp, DataflowGraph, DependencyAnalysis, DfOperator, EnergyEstimate,
15 OperatorId, TokenType, TokenValue,
16};
17use std::collections::{HashMap, HashSet};
18
19#[derive(Debug, Clone, Default)]
21pub struct OptimizationStats {
22 pub operators_before: usize,
24 pub operators_after: usize,
26 pub channels_before: usize,
28 pub channels_after: usize,
30 pub streams_created: usize,
32 pub operators_eliminated: usize,
34}
35
36impl OptimizationStats {
37 pub fn operator_reduction_percent(&self) -> f64 {
39 if self.operators_before == 0 {
40 return 0.0;
41 }
42 let eliminated = self.operators_before.saturating_sub(self.operators_after);
43 (eliminated as f64 / self.operators_before as f64) * 100.0
44 }
45}
46
47pub struct StreamFusion {
65 pub min_iterations: u64,
67 pub fuse_nested: bool,
69}
70
71impl StreamFusion {
72 pub fn new() -> Self {
73 Self {
74 min_iterations: 4,
75 fuse_nested: true,
76 }
77 }
78
79 pub fn optimize(&self, dfg: &mut DataflowGraph) -> usize {
81 let mut streams_created = 0;
82
83 let candidates = self.find_induction_patterns(dfg);
85
86 for candidate in candidates {
87 if self.can_fuse(&candidate, dfg) {
88 self.fuse_to_stream(candidate, dfg);
89 streams_created += 1;
90 }
91 }
92
93 streams_created
94 }
95
96 fn find_induction_patterns(&self, dfg: &DataflowGraph) -> Vec<InductionPattern> {
98 let mut patterns = Vec::new();
99
100 for (op_idx, op) in dfg.operators.iter().enumerate() {
102 if let DfOperator::Carry {
103 initial,
104 feedback,
105 continue_signal,
106 output,
107 } = op
108 {
109 let initial_const = self.find_constant_source(*initial, dfg);
115 let add_op = self.find_add_operator(*feedback, dfg);
116 let cmp_info = self.find_comparison(*continue_signal, dfg);
117
118 if let (Some(start), Some((add_idx, step_ch)), Some((cmp_idx, bound_ch))) =
119 (initial_const, add_op, cmp_info)
120 {
121 patterns.push(InductionPattern {
122 carry_idx: op_idx,
123 initial_channel: *initial,
124 feedback_channel: *feedback,
125 continue_channel: *continue_signal,
126 output_channel: *output,
127 start_value: start,
128 step_channel: step_ch,
129 bound_channel: bound_ch,
130 add_operator: add_idx,
131 compare_operator: cmp_idx,
132 });
133 }
134 }
135 }
136
137 patterns
138 }
139
140 fn find_constant_source(&self, ch: ChannelId, dfg: &DataflowGraph) -> Option<i64> {
142 for op in &dfg.operators {
143 if let DfOperator::Constant { value, output, .. } = op {
144 if *output == ch {
145 return match value {
146 TokenValue::Int(i) => Some(*i),
147 TokenValue::Uint(u) => Some(*u as i64),
148 _ => None,
149 };
150 }
151 }
152 }
153 None
154 }
155
156 fn find_add_operator(&self, ch: ChannelId, dfg: &DataflowGraph) -> Option<(usize, ChannelId)> {
158 for (idx, op) in dfg.operators.iter().enumerate() {
159 if let DfOperator::Compute {
160 op: ComputeOp::Add,
161 inputs,
162 output,
163 } = op
164 {
165 if *output == ch && inputs.len() == 2 {
166 return Some((idx, inputs[1]));
169 }
170 }
171 }
172 None
173 }
174
175 fn find_comparison(&self, ch: ChannelId, dfg: &DataflowGraph) -> Option<(usize, ChannelId)> {
177 for (idx, op) in dfg.operators.iter().enumerate() {
178 if let DfOperator::Compute {
179 op: cmp_op,
180 inputs,
181 output,
182 } = op
183 {
184 if *output == ch
185 && matches!(
186 cmp_op,
187 ComputeOp::Lt
188 | ComputeOp::Le
189 | ComputeOp::Gt
190 | ComputeOp::Ge
191 | ComputeOp::Ne
192 )
193 {
194 if inputs.len() == 2 {
195 return Some((idx, inputs[1]));
197 }
198 }
199 }
200 }
201 None
202 }
203
204 fn can_fuse(&self, pattern: &InductionPattern, dfg: &DataflowGraph) -> bool {
206 let step_const = self.find_constant_source(pattern.step_channel, dfg);
211
212 step_const.is_some()
214 }
215
216 fn fuse_to_stream(&self, pattern: InductionPattern, dfg: &mut DataflowGraph) {
218 let start_ch = pattern.initial_channel;
220 let step_ch = pattern.step_channel;
221 let bound_ch = pattern.bound_channel;
222
223 let done_ch = dfg.add_channel(TokenType::Bool);
226
227 dfg.add_operator(DfOperator::Stream {
229 start: start_ch,
230 step: step_ch,
231 bound: bound_ch,
232 output: pattern.output_channel,
233 done: done_ch,
234 });
235
236 let carry_dead_ch = dfg.add_channel(TokenType::i64());
239 let add_dead_ch = dfg.add_channel(TokenType::i64());
240 let cmp_dead_ch = dfg.add_channel(TokenType::Bool);
241
242 if let Some(op) = dfg.operators.get_mut(pattern.carry_idx) {
247 let _ = pattern.feedback_channel;
248 let _ = pattern.continue_channel;
249 *op = DfOperator::Constant {
250 value: TokenValue::Int(pattern.start_value),
251 output: carry_dead_ch,
252 repeat: None,
253 };
254 }
255
256 if let Some(op) = dfg.operators.get_mut(pattern.add_operator) {
258 *op = DfOperator::Constant {
259 value: TokenValue::Int(pattern.start_value),
260 output: add_dead_ch,
261 repeat: None,
262 };
263 }
264
265 if let Some(op) = dfg.operators.get_mut(pattern.compare_operator) {
268 *op = DfOperator::Constant {
269 value: TokenValue::Uint(0),
270 output: cmp_dead_ch,
271 repeat: None,
272 };
273 }
274 }
275}
276
277impl Default for StreamFusion {
278 fn default() -> Self {
279 Self::new()
280 }
281}
282
283#[derive(Debug)]
285struct InductionPattern {
286 carry_idx: usize,
287 initial_channel: ChannelId,
288 feedback_channel: ChannelId,
289 continue_channel: ChannelId,
290 output_channel: ChannelId,
291 start_value: i64,
292 step_channel: ChannelId,
293 bound_channel: ChannelId,
294 add_operator: usize,
295 compare_operator: usize,
296}
297
298pub struct DfgOptimizer {
300 pub stream_fusion: StreamFusion,
302 pub eliminate_dead_code: bool,
304 pub propagate_constants: bool,
306}
307
308impl DfgOptimizer {
309 pub fn new() -> Self {
310 Self {
311 stream_fusion: StreamFusion::new(),
312 eliminate_dead_code: true,
313 propagate_constants: true,
314 }
315 }
316
317 pub fn optimize(&self, dfg: &mut DataflowGraph) -> OptimizationStats {
319 let operators_before = dfg.operators.len();
320 let channels_before = dfg.channels.len();
321
322 let streams_created = self.stream_fusion.optimize(dfg);
324
325 let dead_eliminated = if self.eliminate_dead_code {
327 self.eliminate_dead_code(dfg)
328 } else {
329 0
330 };
331
332 if self.propagate_constants {
334 self.propagate_constants(dfg);
335 }
336
337 OptimizationStats {
338 operators_before,
339 operators_after: dfg.operators.len(),
340 channels_before,
341 channels_after: dfg.channels.len(),
342 streams_created,
343 operators_eliminated: dead_eliminated,
344 }
345 }
346
347 fn eliminate_dead_code(&self, dfg: &mut DataflowGraph) -> usize {
349 let analysis = DependencyAnalysis::analyze(dfg);
350
351 let mut reachable = HashSet::new();
353
354 let mut worklist: Vec<_> = analysis.sinks.clone();
356
357 while let Some(op_id) = worklist.pop() {
358 if reachable.contains(&op_id) {
359 continue;
360 }
361 reachable.insert(op_id);
362
363 if let Some(preds) = analysis.predecessors.get(&op_id) {
365 for pred in preds {
366 worklist.push(*pred);
367 }
368 }
369 }
370
371 let total_ops = dfg.operators.len();
373 let reachable_count = reachable.len();
374
375 total_ops.saturating_sub(reachable_count)
376 }
377
378 fn propagate_constants(&self, dfg: &mut DataflowGraph) {
380 let mut constants: HashMap<ChannelId, TokenValue> = HashMap::new();
382
383 for op in &dfg.operators {
384 if let DfOperator::Constant { value, output, .. } = op {
385 constants.insert(*output, value.clone());
386 }
387 }
388
389 }
392
393 pub fn estimate_energy(
395 &self,
396 dfg: &mut DataflowGraph,
397 firing_estimates: &HashMap<OperatorId, u64>,
398 ) {
399 for (idx, op) in dfg.operators.iter().enumerate() {
400 let op_id = OperatorId::new(idx as u32);
401 let static_cost = op.energy_cost() as f64;
402 let firing_count = firing_estimates.get(&op_id).copied().unwrap_or(1);
403
404 dfg.energy_estimates.insert(
405 op_id,
406 EnergyEstimate {
407 static_cost,
408 dynamic_cost: 0.0, firing_count,
410 },
411 );
412 }
413 }
414}
415
416impl Default for DfgOptimizer {
417 fn default() -> Self {
418 Self::new()
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 #[test]
427 fn test_stream_fusion_new() {
428 let fusion = StreamFusion::new();
429 assert_eq!(fusion.min_iterations, 4);
430 assert!(fusion.fuse_nested);
431 }
432
433 #[test]
434 fn test_optimizer_default() {
435 let optimizer = DfgOptimizer::new();
436 assert!(optimizer.eliminate_dead_code);
437 assert!(optimizer.propagate_constants);
438 }
439
440 #[test]
441 fn test_optimization_stats() {
442 let stats = OptimizationStats {
443 operators_before: 100,
444 operators_after: 75,
445 channels_before: 50,
446 channels_after: 40,
447 streams_created: 5,
448 operators_eliminated: 25,
449 };
450
451 assert!((stats.operator_reduction_percent() - 25.0).abs() < 0.01);
452 }
453
454 #[test]
455 fn test_simple_optimization() {
456 let mut dfg = DataflowGraph::new();
457
458 let ch1 = dfg.add_channel(TokenType::i32());
460 let _ch2 = dfg.add_channel(TokenType::i32());
461
462 dfg.add_operator(DfOperator::Source {
463 external_id: 0,
464 output: ch1,
465 });
466
467 dfg.add_operator(DfOperator::Sink {
468 input: ch1,
469 external_id: 0,
470 });
471
472 let optimizer = DfgOptimizer::new();
473 let stats = optimizer.optimize(&mut dfg);
474
475 assert!(stats.operators_before > 0);
477 }
478}