joule_mir/
match_lowering.rs

1//! Match expression lowering for Joule MIR
2//!
3//! Converts match expressions into decision trees in MIR.
4//! Uses the classic pattern matching compilation algorithm.
5
6use crate::{
7    BasicBlock, BasicBlockId, FieldIdx, FunctionMIR, IntTy, Local, LocalDecl, Operand, Place,
8    Rvalue, Statement, SwitchTargets, Terminator, Ty,
9};
10use joule_common::Span;
11
12/// A pattern in the pattern matrix
13#[derive(Debug, Clone)]
14pub enum Pat {
15    /// Wildcard pattern
16    Wild,
17    /// Binding pattern
18    Bind {
19        name: String,
20        inner: Option<Box<Pat>>,
21    },
22    /// Literal pattern
23    Literal(LitPat),
24    /// Constructor pattern (enum variant, struct)
25    Ctor { ctor: CtorKind, fields: Vec<Pat> },
26    /// Or-pattern
27    Or(Vec<Pat>),
28    /// Range pattern
29    Range {
30        lo: Option<i128>,
31        hi: Option<i128>,
32        inclusive: bool,
33    },
34}
35
36impl Pat {
37    /// Check if this pattern is a wildcard or binding (always matches)
38    pub fn is_wildcard(&self) -> bool {
39        matches!(self, Pat::Wild | Pat::Bind { .. })
40    }
41
42    /// Get the fields of this pattern if it's a constructor
43    pub fn fields(&self) -> Option<&[Pat]> {
44        match self {
45            Pat::Ctor { fields, .. } => Some(fields),
46            _ => None,
47        }
48    }
49
50    /// Check if this pattern binds a name
51    pub fn binding_name(&self) -> Option<&str> {
52        match self {
53            Pat::Bind { name, .. } => Some(name),
54            _ => None,
55        }
56    }
57}
58
59/// Literal patterns
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub enum LitPat {
62    Int(i128),
63    Bool(bool),
64    Char(char),
65    Str(String),
66}
67
68impl LitPat {
69    /// Convert to a discriminant value for switch
70    pub fn to_discriminant(&self) -> u128 {
71        match self {
72            LitPat::Int(n) => *n as u128,
73            LitPat::Bool(b) => {
74                if *b {
75                    1
76                } else {
77                    0
78                }
79            }
80            LitPat::Char(c) => *c as u128,
81            LitPat::Str(_) => 0, // String comparison needs special handling
82        }
83    }
84}
85
86/// Constructor kinds
87#[derive(Debug, Clone, PartialEq, Eq, Hash)]
88pub enum CtorKind {
89    /// Enum variant
90    Variant { enum_id: u32, variant_idx: u32 },
91    /// Struct
92    Struct { struct_id: u32 },
93    /// Tuple
94    Tuple(usize),
95    /// Array
96    Array(usize),
97    /// Slice
98    Slice { min_len: usize, has_rest: bool },
99    /// Box
100    Box,
101    /// Reference
102    Ref { mutable: bool },
103}
104
105impl CtorKind {
106    /// Get the arity (number of fields) of this constructor
107    pub fn arity(&self) -> usize {
108        match self {
109            CtorKind::Variant { .. } => 0, // Would look up actual arity from type info
110            CtorKind::Struct { .. } => 0,  // Would look up actual arity from type info
111            CtorKind::Tuple(n) => *n,
112            CtorKind::Array(n) => *n,
113            CtorKind::Slice { min_len, .. } => *min_len,
114            CtorKind::Box => 1,
115            CtorKind::Ref { .. } => 1,
116        }
117    }
118
119    /// Get the discriminant value for this constructor
120    pub fn discriminant(&self) -> Option<u128> {
121        match self {
122            CtorKind::Variant { variant_idx, .. } => Some(*variant_idx as u128),
123            _ => None,
124        }
125    }
126}
127
128/// A match arm for lowering
129#[derive(Debug, Clone)]
130pub struct MatchArm {
131    /// The pattern
132    pub pattern: Pat,
133    /// Guard expression (if any) - represented as a local containing the guard result
134    pub guard: Option<Local>,
135    /// Body block ID
136    pub body: BasicBlockId,
137    /// Bindings to be set before entering the body
138    pub bindings: Vec<(String, Local)>,
139    /// Span
140    pub span: Span,
141}
142
143/// Match lowering context
144pub struct MatchLowering<'a> {
145    /// The function we're lowering into
146    func: &'a mut FunctionMIR,
147    /// Type information for the scrutinee
148    _scrutinee_ty: Ty,
149    /// Generated blocks (to be merged into function)
150    new_blocks: Vec<BasicBlock>,
151}
152
153/// A lowered basic block (for return value)
154#[derive(Debug)]
155pub struct LoweredBlock {
156    pub id: BasicBlockId,
157    pub statements: Vec<Statement>,
158    pub terminator: Terminator,
159}
160
161impl<'a> MatchLowering<'a> {
162    /// Create a new match lowering context
163    pub fn new(func: &'a mut FunctionMIR, scrutinee_ty: Ty) -> Self {
164        Self {
165            func,
166            _scrutinee_ty: scrutinee_ty,
167            new_blocks: Vec::new(),
168        }
169    }
170
171    /// Lower a match expression
172    ///
173    /// Returns the entry block ID for the match
174    pub fn lower_match(
175        &mut self,
176        scrutinee: Place,
177        arms: &[MatchArm],
178        result_place: Place,
179        join_block: BasicBlockId,
180        span: Span,
181    ) -> BasicBlockId {
182        // Build the pattern matrix
183        let matrix = self.build_matrix(arms);
184
185        // Compile the matrix into a decision tree
186        let entry_block = self.compile_matrix(
187            matrix,
188            vec![scrutinee],
189            arms,
190            result_place,
191            join_block,
192            span,
193        );
194
195        // Move new blocks into the function
196        for block in std::mem::take(&mut self.new_blocks) {
197            self.func.basic_blocks.push(block);
198        }
199
200        entry_block
201    }
202
203    /// Build the pattern matrix from match arms
204    fn build_matrix(&self, arms: &[MatchArm]) -> PatternMatrix {
205        let rows: Vec<_> = arms
206            .iter()
207            .enumerate()
208            .map(|(i, arm)| PatternRow {
209                patterns: vec![arm.pattern.clone()],
210                guard: arm.guard,
211                arm_index: i,
212            })
213            .collect();
214
215        PatternMatrix { rows }
216    }
217
218    /// Compile a pattern matrix into decision tree
219    fn compile_matrix(
220        &mut self,
221        matrix: PatternMatrix,
222        scrutinees: Vec<Place>,
223        arms: &[MatchArm],
224        result_place: Place,
225        join_block: BasicBlockId,
226        span: Span,
227    ) -> BasicBlockId {
228        // Base case: empty matrix (unreachable)
229        if matrix.rows.is_empty() {
230            let block_id = self.new_block_id();
231            self.new_blocks
232                .push(BasicBlock::new(Terminator::Unreachable { span }));
233            return block_id;
234        }
235
236        // Base case: all patterns in first column are wildcards
237        if matrix.all_wildcards() {
238            // First row matches - emit bindings and jump to body
239            let first_row = &matrix.rows[0];
240            let arm = &arms[first_row.arm_index];
241
242            // Check guard if present
243            if let Some(guard_local) = first_row.guard {
244                return self.compile_with_guard(
245                    guard_local,
246                    arm,
247                    matrix.rows[1..].to_vec(),
248                    scrutinees,
249                    arms,
250                    result_place,
251                    join_block,
252                    span,
253                );
254            }
255
256            // No guard - just emit bindings and jump to body
257            let block_id = self.new_block_id();
258            let statements = self.emit_bindings(&arm.bindings, &scrutinees, span);
259
260            self.new_blocks.push(BasicBlock {
261                statements,
262                terminator: Terminator::Goto {
263                    target: arm.body,
264                    span,
265                },
266            });
267
268            return block_id;
269        }
270
271        // Find the best column to split on
272        let col = self.select_column(&matrix);
273
274        // No scrutinee for this column (shouldn't happen)
275        if col >= scrutinees.len() {
276            let block_id = self.new_block_id();
277            self.new_blocks
278                .push(BasicBlock::new(Terminator::Unreachable { span }));
279            return block_id;
280        }
281
282        let scrutinee = scrutinees[col].clone();
283
284        // Get the constructors used in this column
285        let ctors = self.collect_constructors(&matrix, col);
286
287        // If no constructors, all are wildcards - move to next column
288        if ctors.is_empty() {
289            let reduced = matrix.remove_column(col);
290            let mut new_scrutinees = scrutinees.clone();
291            new_scrutinees.remove(col);
292            return self.compile_matrix(
293                reduced,
294                new_scrutinees,
295                arms,
296                result_place,
297                join_block,
298                span,
299            );
300        }
301
302        // Check if we have literal patterns
303        let literals = self.collect_literals(&matrix, col);
304        if !literals.is_empty() {
305            return self.compile_literal_switch(
306                matrix,
307                col,
308                scrutinee,
309                scrutinees,
310                arms,
311                result_place,
312                join_block,
313                span,
314            );
315        }
316
317        // Create a switch on the constructor discriminant
318        self.compile_ctor_switch(
319            matrix,
320            col,
321            scrutinee,
322            scrutinees,
323            arms,
324            result_place,
325            join_block,
326            span,
327            &ctors,
328        )
329    }
330
331    /// Compile with a guard check
332    fn compile_with_guard(
333        &mut self,
334        guard_local: Local,
335        arm: &MatchArm,
336        remaining_rows: Vec<PatternRow>,
337        scrutinees: Vec<Place>,
338        arms: &[MatchArm],
339        result_place: Place,
340        join_block: BasicBlockId,
341        span: Span,
342    ) -> BasicBlockId {
343        let guard_block = self.new_block_id();
344        let body_block = arm.body;
345
346        // Create else branch for when guard fails
347        let else_matrix = PatternMatrix {
348            rows: remaining_rows,
349        };
350        let else_block = self.compile_matrix(
351            else_matrix,
352            scrutinees.clone(),
353            arms,
354            result_place.clone(),
355            join_block,
356            span,
357        );
358
359        // Emit the guard check
360        let statements = self.emit_bindings(&arm.bindings, &scrutinees, span);
361
362        self.new_blocks.push(BasicBlock {
363            statements,
364            terminator: Terminator::SwitchInt {
365                discriminant: Operand::Copy(Place::from_local(guard_local)),
366                targets: SwitchTargets {
367                    branches: vec![(1, body_block)], // true -> body
368                    otherwise: else_block,           // false -> next pattern
369                },
370                span,
371            },
372        });
373
374        guard_block
375    }
376
377    /// Compile a switch on literal values
378    fn compile_literal_switch(
379        &mut self,
380        matrix: PatternMatrix,
381        col: usize,
382        scrutinee: Place,
383        scrutinees: Vec<Place>,
384        arms: &[MatchArm],
385        result_place: Place,
386        join_block: BasicBlockId,
387        span: Span,
388    ) -> BasicBlockId {
389        let literals = self.collect_literals(&matrix, col);
390        let switch_block = self.new_block_id();
391
392        let mut branches = Vec::new();
393
394        for lit in &literals {
395            // Specialize matrix for this literal
396            let specialized = matrix.specialize_literal(col, lit);
397            let mut new_scrutinees = scrutinees.clone();
398            new_scrutinees.remove(col);
399
400            let case_block = self.compile_matrix(
401                specialized,
402                new_scrutinees,
403                arms,
404                result_place.clone(),
405                join_block,
406                span,
407            );
408
409            branches.push((lit.to_discriminant(), case_block));
410        }
411
412        // Default case (wildcard rows)
413        let default_matrix = matrix.default(col);
414        let mut default_scrutinees = scrutinees.clone();
415        default_scrutinees.remove(col);
416
417        let default_block = if default_matrix.rows.is_empty() {
418            let unreachable_block = self.new_block_id();
419            self.new_blocks
420                .push(BasicBlock::new(Terminator::Unreachable { span }));
421            unreachable_block
422        } else {
423            self.compile_matrix(
424                default_matrix,
425                default_scrutinees,
426                arms,
427                result_place,
428                join_block,
429                span,
430            )
431        };
432
433        self.new_blocks.push(BasicBlock {
434            statements: vec![],
435            terminator: Terminator::SwitchInt {
436                discriminant: Operand::Copy(scrutinee),
437                targets: SwitchTargets {
438                    branches,
439                    otherwise: default_block,
440                },
441                span,
442            },
443        });
444
445        switch_block
446    }
447
448    /// Compile a switch on constructor discriminant
449    fn compile_ctor_switch(
450        &mut self,
451        matrix: PatternMatrix,
452        col: usize,
453        scrutinee: Place,
454        scrutinees: Vec<Place>,
455        arms: &[MatchArm],
456        result_place: Place,
457        join_block: BasicBlockId,
458        span: Span,
459        ctors: &[CtorKind],
460    ) -> BasicBlockId {
461        let switch_block = self.new_block_id();
462        let mut branches = Vec::new();
463
464        for ctor in ctors {
465            // Specialize matrix for this constructor
466            let specialized = matrix.specialize(col, ctor);
467
468            // Create places for constructor fields
469            let mut new_scrutinees = scrutinees.clone();
470            new_scrutinees.remove(col);
471
472            // Add field places
473            let arity = ctor.arity();
474            for i in 0..arity {
475                let field_place = scrutinee.clone().field(FieldIdx::new(i as u32));
476                new_scrutinees.push(field_place);
477            }
478
479            let case_block = self.compile_matrix(
480                specialized,
481                new_scrutinees,
482                arms,
483                result_place.clone(),
484                join_block,
485                span,
486            );
487
488            if let Some(discr) = ctor.discriminant() {
489                branches.push((discr, case_block));
490            } else {
491                // Single constructor (struct, tuple) - direct jump
492                branches.push((0, case_block));
493            }
494        }
495
496        // Default case for non-exhaustive matches
497        let default_matrix = matrix.default(col);
498        let mut default_scrutinees = scrutinees.clone();
499        default_scrutinees.remove(col);
500
501        let default_block = if default_matrix.rows.is_empty() {
502            let unreachable_block = self.new_block_id();
503            self.new_blocks
504                .push(BasicBlock::new(Terminator::Unreachable { span }));
505            unreachable_block
506        } else {
507            self.compile_matrix(
508                default_matrix,
509                default_scrutinees,
510                arms,
511                result_place,
512                join_block,
513                span,
514            )
515        };
516
517        // If there's only one constructor and no discriminant, use Goto
518        if ctors.len() == 1 && ctors[0].discriminant().is_none() {
519            self.new_blocks.push(BasicBlock {
520                statements: vec![],
521                terminator: Terminator::Goto {
522                    target: branches[0].1,
523                    span,
524                },
525            });
526        } else {
527            // Get discriminant
528            let discr_local = self.new_local(Ty::Int(IntTy::I32), span);
529            let discr_place = Place::from_local(discr_local);
530
531            self.new_blocks.push(BasicBlock {
532                statements: vec![Statement::Assign {
533                    place: discr_place.clone(),
534                    rvalue: Rvalue::Discriminant { place: scrutinee },
535                    span,
536                }],
537                terminator: Terminator::SwitchInt {
538                    discriminant: Operand::Copy(discr_place),
539                    targets: SwitchTargets {
540                        branches,
541                        otherwise: default_block,
542                    },
543                    span,
544                },
545            });
546        }
547
548        switch_block
549    }
550
551    /// Emit binding statements
552    fn emit_bindings(
553        &self,
554        bindings: &[(String, Local)],
555        _scrutinees: &[Place],
556        span: Span,
557    ) -> Vec<Statement> {
558        // In a full implementation, this would copy/move values from scrutinees to locals
559        // For now, we mark bindings as live
560        bindings
561            .iter()
562            .map(|(_, local)| Statement::StorageLive {
563                local: *local,
564                span,
565            })
566            .collect()
567    }
568
569    /// Create a new block ID
570    fn new_block_id(&self) -> BasicBlockId {
571        BasicBlockId::new((self.func.basic_blocks.len() + self.new_blocks.len()) as u32)
572    }
573
574    /// Create a new local variable
575    fn new_local(&mut self, ty: Ty, span: Span) -> Local {
576        let local = Local::new(self.func.locals.len() as u32);
577        self.func.locals.push(LocalDecl::new(None, ty, false, span));
578        local
579    }
580
581    /// Select the best column to split on
582    fn select_column(&self, matrix: &PatternMatrix) -> usize {
583        // Heuristic: first column with a constructor
584        for col in 0..matrix.width() {
585            for row in &matrix.rows {
586                if col < row.patterns.len() {
587                    match &row.patterns[col] {
588                        Pat::Wild | Pat::Bind { .. } => continue,
589                        _ => return col,
590                    }
591                }
592            }
593        }
594        0
595    }
596
597    /// Collect all constructors used in a column
598    fn collect_constructors(&self, matrix: &PatternMatrix, col: usize) -> Vec<CtorKind> {
599        let mut ctors = Vec::new();
600        for row in &matrix.rows {
601            if col < row.patterns.len() {
602                if let Pat::Ctor { ctor, .. } = &row.patterns[col] {
603                    if !ctors.contains(ctor) {
604                        ctors.push(ctor.clone());
605                    }
606                }
607            }
608        }
609        ctors
610    }
611
612    /// Collect all literals used in a column
613    fn collect_literals(&self, matrix: &PatternMatrix, col: usize) -> Vec<LitPat> {
614        let mut lits = Vec::new();
615        for row in &matrix.rows {
616            if col < row.patterns.len() {
617                if let Pat::Literal(lit) = &row.patterns[col] {
618                    if !lits.contains(lit) {
619                        lits.push(lit.clone());
620                    }
621                }
622            }
623        }
624        lits
625    }
626}
627
628/// The pattern matrix used in match compilation
629#[derive(Debug, Clone)]
630pub struct PatternMatrix {
631    pub rows: Vec<PatternRow>,
632}
633
634/// A row in the pattern matrix
635#[derive(Debug, Clone)]
636pub struct PatternRow {
637    pub patterns: Vec<Pat>,
638    pub guard: Option<Local>,
639    pub arm_index: usize,
640}
641
642impl PatternMatrix {
643    /// Get the width (number of columns) of the matrix
644    pub fn width(&self) -> usize {
645        self.rows.first().map_or(0, |r| r.patterns.len())
646    }
647
648    /// Check if all patterns in the first column are wildcards
649    pub fn all_wildcards(&self) -> bool {
650        self.rows.iter().all(|row| {
651            row.patterns
652                .first()
653                .map(|p| p.is_wildcard())
654                .unwrap_or(true)
655        })
656    }
657
658    /// Remove a column from the matrix
659    pub fn remove_column(&self, col: usize) -> PatternMatrix {
660        let rows = self
661            .rows
662            .iter()
663            .map(|row| PatternRow {
664                patterns: row
665                    .patterns
666                    .iter()
667                    .enumerate()
668                    .filter(|(i, _)| *i != col)
669                    .map(|(_, p)| p.clone())
670                    .collect(),
671                guard: row.guard,
672                arm_index: row.arm_index,
673            })
674            .collect();
675        PatternMatrix { rows }
676    }
677
678    /// Specialize the matrix for a constructor
679    pub fn specialize(&self, col: usize, ctor: &CtorKind) -> PatternMatrix {
680        let arity = ctor.arity();
681        let rows = self
682            .rows
683            .iter()
684            .filter_map(|row| {
685                if col >= row.patterns.len() {
686                    return None;
687                }
688                match &row.patterns[col] {
689                    Pat::Wild | Pat::Bind { .. } => {
690                        // Wildcard matches any constructor - expand with wildcards
691                        let mut new_patterns: Vec<_> = (0..arity).map(|_| Pat::Wild).collect();
692                        new_patterns.extend(
693                            row.patterns
694                                .iter()
695                                .enumerate()
696                                .filter(|(i, _)| *i != col)
697                                .map(|(_, p)| p.clone()),
698                        );
699                        Some(PatternRow {
700                            patterns: new_patterns,
701                            guard: row.guard,
702                            arm_index: row.arm_index,
703                        })
704                    }
705                    Pat::Ctor {
706                        ctor: row_ctor,
707                        fields,
708                    } if row_ctor == ctor => {
709                        // Constructor matches - add fields
710                        let mut new_patterns = fields.clone();
711                        new_patterns.extend(
712                            row.patterns
713                                .iter()
714                                .enumerate()
715                                .filter(|(i, _)| *i != col)
716                                .map(|(_, p)| p.clone()),
717                        );
718                        Some(PatternRow {
719                            patterns: new_patterns,
720                            guard: row.guard,
721                            arm_index: row.arm_index,
722                        })
723                    }
724                    Pat::Or(pats) => {
725                        // Check if any alternative matches
726                        for pat in pats {
727                            match pat {
728                                Pat::Wild | Pat::Bind { .. } => {
729                                    let mut new_patterns: Vec<_> =
730                                        (0..arity).map(|_| Pat::Wild).collect();
731                                    new_patterns.extend(
732                                        row.patterns
733                                            .iter()
734                                            .enumerate()
735                                            .filter(|(i, _)| *i != col)
736                                            .map(|(_, p)| p.clone()),
737                                    );
738                                    return Some(PatternRow {
739                                        patterns: new_patterns,
740                                        guard: row.guard,
741                                        arm_index: row.arm_index,
742                                    });
743                                }
744                                Pat::Ctor {
745                                    ctor: pat_ctor,
746                                    fields,
747                                } if pat_ctor == ctor => {
748                                    let mut new_patterns = fields.clone();
749                                    new_patterns.extend(
750                                        row.patterns
751                                            .iter()
752                                            .enumerate()
753                                            .filter(|(i, _)| *i != col)
754                                            .map(|(_, p)| p.clone()),
755                                    );
756                                    return Some(PatternRow {
757                                        patterns: new_patterns,
758                                        guard: row.guard,
759                                        arm_index: row.arm_index,
760                                    });
761                                }
762                                _ => {}
763                            }
764                        }
765                        None
766                    }
767                    _ => None,
768                }
769            })
770            .collect();
771        PatternMatrix { rows }
772    }
773
774    /// Specialize the matrix for a literal
775    pub fn specialize_literal(&self, col: usize, lit: &LitPat) -> PatternMatrix {
776        let rows = self
777            .rows
778            .iter()
779            .filter_map(|row| {
780                if col >= row.patterns.len() {
781                    return None;
782                }
783                match &row.patterns[col] {
784                    Pat::Wild | Pat::Bind { .. } => {
785                        // Wildcard matches any literal
786                        let new_patterns: Vec<_> = row
787                            .patterns
788                            .iter()
789                            .enumerate()
790                            .filter(|(i, _)| *i != col)
791                            .map(|(_, p)| p.clone())
792                            .collect();
793                        Some(PatternRow {
794                            patterns: new_patterns,
795                            guard: row.guard,
796                            arm_index: row.arm_index,
797                        })
798                    }
799                    Pat::Literal(row_lit) if row_lit == lit => {
800                        // Literal matches
801                        let new_patterns: Vec<_> = row
802                            .patterns
803                            .iter()
804                            .enumerate()
805                            .filter(|(i, _)| *i != col)
806                            .map(|(_, p)| p.clone())
807                            .collect();
808                        Some(PatternRow {
809                            patterns: new_patterns,
810                            guard: row.guard,
811                            arm_index: row.arm_index,
812                        })
813                    }
814                    Pat::Or(pats) => {
815                        // Check if any alternative matches
816                        for pat in pats {
817                            match pat {
818                                Pat::Wild | Pat::Bind { .. } => {
819                                    let new_patterns: Vec<_> = row
820                                        .patterns
821                                        .iter()
822                                        .enumerate()
823                                        .filter(|(i, _)| *i != col)
824                                        .map(|(_, p)| p.clone())
825                                        .collect();
826                                    return Some(PatternRow {
827                                        patterns: new_patterns,
828                                        guard: row.guard,
829                                        arm_index: row.arm_index,
830                                    });
831                                }
832                                Pat::Literal(pat_lit) if pat_lit == lit => {
833                                    let new_patterns: Vec<_> = row
834                                        .patterns
835                                        .iter()
836                                        .enumerate()
837                                        .filter(|(i, _)| *i != col)
838                                        .map(|(_, p)| p.clone())
839                                        .collect();
840                                    return Some(PatternRow {
841                                        patterns: new_patterns,
842                                        guard: row.guard,
843                                        arm_index: row.arm_index,
844                                    });
845                                }
846                                _ => {}
847                            }
848                        }
849                        None
850                    }
851                    _ => None,
852                }
853            })
854            .collect();
855        PatternMatrix { rows }
856    }
857
858    /// Get rows that don't match any specific constructor (default matrix)
859    pub fn default(&self, col: usize) -> PatternMatrix {
860        let rows = self
861            .rows
862            .iter()
863            .filter_map(|row| {
864                if col >= row.patterns.len() {
865                    return Some(row.clone());
866                }
867                match &row.patterns[col] {
868                    Pat::Wild | Pat::Bind { .. } => {
869                        let new_patterns: Vec<_> = row
870                            .patterns
871                            .iter()
872                            .enumerate()
873                            .filter(|(i, _)| *i != col)
874                            .map(|(_, p)| p.clone())
875                            .collect();
876                        Some(PatternRow {
877                            patterns: new_patterns,
878                            guard: row.guard,
879                            arm_index: row.arm_index,
880                        })
881                    }
882                    Pat::Or(pats) => {
883                        // Include if any alternative is a wildcard
884                        if pats.iter().any(|p| p.is_wildcard()) {
885                            let new_patterns: Vec<_> = row
886                                .patterns
887                                .iter()
888                                .enumerate()
889                                .filter(|(i, _)| *i != col)
890                                .map(|(_, p)| p.clone())
891                                .collect();
892                            Some(PatternRow {
893                                patterns: new_patterns,
894                                guard: row.guard,
895                                arm_index: row.arm_index,
896                            })
897                        } else {
898                            None
899                        }
900                    }
901                    _ => None,
902                }
903            })
904            .collect();
905        PatternMatrix { rows }
906    }
907}
908
909/// Helper to create a simple match for common cases
910pub fn lower_simple_match(
911    func: &mut FunctionMIR,
912    scrutinee: Place,
913    _scrutinee_ty: Ty,
914    arms: Vec<(Pat, BasicBlockId)>,
915    default_block: BasicBlockId,
916    span: Span,
917) -> BasicBlockId {
918    let switch_block = BasicBlockId::new(func.basic_blocks.len() as u32);
919
920    let mut branches = Vec::new();
921    for (pat, target) in arms {
922        match pat {
923            Pat::Literal(lit) => {
924                branches.push((lit.to_discriminant(), target));
925            }
926            Pat::Ctor { ctor, .. } => {
927                if let Some(discr) = ctor.discriminant() {
928                    branches.push((discr, target));
929                }
930            }
931            _ => {}
932        }
933    }
934
935    func.basic_blocks.push(BasicBlock {
936        statements: vec![],
937        terminator: Terminator::SwitchInt {
938            discriminant: Operand::Copy(scrutinee),
939            targets: SwitchTargets {
940                branches,
941                otherwise: default_block,
942            },
943            span,
944        },
945    });
946
947    switch_block
948}
949
950#[cfg(test)]
951mod tests {
952    use super::*;
953
954    #[test]
955    fn test_pattern_matrix_wildcards() {
956        let matrix = PatternMatrix {
957            rows: vec![
958                PatternRow {
959                    patterns: vec![Pat::Wild],
960                    guard: None,
961                    arm_index: 0,
962                },
963                PatternRow {
964                    patterns: vec![Pat::Wild],
965                    guard: None,
966                    arm_index: 1,
967                },
968            ],
969        };
970        assert!(matrix.all_wildcards());
971    }
972
973    #[test]
974    fn test_pattern_matrix_not_wildcards() {
975        let matrix = PatternMatrix {
976            rows: vec![
977                PatternRow {
978                    patterns: vec![Pat::Literal(LitPat::Int(1))],
979                    guard: None,
980                    arm_index: 0,
981                },
982                PatternRow {
983                    patterns: vec![Pat::Wild],
984                    guard: None,
985                    arm_index: 1,
986                },
987            ],
988        };
989        assert!(!matrix.all_wildcards());
990    }
991
992    #[test]
993    fn test_pattern_matrix_specialize_literal() {
994        let matrix = PatternMatrix {
995            rows: vec![
996                PatternRow {
997                    patterns: vec![Pat::Literal(LitPat::Int(1))],
998                    guard: None,
999                    arm_index: 0,
1000                },
1001                PatternRow {
1002                    patterns: vec![Pat::Literal(LitPat::Int(2))],
1003                    guard: None,
1004                    arm_index: 1,
1005                },
1006                PatternRow {
1007                    patterns: vec![Pat::Wild],
1008                    guard: None,
1009                    arm_index: 2,
1010                },
1011            ],
1012        };
1013
1014        let specialized = matrix.specialize_literal(0, &LitPat::Int(1));
1015        assert_eq!(specialized.rows.len(), 2); // matching literal + wildcard
1016        assert_eq!(specialized.rows[0].arm_index, 0);
1017        assert_eq!(specialized.rows[1].arm_index, 2);
1018    }
1019
1020    #[test]
1021    fn test_pattern_matrix_default() {
1022        let matrix = PatternMatrix {
1023            rows: vec![
1024                PatternRow {
1025                    patterns: vec![Pat::Literal(LitPat::Int(1))],
1026                    guard: None,
1027                    arm_index: 0,
1028                },
1029                PatternRow {
1030                    patterns: vec![Pat::Wild],
1031                    guard: None,
1032                    arm_index: 1,
1033                },
1034            ],
1035        };
1036
1037        let default = matrix.default(0);
1038        assert_eq!(default.rows.len(), 1);
1039        assert_eq!(default.rows[0].arm_index, 1);
1040    }
1041
1042    #[test]
1043    fn test_ctor_kind_arity() {
1044        assert_eq!(CtorKind::Tuple(3).arity(), 3);
1045        assert_eq!(CtorKind::Array(5).arity(), 5);
1046        assert_eq!(CtorKind::Box.arity(), 1);
1047        assert_eq!(CtorKind::Ref { mutable: false }.arity(), 1);
1048    }
1049
1050    #[test]
1051    fn test_lit_pat_discriminant() {
1052        assert_eq!(LitPat::Bool(true).to_discriminant(), 1);
1053        assert_eq!(LitPat::Bool(false).to_discriminant(), 0);
1054        assert_eq!(LitPat::Int(42).to_discriminant(), 42);
1055        assert_eq!(LitPat::Char('A').to_discriminant(), 65);
1056    }
1057
1058    #[test]
1059    fn test_pattern_is_wildcard() {
1060        assert!(Pat::Wild.is_wildcard());
1061        assert!(
1062            Pat::Bind {
1063                name: "x".to_string(),
1064                inner: None
1065            }
1066            .is_wildcard()
1067        );
1068        assert!(!Pat::Literal(LitPat::Int(1)).is_wildcard());
1069    }
1070
1071    #[test]
1072    fn test_pattern_binding_name() {
1073        let pat = Pat::Bind {
1074            name: "x".to_string(),
1075            inner: None,
1076        };
1077        assert_eq!(pat.binding_name(), Some("x"));
1078
1079        let wild = Pat::Wild;
1080        assert_eq!(wild.binding_name(), None);
1081    }
1082
1083    #[test]
1084    fn test_or_pattern_specialize() {
1085        let matrix = PatternMatrix {
1086            rows: vec![
1087                PatternRow {
1088                    patterns: vec![Pat::Or(vec![
1089                        Pat::Literal(LitPat::Int(1)),
1090                        Pat::Literal(LitPat::Int(2)),
1091                    ])],
1092                    guard: None,
1093                    arm_index: 0,
1094                },
1095                PatternRow {
1096                    patterns: vec![Pat::Wild],
1097                    guard: None,
1098                    arm_index: 1,
1099                },
1100            ],
1101        };
1102
1103        let specialized = matrix.specialize_literal(0, &LitPat::Int(1));
1104        assert_eq!(specialized.rows.len(), 2);
1105    }
1106}