joule_mir/
lib.rs

1//! Mid-level Intermediate Representation (MIR) for Joule
2//!
3//! MIR is the central IR where all optimizations happen, including:
4//! - Energy optimizations
5//! - Dead code elimination
6//! - Constant folding
7//! - Inlining
8//!
9//! Key features:
10//! - Control flow represented as basic blocks
11//! - SSA-like representation with explicit temporaries
12//! - All control flow is explicit (no implicit returns)
13//! - Borrow checking happens on MIR
14//! - Code generation targets (LLVM, Cranelift) consume MIR
15
16pub mod build;
17pub mod const_fold;
18pub mod dataflow;
19pub mod match_lowering;
20pub mod ndarray_fusion;
21pub mod pretty;
22pub mod visit;
23
24// Re-export dataflow types
25pub use dataflow::{
26    Channel, ChannelId, ComputeOp, DataflowGraph, DependencyAnalysis, DfOperator, EnergyEstimate,
27    MemoryOp, OperatorId, ScheduleStep, ScheduledDfg, TokenType, TokenValue,
28};
29
30use indexmap::IndexMap;
31use joule_common::{Span, Symbol};
32use joule_hir::HirId;
33pub use joule_hir::{BinOp, FunctionAttributes, InlineHint, Literal, ProcessorTarget, UnOp};
34use std::fmt;
35
36/// Unique identifier for functions
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
38pub struct FunctionId(pub u32);
39
40impl FunctionId {
41    pub fn new(id: u32) -> Self {
42        Self(id)
43    }
44
45    pub fn from_hir(hir_id: HirId) -> Self {
46        Self(hir_id.0)
47    }
48}
49
50/// Unique identifier for basic blocks
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
52pub struct BasicBlockId(pub u32);
53
54impl BasicBlockId {
55    pub const START: Self = Self(0);
56
57    pub fn new(id: u32) -> Self {
58        Self(id)
59    }
60}
61
62impl fmt::Display for BasicBlockId {
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        write!(f, "bb{}", self.0)
65    }
66}
67
68/// Local variable identifier
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
70pub struct Local(pub u32);
71
72impl Local {
73    /// The return place (always local 0)
74    pub const RETURN: Self = Self(0);
75
76    pub fn new(id: u32) -> Self {
77        Self(id)
78    }
79
80    pub fn from_hir(hir_id: HirId) -> Self {
81        Self(hir_id.0)
82    }
83}
84
85impl fmt::Display for Local {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        if *self == Self::RETURN {
88            write!(f, "_ret")
89        } else {
90            write!(f, "_{}", self.0)
91        }
92    }
93}
94
95/// Field index
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
97pub struct FieldIdx(pub u32);
98
99impl FieldIdx {
100    pub fn new(idx: u32) -> Self {
101        Self(idx)
102    }
103}
104
105/// Complete MIR for a function
106#[derive(Debug, Clone)]
107pub struct FunctionMIR {
108    /// Function identifier
109    pub id: FunctionId,
110    /// Function name
111    pub name: Symbol,
112    /// Parameter locals (indices into locals array)
113    pub params: Vec<Local>,
114    /// Return type
115    pub return_ty: Ty,
116    /// All local variables (including params and temporaries)
117    pub locals: Vec<LocalDecl>,
118    /// Basic blocks (indexed by BasicBlockId)
119    pub basic_blocks: Vec<BasicBlock>,
120    /// Function attributes for code generation decisions
121    pub attributes: FunctionAttributes,
122    /// Source span
123    pub span: Span,
124}
125
126impl FunctionMIR {
127    /// Create a new function MIR
128    pub fn new(id: FunctionId, name: Symbol, return_ty: Ty, span: Span) -> Self {
129        Self {
130            id,
131            name,
132            params: Vec::new(),
133            return_ty,
134            locals: vec![LocalDecl::return_place()],
135            basic_blocks: Vec::new(),
136            attributes: FunctionAttributes::default(),
137            span,
138        }
139    }
140
141    /// Create a new function MIR with attributes
142    pub fn with_attributes(
143        id: FunctionId,
144        name: Symbol,
145        return_ty: Ty,
146        attributes: FunctionAttributes,
147        span: Span,
148    ) -> Self {
149        Self {
150            id,
151            name,
152            params: Vec::new(),
153            return_ty,
154            locals: vec![LocalDecl::return_place()],
155            basic_blocks: Vec::new(),
156            attributes,
157            span,
158        }
159    }
160
161    /// Check if this function is a GPU kernel
162    pub fn is_gpu_kernel(&self) -> bool {
163        self.attributes.is_gpu_kernel
164            || matches!(self.attributes.target, Some(ProcessorTarget::Gpu))
165    }
166
167    /// Add a new local variable
168    pub fn add_local(&mut self, decl: LocalDecl) -> Local {
169        let local = Local::new(self.locals.len() as u32);
170        self.locals.push(decl);
171        local
172    }
173
174    /// Add a new basic block
175    pub fn add_block(&mut self, block: BasicBlock) -> BasicBlockId {
176        let id = BasicBlockId::new(self.basic_blocks.len() as u32);
177        self.basic_blocks.push(block);
178        id
179    }
180
181    /// Get a basic block by ID
182    pub fn block(&self, id: BasicBlockId) -> &BasicBlock {
183        &self.basic_blocks[id.0 as usize]
184    }
185
186    /// Get a mutable basic block by ID
187    pub fn block_mut(&mut self, id: BasicBlockId) -> &mut BasicBlock {
188        &mut self.basic_blocks[id.0 as usize]
189    }
190}
191
192/// Local variable declaration
193#[derive(Debug, Clone)]
194pub struct LocalDecl {
195    /// Name (for debugging)
196    pub name: Option<Symbol>,
197    /// Type
198    pub ty: Ty,
199    /// Is this mutable?
200    pub mutable: bool,
201    /// Source span
202    pub span: Span,
203}
204
205impl LocalDecl {
206    /// Create the return place declaration
207    pub fn return_place() -> Self {
208        Self {
209            name: None,
210            ty: Ty::Unit,
211            mutable: true,
212            span: Span::dummy(),
213        }
214    }
215
216    /// Create a new local declaration
217    pub fn new(name: Option<Symbol>, ty: Ty, mutable: bool, span: Span) -> Self {
218        Self {
219            name,
220            ty,
221            mutable,
222            span,
223        }
224    }
225}
226
227/// Basic block - a sequence of statements ending with a terminator
228#[derive(Debug, Clone)]
229pub struct BasicBlock {
230    /// Statements in this block
231    pub statements: Vec<Statement>,
232    /// Terminator (determines control flow)
233    pub terminator: Terminator,
234}
235
236impl BasicBlock {
237    /// Create a new basic block with a terminator
238    pub fn new(terminator: Terminator) -> Self {
239        Self {
240            statements: Vec::new(),
241            terminator,
242        }
243    }
244
245    /// Add a statement to this block
246    pub fn push_statement(&mut self, stmt: Statement) {
247        self.statements.push(stmt);
248    }
249}
250
251/// Statement - a side-effecting operation that doesn't transfer control
252#[derive(Debug, Clone)]
253pub enum Statement {
254    /// Assignment: place = rvalue
255    Assign {
256        place: Place,
257        rvalue: Rvalue,
258        span: Span,
259    },
260
261    /// Mark a local as "live" (storage allocated)
262    StorageLive { local: Local, span: Span },
263
264    /// Mark a local as "dead" (storage can be freed)
265    StorageDead { local: Local, span: Span },
266
267    /// No-op (used for alignment or as placeholder)
268    Nop,
269}
270
271/// Terminator - ends a basic block and determines control flow
272#[derive(Debug, Clone)]
273pub enum Terminator {
274    /// Return from function (using Local::RETURN)
275    Return { span: Span },
276
277    /// Unconditional jump
278    Goto { target: BasicBlockId, span: Span },
279
280    /// Conditional branch
281    SwitchInt {
282        /// The discriminant to switch on
283        discriminant: Operand,
284        /// Possible targets
285        targets: SwitchTargets,
286        span: Span,
287    },
288
289    /// Function call
290    Call {
291        /// The function to call
292        func: Operand,
293        /// Name of the function (for built-ins and external functions)
294        func_name: Option<Symbol>,
295        /// Arguments
296        args: Vec<Operand>,
297        /// Where to store the return value
298        destination: Place,
299        /// Block to jump to after call
300        target: BasicBlockId,
301        /// Whether this is a virtual dispatch call on a trait object
302        is_virtual: bool,
303        span: Span,
304    },
305
306    /// Panic/abort
307    Abort { span: Span },
308
309    /// Unreachable code
310    Unreachable { span: Span },
311
312    // === Concurrency Terminators ===
313    /// Spawn a new task
314    /// The spawned function runs concurrently and returns a TaskHandle<T>
315    Spawn {
316        /// The function/closure to spawn
317        func: Operand,
318        /// Arguments to the function
319        args: Vec<Operand>,
320        /// Where to store the task handle
321        destination: Place,
322        /// Block to continue in after spawn
323        target: BasicBlockId,
324        span: Span,
325    },
326
327    /// Await a task handle to get its result
328    /// Blocks until the task completes
329    TaskAwait {
330        /// Task handle to await
331        task: Operand,
332        /// Where to store the result
333        destination: Place,
334        /// Block to continue in after await
335        target: BasicBlockId,
336        span: Span,
337    },
338
339    /// Enter a task group (structured concurrency)
340    /// Creates a new scope where child tasks must complete
341    TaskGroupEnter {
342        /// Where to store the task group handle
343        destination: Place,
344        /// Block to continue in (the task group body)
345        body: BasicBlockId,
346        /// Block to jump to after all tasks complete
347        join_block: BasicBlockId,
348        span: Span,
349    },
350
351    /// Exit a task group, waiting for all spawned tasks
352    TaskGroupExit {
353        /// The task group handle
354        group: Operand,
355        /// Block to continue in after all tasks complete
356        target: BasicBlockId,
357        span: Span,
358    },
359
360    /// Blocking receive from a channel
361    /// Suspends the current task until a value is available
362    ChannelRecv {
363        /// Channel to receive from
364        channel: Operand,
365        /// Where to store the received value
366        destination: Place,
367        /// Block to continue in after receive
368        target: BasicBlockId,
369        /// Block to jump to if channel is closed
370        closed_target: BasicBlockId,
371        span: Span,
372    },
373
374    /// Blocking send to a channel
375    /// Suspends the current task if the channel is full
376    ChannelSend {
377        /// Channel to send to
378        channel: Operand,
379        /// Value to send
380        value: Operand,
381        /// Block to continue in after send
382        target: BasicBlockId,
383        /// Block to jump to if channel is closed
384        closed_target: BasicBlockId,
385        span: Span,
386    },
387
388    /// Select on multiple channel operations
389    /// Waits for one of the operations to become ready
390    Select {
391        /// Arms to select from
392        arms: Vec<SelectArm>,
393        /// Default block (if all operations would block and there's a default)
394        default: Option<BasicBlockId>,
395        /// Where to store the received value (if any)
396        destination: Place,
397        /// Where to store which arm was selected (index)
398        selected_arm: Place,
399        span: Span,
400    },
401
402    /// Cancel the current task
403    Cancel { span: Span },
404}
405
406/// Switch targets for conditional branches
407#[derive(Debug, Clone)]
408pub struct SwitchTargets {
409    /// (value, target) pairs
410    pub branches: Vec<(u128, BasicBlockId)>,
411    /// Default target (if no branch matches)
412    pub otherwise: BasicBlockId,
413}
414
415impl SwitchTargets {
416    /// Create a simple if-else (true -> then_block, false -> else_block)
417    pub fn if_else(then_block: BasicBlockId, else_block: BasicBlockId) -> Self {
418        Self {
419            branches: vec![(1, then_block)], // true = 1
420            otherwise: else_block,
421        }
422    }
423
424    /// Create switch with multiple branches
425    pub fn new(branches: Vec<(u128, BasicBlockId)>, otherwise: BasicBlockId) -> Self {
426        Self {
427            branches,
428            otherwise,
429        }
430    }
431}
432
433/// Place - an L-value (location that can be assigned to)
434#[derive(Debug, Clone, PartialEq, Eq, Hash)]
435pub struct Place {
436    pub local: Local,
437    pub projection: Vec<PlaceElem>,
438}
439
440impl Place {
441    /// Create a place from just a local
442    pub fn from_local(local: Local) -> Self {
443        Self {
444            local,
445            projection: Vec::new(),
446        }
447    }
448
449    /// Create the return place
450    pub fn return_place() -> Self {
451        Self::from_local(Local::RETURN)
452    }
453
454    /// Add a projection
455    pub fn project(mut self, elem: PlaceElem) -> Self {
456        self.projection.push(elem);
457        self
458    }
459
460    /// Dereference this place
461    pub fn deref(self) -> Self {
462        self.project(PlaceElem::Deref)
463    }
464
465    /// Field access
466    pub fn field(self, field: FieldIdx) -> Self {
467        self.project(PlaceElem::Field(field))
468    }
469
470    /// Index access
471    pub fn index(self, local: Local) -> Self {
472        self.project(PlaceElem::Index(local))
473    }
474}
475
476impl fmt::Display for Place {
477    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
478        write!(f, "{}", self.local)?;
479        for elem in &self.projection {
480            match elem {
481                PlaceElem::Deref => write!(f, ".*")?,
482                PlaceElem::Field(field) => write!(f, ".{}", field.0)?,
483                PlaceElem::Index(local) => write!(f, "[{}]", local)?,
484                PlaceElem::Downcast(name) => write!(f, " as {}", name.as_str())?,
485            }
486        }
487        Ok(())
488    }
489}
490
491/// Place projection element
492#[derive(Debug, Clone, PartialEq, Eq, Hash)]
493pub enum PlaceElem {
494    /// Dereference (*place)
495    Deref,
496    /// Field access (place.field)
497    Field(FieldIdx),
498    /// Index (place[index])
499    Index(Local),
500    /// Downcast to a specific enum variant (for tagged union field access)
501    Downcast(Symbol),
502}
503
504/// SIMD operation kind for vectorized operations
505#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
506pub enum SimdOp {
507    /// SIMD addition
508    Add,
509    /// SIMD subtraction
510    Sub,
511    /// SIMD multiplication
512    Mul,
513    /// SIMD division
514    Div,
515    /// Fused multiply-add (a * b + c)
516    Fma,
517}
518
519impl SimdOp {
520    /// Convert from BinOp if possible
521    pub fn from_binop(op: &BinOp) -> Option<Self> {
522        match op {
523            BinOp::Add => Some(SimdOp::Add),
524            BinOp::Sub => Some(SimdOp::Sub),
525            BinOp::Mul => Some(SimdOp::Mul),
526            BinOp::Div => Some(SimdOp::Div),
527            _ => None,
528        }
529    }
530
531    /// Convert back to BinOp
532    pub fn to_binop(&self) -> Option<BinOp> {
533        match self {
534            SimdOp::Add => Some(BinOp::Add),
535            SimdOp::Sub => Some(BinOp::Sub),
536            SimdOp::Mul => Some(BinOp::Mul),
537            SimdOp::Div => Some(BinOp::Div),
538            SimdOp::Fma => None, // FMA has no direct BinOp equivalent
539        }
540    }
541}
542
543/// SIMD lane width for vectorized operations
544#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
545pub enum SimdWidth {
546    /// 2 elements (64-bit total for 32-bit elements)
547    X2,
548    /// 4 elements (128-bit SSE)
549    X4,
550    /// 8 elements (256-bit AVX/AVX2)
551    X8,
552    /// 16 elements (512-bit AVX-512)
553    X16,
554}
555
556impl SimdWidth {
557    /// Get the number of elements in this SIMD width
558    pub fn element_count(&self) -> usize {
559        match self {
560            SimdWidth::X2 => 2,
561            SimdWidth::X4 => 4,
562            SimdWidth::X8 => 8,
563            SimdWidth::X16 => 16,
564        }
565    }
566
567    /// Create from element count if it's a valid SIMD width
568    pub fn from_count(count: usize) -> Option<Self> {
569        match count {
570            2 => Some(SimdWidth::X2),
571            4 => Some(SimdWidth::X4),
572            8 => Some(SimdWidth::X8),
573            16 => Some(SimdWidth::X16),
574            _ => None,
575        }
576    }
577}
578
579/// Channel operation kind for select expressions
580#[derive(Debug, Clone)]
581pub enum ChannelOp {
582    /// Receive from channel
583    Recv {
584        /// Channel to receive from
585        channel: Operand,
586    },
587    /// Send to channel
588    Send {
589        /// Channel to send to
590        channel: Operand,
591        /// Value to send
592        value: Operand,
593    },
594    /// Timeout operation
595    Timeout {
596        /// Duration in nanoseconds
597        duration_ns: u64,
598    },
599}
600
601/// Select arm for select expressions
602#[derive(Debug, Clone)]
603pub struct SelectArm {
604    /// The channel operation for this arm
605    pub operation: ChannelOp,
606    /// Block to jump to when this arm is selected
607    pub target: BasicBlockId,
608}
609
610/// Rvalue - an expression that produces a value
611#[derive(Debug, Clone)]
612pub enum Rvalue {
613    /// Use an operand (copy or move)
614    Use(Operand),
615
616    /// Binary operation
617    BinaryOp {
618        op: BinOp,
619        left: Operand,
620        right: Operand,
621    },
622
623    /// Unary operation
624    UnaryOp { op: UnOp, operand: Operand },
625
626    /// SIMD binary operation on vectors
627    ///
628    /// This represents a vectorized operation that processes multiple elements
629    /// in parallel. The operands are arrays/slices containing `width` elements.
630    SimdBinaryOp {
631        /// The SIMD operation to perform
632        op: SimdOp,
633        /// Element type (e.g., f32, i32)
634        element_ty: Ty,
635        /// Vector width (number of elements processed in parallel)
636        width: SimdWidth,
637        /// Left operand vector (array or pointer to contiguous elements)
638        left: Operand,
639        /// Right operand vector (array or pointer to contiguous elements)
640        right: Operand,
641    },
642
643    /// SIMD load - gather elements from memory into a vector register
644    SimdLoad {
645        /// Element type
646        element_ty: Ty,
647        /// Vector width
648        width: SimdWidth,
649        /// Source place to load from (must be an array or contiguous memory)
650        source: Place,
651    },
652
653    /// SIMD store - scatter elements from a vector register to memory
654    SimdStore {
655        /// Element type
656        element_ty: Ty,
657        /// Vector width
658        width: SimdWidth,
659        /// Value to store (SIMD vector operand)
660        value: Operand,
661        /// Destination place (must be an array or contiguous memory)
662        dest: Place,
663    },
664
665    /// SIMD splat - broadcast a scalar value to all lanes
666    SimdSplat {
667        /// Element type
668        element_ty: Ty,
669        /// Vector width
670        width: SimdWidth,
671        /// Scalar value to broadcast
672        value: Operand,
673    },
674
675    /// Create a reference (&place or &mut place)
676    Ref { mutable: bool, place: Place },
677
678    /// Aggregate construction (struct, tuple, array)
679    Aggregate {
680        kind: AggregateKind,
681        operands: Vec<Operand>,
682    },
683
684    /// Type cast
685    Cast {
686        operand: Operand,
687        target_ty: Ty,
688        kind: CastKind,
689    },
690
691    /// Discriminant (get enum variant)
692    Discriminant { place: Place },
693
694    /// Length of array/slice
695    Len { place: Place },
696
697    // === Concurrency Operations ===
698    /// Create a bounded channel
699    ChannelCreate {
700        /// Element type for the channel
701        element_ty: Ty,
702        /// Channel capacity (number of elements)
703        capacity: u64,
704    },
705
706    /// Receive from a channel (non-blocking check)
707    /// Returns Option<T> - Some(value) if value available, None otherwise
708    ChannelTryRecv {
709        /// Channel to receive from
710        channel: Operand,
711    },
712
713    /// Send to a channel (non-blocking check)
714    /// Returns bool - true if sent successfully, false if channel full
715    ChannelTrySend {
716        /// Channel to send to
717        channel: Operand,
718        /// Value to send
719        value: Operand,
720    },
721
722    /// Get the sender handle from a channel
723    ChannelSender {
724        /// Channel to get sender from
725        channel: Operand,
726    },
727
728    /// Get the receiver handle from a channel
729    ChannelReceiver {
730        /// Channel to get receiver from
731        channel: Operand,
732    },
733
734    /// Close a channel
735    ChannelClose {
736        /// Channel to close
737        channel: Operand,
738    },
739
740    /// Check if cancelled
741    IsCancelled,
742
743    /// Get the current task handle
744    CurrentTask,
745
746    /// Try operator: unwrap Result/Option, panicking on error
747    Try(Operand),
748
749    /// Enum unit variant construction (like TokenKind::Arrow)
750    EnumVariant {
751        /// Name of the variant
752        variant: Symbol,
753        /// Field operands (empty for unit variants)
754        fields: Vec<Operand>,
755    },
756}
757
758/// Aggregate construction kind
759#[derive(Debug, Clone)]
760pub enum AggregateKind {
761    /// Tuple
762    Tuple,
763    /// Array [T; N]
764    Array(Ty),
765    /// Struct (HirId, optional variant name for enum variant struct literals)
766    Struct(HirId, Option<joule_common::Symbol>),
767    /// Closure environment (closure id)
768    Closure(u32),
769}
770
771/// Type of cast
772#[derive(Debug, Clone, Copy, PartialEq, Eq)]
773pub enum CastKind {
774    /// Integer to integer
775    IntToInt,
776    /// Float to float
777    FloatToFloat,
778    /// Integer to float
779    IntToFloat,
780    /// Float to integer
781    FloatToInt,
782    /// Pointer cast
783    PtrToPtr,
784    /// Bitcast (reinterpret bits without conversion, e.g., i64 bits → f64)
785    Bitcast,
786}
787
788/// Operand - a value used in an operation
789#[derive(Debug, Clone)]
790pub enum Operand {
791    /// Copy a value from a place
792    Copy(Place),
793    /// Move a value from a place
794    Move(Place),
795    /// Constant value
796    Constant(Constant),
797}
798
799impl Operand {
800    /// Create an operand from a local (uses copy for Copy types)
801    pub fn from_local(local: Local) -> Self {
802        Self::Copy(Place::from_local(local))
803    }
804}
805
806/// Constant value
807#[derive(Debug, Clone)]
808pub struct Constant {
809    pub literal: Literal,
810    pub ty: Ty,
811    pub span: Span,
812}
813
814impl Constant {
815    pub fn new(literal: Literal, ty: Ty, span: Span) -> Self {
816        Self { literal, ty, span }
817    }
818}
819
820/// Type representation in MIR (simplified from HIR)
821#[derive(Debug, Clone, PartialEq, Eq, Hash)]
822pub enum Ty {
823    /// Boolean
824    Bool,
825    /// Character
826    Char,
827    /// Integer (signed)
828    Int(IntTy),
829    /// Integer (unsigned)
830    Uint(UintTy),
831    /// Floating point
832    Float(FloatTy),
833    /// Reference
834    Ref { mutable: bool, inner: Box<Ty> },
835    /// Raw pointer
836    RawPtr { mutable: bool, inner: Box<Ty> },
837    /// Array [T; N]
838    Array { element: Box<Ty>, size: u64 },
839    /// Slice [T]
840    Slice { element: Box<Ty> },
841    /// Tuple
842    Tuple(Vec<Ty>),
843    /// Function pointer
844    FnPtr { params: Vec<Ty>, return_ty: Box<Ty> },
845    /// Named type (struct, enum)
846    Named { def_id: HirId, name: Symbol },
847    /// Generic type instantiation (e.g., Vec<i32>)
848    Generic {
849        def_id: HirId,
850        name: Symbol,
851        args: Vec<Ty>,
852    },
853    /// Type parameter reference (e.g., T in fn foo<T>)
854    TypeParam { index: u32, name: Symbol },
855    /// Unit ()
856    Unit,
857    /// Never !
858    Never,
859
860    /// Trait object type (dyn Trait)
861    TraitObject { trait_name: Symbol },
862
863    // === Concurrency Types ===
864    /// Task handle - represents a spawned task that will produce a value
865    Task { result_ty: Box<Ty> },
866
867    /// Task group handle - for structured concurrency
868    TaskGroup,
869
870    /// Channel - bidirectional communication channel
871    Channel { element_ty: Box<Ty> },
872
873    /// Channel sender half
874    Sender { element_ty: Box<Ty> },
875
876    /// Channel receiver half
877    Receiver { element_ty: Box<Ty> },
878
879    /// Closure type (id, params, return, captures)
880    Closure {
881        id: u32,
882        params: Vec<Ty>,
883        return_ty: Box<Ty>,
884        capture_tys: Vec<Ty>,
885    },
886
887    /// Union type: `i64 | f64 | String` — tagged discriminated union
888    Union { variants: Vec<Ty> },
889
890    // === N-Dimensional Array Types ===
891    /// Owned N-dimensional array: NDArray[T; N] — heap-allocated contiguous data
892    NDArray { element: Box<Ty>, rank: u32 },
893    /// Borrowed N-dimensional view: NDView[T; N] — zero-copy slice into an NDArray
894    NDView { element: Box<Ty>, rank: u32 },
895    /// Copy-on-write array: CowArray[T; N] — shared data, cloned on mutation
896    CowArray { element: Box<Ty>, rank: u32 },
897    /// Dynamic-rank array: DynArray[T] — rank determined at runtime
898    DynArray { element: Box<Ty> },
899
900    // === Const-Generic Types ===
901    /// Small vector with inline storage: SmallVec[T; N]
902    SmallVec { element: Box<Ty>, capacity: u32 },
903    /// SIMD vector type: Simd[T; N]
904    Simd { element: Box<Ty>, lanes: u32 },
905}
906
907impl Ty {
908    /// Convert from HIR type
909    pub fn from_hir(hir_ty: &joule_hir::Ty) -> Self {
910        match hir_ty {
911            joule_hir::Ty::Bool => Ty::Bool,
912            joule_hir::Ty::Char => Ty::Char,
913            joule_hir::Ty::Int(int_ty) => Ty::Int(*int_ty),
914            joule_hir::Ty::Uint(uint_ty) => Ty::Uint(*uint_ty),
915            joule_hir::Ty::Float(float_ty) => Ty::Float(*float_ty),
916            joule_hir::Ty::Ref { mutable, inner } => Ty::Ref {
917                mutable: *mutable,
918                inner: Box::new(Ty::from_hir(inner)),
919            },
920            joule_hir::Ty::Array { element, size } => Ty::Array {
921                element: Box::new(Ty::from_hir(element)),
922                size: *size,
923            },
924            joule_hir::Ty::Slice { element } => Ty::Slice {
925                element: Box::new(Ty::from_hir(element)),
926            },
927            joule_hir::Ty::Tuple(tys) => Ty::Tuple(tys.iter().map(Ty::from_hir).collect()),
928            joule_hir::Ty::Function { params, return_ty } => Ty::FnPtr {
929                params: params.iter().map(Ty::from_hir).collect(),
930                return_ty: Box::new(Ty::from_hir(return_ty)),
931            },
932            joule_hir::Ty::Named { def_id, name } => Ty::Named {
933                def_id: *def_id,
934                name: *name,
935            },
936            joule_hir::Ty::Unit => Ty::Unit,
937            joule_hir::Ty::Never => Ty::Never,
938            joule_hir::Ty::String => {
939                // String is a special case - use a named type with proper symbol
940                Ty::Named {
941                    def_id: HirId::new(0),
942                    name: Symbol::intern("String"),
943                }
944            }
945            joule_hir::Ty::Infer(_) | joule_hir::Ty::Error => {
946                // These should be resolved before MIR generation
947                Ty::Unit
948            }
949
950            // Concurrency types
951            joule_hir::Ty::Future { result_ty } => Ty::Task {
952                result_ty: Box::new(Ty::from_hir(result_ty)),
953            },
954            joule_hir::Ty::Task { result_ty } => Ty::Task {
955                result_ty: Box::new(Ty::from_hir(result_ty)),
956            },
957            joule_hir::Ty::TaskGroup => Ty::TaskGroup,
958            joule_hir::Ty::Channel { element_ty } => Ty::Channel {
959                element_ty: Box::new(Ty::from_hir(element_ty)),
960            },
961            joule_hir::Ty::Sender { element_ty } => Ty::Sender {
962                element_ty: Box::new(Ty::from_hir(element_ty)),
963            },
964            joule_hir::Ty::Receiver { element_ty } => Ty::Receiver {
965                element_ty: Box::new(Ty::from_hir(element_ty)),
966            },
967            joule_hir::Ty::Generic { def_id, name, args } => Ty::Generic {
968                def_id: *def_id,
969                name: *name,
970                args: args.iter().map(Ty::from_hir).collect(),
971            },
972            joule_hir::Ty::TypeParam { index, name } => Ty::TypeParam {
973                index: *index,
974                name: *name,
975            },
976            joule_hir::Ty::Closure {
977                id,
978                params,
979                return_ty,
980                capture_tys,
981            } => Ty::Closure {
982                id: id.0,
983                params: params.iter().map(Ty::from_hir).collect(),
984                return_ty: Box::new(Ty::from_hir(return_ty)),
985                capture_tys: capture_tys.iter().map(Ty::from_hir).collect(),
986            },
987            joule_hir::Ty::TraitObject { trait_name } => Ty::TraitObject {
988                trait_name: *trait_name,
989            },
990            joule_hir::Ty::Union { variants } => Ty::Union {
991                variants: variants.iter().map(Ty::from_hir).collect(),
992            },
993            joule_hir::Ty::Opaque { name, def_id } => Ty::Named {
994                def_id: *def_id,
995                name: *name,
996            },
997
998            // N-dimensional array types
999            joule_hir::Ty::NDArray { element, rank } => Ty::NDArray {
1000                element: Box::new(Ty::from_hir(element)),
1001                rank: *rank,
1002            },
1003            joule_hir::Ty::NDView { element, rank } => Ty::NDView {
1004                element: Box::new(Ty::from_hir(element)),
1005                rank: *rank,
1006            },
1007            joule_hir::Ty::CowArray { element, rank } => Ty::CowArray {
1008                element: Box::new(Ty::from_hir(element)),
1009                rank: *rank,
1010            },
1011            joule_hir::Ty::DynArray { element } => Ty::DynArray {
1012                element: Box::new(Ty::from_hir(element)),
1013            },
1014
1015            // Const-generic types
1016            joule_hir::Ty::SmallVec { element, capacity } => Ty::SmallVec {
1017                element: Box::new(Ty::from_hir(element)),
1018                capacity: *capacity,
1019            },
1020            joule_hir::Ty::Simd { element, lanes } => Ty::Simd {
1021                element: Box::new(Ty::from_hir(element)),
1022                lanes: *lanes,
1023            },
1024        }
1025    }
1026
1027    /// Check if this is a Copy type
1028    pub fn is_copy(&self) -> bool {
1029        match self {
1030            Ty::Bool
1031            | Ty::Char
1032            | Ty::Int(_)
1033            | Ty::Uint(_)
1034            | Ty::Float(_)
1035            | Ty::RawPtr { .. }
1036            | Ty::FnPtr { .. }
1037            | Ty::Unit
1038            | Ty::Never => true,
1039            Ty::Ref { .. } => true, // Shared references are Copy
1040            Ty::Tuple(tys) => tys.iter().all(|ty| ty.is_copy()),
1041            Ty::Array { element, .. } => element.is_copy(),
1042            _ => false,
1043        }
1044    }
1045
1046    /// Check if this is a zero-sized type
1047    pub fn is_zst(&self) -> bool {
1048        match self {
1049            Ty::Unit | Ty::Never => true,
1050            Ty::Tuple(tys) => tys.iter().all(|ty| ty.is_zst()),
1051            Ty::Array { element, size } => *size == 0 || element.is_zst(),
1052            _ => false,
1053        }
1054    }
1055}
1056
1057/// Integer types
1058pub use joule_hir::IntTy;
1059
1060/// Unsigned integer types
1061pub use joule_hir::UintTy;
1062
1063/// Float types
1064pub use joule_hir::FloatTy;
1065
1066/// MIR context - holds all MIR for a compilation unit
1067#[derive(Debug, Clone)]
1068pub struct MirContext {
1069    /// All functions
1070    pub functions: IndexMap<FunctionId, FunctionMIR>,
1071    /// Struct definitions
1072    pub structs: Vec<joule_hir::Struct>,
1073    /// Enum definitions
1074    pub enums: Vec<joule_hir::Enum>,
1075    /// Constant definitions
1076    pub consts: Vec<joule_hir::Const>,
1077    /// Static variable definitions
1078    pub statics: Vec<joule_hir::Static>,
1079    /// Extern function declarations
1080    pub extern_fns: Vec<joule_hir::Function>,
1081    /// Trait definitions
1082    pub traits: Vec<joule_hir::Trait>,
1083    /// Impl blocks
1084    pub impls: Vec<joule_hir::Impl>,
1085    /// Type aliases (name → resolved HIR type)
1086    pub type_aliases: std::collections::HashMap<String, joule_hir::Ty>,
1087}
1088
1089impl MirContext {
1090    /// Create a new MIR context
1091    pub fn new() -> Self {
1092        Self {
1093            functions: IndexMap::new(),
1094            structs: Vec::new(),
1095            enums: Vec::new(),
1096            consts: Vec::new(),
1097            statics: Vec::new(),
1098            extern_fns: Vec::new(),
1099            traits: Vec::new(),
1100            impls: Vec::new(),
1101            type_aliases: std::collections::HashMap::new(),
1102        }
1103    }
1104
1105    /// Add a function to the context
1106    pub fn add_function(&mut self, func: FunctionMIR) -> FunctionId {
1107        let id = func.id;
1108        self.functions.insert(id, func);
1109        id
1110    }
1111
1112    /// Get a function by ID
1113    pub fn get_function(&self, id: FunctionId) -> Option<&FunctionMIR> {
1114        self.functions.get(&id)
1115    }
1116
1117    /// Get a mutable function by ID
1118    pub fn get_function_mut(&mut self, id: FunctionId) -> Option<&mut FunctionMIR> {
1119        self.functions.get_mut(&id)
1120    }
1121}
1122
1123impl Default for MirContext {
1124    fn default() -> Self {
1125        Self::new()
1126    }
1127}
1128
1129#[cfg(test)]
1130mod tests {
1131    use super::*;
1132
1133    #[test]
1134    fn test_basic_block_creation() {
1135        let bb = BasicBlock::new(Terminator::Return {
1136            span: Span::dummy(),
1137        });
1138        assert_eq!(bb.statements.len(), 0);
1139        assert!(matches!(bb.terminator, Terminator::Return { .. }));
1140    }
1141
1142    #[test]
1143    fn test_place_construction() {
1144        let local = Local::new(1);
1145        let place = Place::from_local(local);
1146        assert_eq!(place.local, local);
1147        assert!(place.projection.is_empty());
1148
1149        let deref_place = place.clone().deref();
1150        assert_eq!(deref_place.projection.len(), 1);
1151        assert!(matches!(deref_place.projection[0], PlaceElem::Deref));
1152    }
1153
1154    #[test]
1155    fn test_function_mir_creation() {
1156        let mut func = FunctionMIR::new(
1157            FunctionId::new(0),
1158            Symbol::from_u32(0),
1159            Ty::Unit,
1160            Span::dummy(),
1161        );
1162
1163        // Add a local
1164        let local = func.add_local(LocalDecl::new(
1165            Some(Symbol::from_u32(1)),
1166            Ty::Int(IntTy::I32),
1167            false,
1168            Span::dummy(),
1169        ));
1170        assert_eq!(local, Local::new(1));
1171
1172        // Add a basic block
1173        let bb_id = func.add_block(BasicBlock::new(Terminator::Return {
1174            span: Span::dummy(),
1175        }));
1176        assert_eq!(bb_id, BasicBlockId::new(0));
1177    }
1178
1179    #[test]
1180    fn test_switch_targets_if_else() {
1181        let then_bb = BasicBlockId::new(1);
1182        let else_bb = BasicBlockId::new(2);
1183        let targets = SwitchTargets::if_else(then_bb, else_bb);
1184
1185        assert_eq!(targets.branches.len(), 1);
1186        assert_eq!(targets.branches[0], (1, then_bb));
1187        assert_eq!(targets.otherwise, else_bb);
1188    }
1189
1190    #[test]
1191    fn test_ty_is_copy() {
1192        assert!(Ty::Int(IntTy::I32).is_copy());
1193        assert!(Ty::Bool.is_copy());
1194        assert!(Ty::Unit.is_copy());
1195        assert!(
1196            Ty::Ref {
1197                mutable: false,
1198                inner: Box::new(Ty::Int(IntTy::I32))
1199            }
1200            .is_copy()
1201        );
1202        // Note: References (including &mut) are Copy in MIR
1203        // They copy the pointer, not the data
1204        assert!(
1205            Ty::Ref {
1206                mutable: true,
1207                inner: Box::new(Ty::Int(IntTy::I32))
1208            }
1209            .is_copy()
1210        );
1211    }
1212
1213    #[test]
1214    fn test_ty_is_zst() {
1215        assert!(Ty::Unit.is_zst());
1216        assert!(Ty::Never.is_zst());
1217        assert!(Ty::Tuple(vec![]).is_zst());
1218        assert!(
1219            Ty::Array {
1220                element: Box::new(Ty::Int(IntTy::I32)),
1221                size: 0
1222            }
1223            .is_zst()
1224        );
1225        assert!(!Ty::Int(IntTy::I32).is_zst());
1226    }
1227
1228    #[test]
1229    fn test_mir_context() {
1230        let mut ctx = MirContext::new();
1231
1232        let func = FunctionMIR::new(
1233            FunctionId::new(0),
1234            Symbol::from_u32(0),
1235            Ty::Unit,
1236            Span::dummy(),
1237        );
1238        let id = ctx.add_function(func);
1239
1240        assert!(ctx.get_function(id).is_some());
1241        assert_eq!(ctx.get_function(id).unwrap().id, id);
1242    }
1243
1244    #[test]
1245    fn test_local_display() {
1246        let ret = format!("{}", Local::RETURN);
1247        assert_eq!(ret, "_ret");
1248        let l1 = format!("{}", Local::new(1));
1249        assert_eq!(l1, "_1");
1250        let l42 = format!("{}", Local::new(42));
1251        assert_eq!(l42, "_42");
1252    }
1253
1254    #[test]
1255    fn test_basic_block_id_display() {
1256        let bb0 = format!("{}", BasicBlockId::START);
1257        assert_eq!(bb0, "bb0");
1258        let bb5 = format!("{}", BasicBlockId::new(5));
1259        assert_eq!(bb5, "bb5");
1260    }
1261
1262    #[test]
1263    fn test_place_display() {
1264        let local = Local::new(1);
1265        let place = Place::from_local(local);
1266        let p1 = format!("{}", place);
1267        assert_eq!(p1, "_1");
1268
1269        let field_place = place.clone().field(FieldIdx::new(0));
1270        let fp = format!("{}", field_place);
1271        assert_eq!(fp, "_1.0");
1272
1273        let deref_place = place.deref();
1274        let dp = format!("{}", deref_place);
1275        assert_eq!(dp, "_1.*");
1276    }
1277}