Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Core Implementation

Algorithm W represents one of the most elegant solutions to the type inference problem in functional programming languages. Developed by Robin Milner in 1978, it provides a sound and complete method for inferring the most general types in the Hindley-Milner type system. This chapter explores our Rust implementation, examining how the mathematical foundations translate into practical code that can handle lambda abstractions, function applications, let-polymorphism, and complex unification scenarios.

The core insight of Algorithm W lies in its systematic approach to type inference through constraint generation and unification. Rather than attempting to determine types through local analysis, the algorithm builds a global picture by generating type variables, collecting constraints, and then solving these constraints through unification. This approach ensures that we always find the most general type possible, a property crucial for supporting polymorphism in functional languages.

You’ll often see this algorithm (or the type system, confusingly enough) referred to by many names:

  • Hindley-Milner
  • Hindley-Damas-Milner
  • Damas-Milner
  • HM
  • Algorithm W

Typing Rules

Before diving into the implementation details, let’s establish the formal typing rules that govern the Hindley-Milner type system. We’ll be introducing mathematical symbols that capture the essence of type inference, but don’t worry, each symbol has a precise and intuitive meaning once you dive into the details.

  • \( \Gamma \) (Gamma) - The type environment, which maps variables to their types. It’s like a dictionary that remembers what we know about each variable’s type.

  • \( \vdash \) (Turnstile) - The “entails” or “proves” symbol. When we write \( \Gamma \vdash e : \tau \), we’re saying “in environment \( \Gamma \), expression \( e \) has type \( \tau \).”

  • \( \tau \) (Tau) - Represents monomorphic types like \( \text{Int} \), \( \text{Bool} \), or \( \text{Int} \to \text{Bool} \). These are concrete, fully-determined types.

  • \( \sigma \) (Sigma) - Represents polymorphic type schemes like \( \forall \alpha. \alpha \to \alpha \). These can be instantiated with different concrete types.

  • \( \forall \alpha \) (Forall Alpha) - Universal quantification over type variables. It means “for any type \( \alpha \).” This is how we express polymorphism.

  • \( \alpha, \beta, \gamma \) (Greek Letters) - Type variables that stand for unknown types during inference. Think of them as type-level unknowns that get solved.

  • \( [\tau/\alpha]\sigma \) - Type substitution, replacing all occurrences of type variable \( \alpha \) with type \( \tau \) in scheme \( \sigma \). This is how we instantiate polymorphic types.

  • \( S \) (Substitution) - A mapping from type variables to types, representing the solutions found by unification.

  • \( \text{gen}(\Gamma, \tau) \) - Generalization, which turns a monotype into a polytype by quantifying over type variables not present in the environment.

  • \( \text{inst}(\sigma) \) - Instantiation, which creates a fresh monotype from a polytype by replacing quantified variables with fresh type variables.

  • \( \text{ftv}(\tau) \) - Free type variables, the set of unbound type variables appearing in type \( \tau \).

  • \( \emptyset \) (Empty Set) - The empty substitution, representing no changes to types.

  • \( [\alpha \mapsto \tau] \) - A substitution that maps type variable \( \alpha \) to type \( \tau \).

  • \( S_1 \circ S_2 \) - Composition of substitutions, applying \( S_2 \) first, then \( S_1 \).

  • \( \notin \) (Not In) - Set membership negation, used in the occurs check to prevent infinite types.

Now that we have our symbolic toolkit, let’s see how these pieces work together to create the elegant machinery of Algorithm W.

Core Typing Rules

The variable rule looks up types from the environment: \[ \frac{x : σ \in Γ \quad τ = \text{inst}(σ)}{Γ ⊢ x : τ} \text{(T-Var)} \]

Lambda abstraction introduces new variable bindings: \[ \frac{Γ, x : α ⊢ e : τ \quad α \text{ fresh}}{Γ ⊢ λx. e : α → τ} \text{(T-Lam)} \]

Function application combines types through unification: \[ \frac{Γ ⊢ e₁ : τ₁ \quad Γ ⊢ e₂ : τ₂ \quad α \text{ fresh} \quad S = \text{unify}(τ₁, τ₂ → α)}{Γ ⊢ e₁ \, e₂ : S(α)} \text{(T-App)} \]

Let-polymorphism allows generalization: \[ \frac{Γ ⊢ e₁ : τ₁ \quad σ = \text{gen}(Γ, τ₁) \quad Γ, x : σ ⊢ e₂ : τ₂}{Γ ⊢ \text{let } x = e₁ \text{ in } e₂ : τ₂} \text{(T-Let)} \]

Literals have their corresponding base types: \[ \frac{}{Γ ⊢ n : \text{Int}} \text{(T-LitInt)} \]

\[ \frac{}{Γ ⊢ b : \text{Bool}} \text{(T-LitBool)} \]

These rules capture the essence of the Hindley-Milner type system, where we infer the most general types while supporting true polymorphism through let-generalization.

Abstract Syntax Trees

Our implementation begins with a careful modeling of both expressions and types as algebraic data types. The expression language extends the pure lambda calculus with practical constructs while maintaining the theoretical foundation.

#![allow(unused)]
fn main() {
#[derive(Debug, Clone, PartialEq)]
pub enum Lit {
    Int(i64),
    Bool(bool),
}
#[derive(Debug, Clone, PartialEq)]
pub enum Type {
    Var(String),
    Arrow(Box<Type>, Box<Type>),
    Int,
    Bool,
    Tuple(Vec<Type>),
}
#[derive(Debug, Clone, PartialEq)]
pub struct Scheme {
    pub vars: Vec<String>, // Quantified type variables
    pub ty: Type,          // The type being quantified over
}
impl std::fmt::Display for Expr {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Expr::Var(name) => write!(f, "{}", name),
            Expr::Lit(Lit::Int(n)) => write!(f, "{}", n),
            Expr::Lit(Lit::Bool(b)) => write!(f, "{}", b),
            Expr::Abs(param, body) => write!(f, "λ{}.{}", param, body),
            Expr::App(func, arg) => match (func.as_ref(), arg.as_ref()) {
                (Expr::Abs(_, _), _) => write!(f, "({}) {}", func, arg),
                (_, Expr::App(_, _)) => write!(f, "{} ({})", func, arg),
                (_, Expr::Abs(_, _)) => write!(f, "{} ({})", func, arg),
                _ => write!(f, "{} {}", func, arg),
            },
            Expr::Let(var, value, body) => {
                write!(f, "let {} = {} in {}", var, value, body)
            }
            Expr::Tuple(exprs) => {
                write!(f, "(")?;
                for (i, expr) in exprs.iter().enumerate() {
                    if i > 0 {
                        write!(f, ", ")?;
                    }
                    write!(f, "{}", expr)?;
                }
                write!(f, ")")
            }
        }
    }
}
impl std::fmt::Display for Type {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Type::Var(name) => write!(f, "{}", name),
            Type::Int => write!(f, "Int"),
            Type::Bool => write!(f, "Bool"),
            Type::Arrow(t1, t2) => {
                // Add parentheses around left side if it's an arrow to avoid ambiguity
                match t1.as_ref() {
                    Type::Arrow(_, _) => write!(f, "({}) → {}", t1, t2),
                    _ => write!(f, "{} → {}", t1, t2),
                }
            }
            Type::Tuple(types) => {
                write!(f, "(")?;
                for (i, ty) in types.iter().enumerate() {
                    if i > 0 {
                        write!(f, ", ")?;
                    }
                    write!(f, "{}", ty)?;
                }
                write!(f, ")")
            }
        }
    }
}
impl std::fmt::Display for Scheme {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        if self.vars.is_empty() {
            write!(f, "{}", self.ty)
        } else {
            write!(f, "forall {}. {}", self.vars.join(" "), self.ty)
        }
    }
}
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
    Var(String),
    App(Box<Expr>, Box<Expr>),
    Abs(String, Box<Expr>),
    Let(String, Box<Expr>, Box<Expr>),
    Lit(Lit),
    Tuple(Vec<Expr>),
}
}

The expression AST captures the essential constructs of our language. Variables (Var) and function abstractions (Abs) correspond directly to the lambda calculus. Function application (App) drives computation through beta reduction. The Let construct introduces local bindings with potential for polymorphic generalization, while literals (Lit) and tuples (Tuple) provide concrete data types that make the language practical for real programming tasks.

The type system mirrors this structure with its own AST that represents the types these expressions can have.

#![allow(unused)]
fn main() {
#[derive(Debug, Clone, PartialEq)]
pub enum Expr {
    Var(String),
    App(Box<Expr>, Box<Expr>),
    Abs(String, Box<Expr>),
    Let(String, Box<Expr>, Box<Expr>),
    Lit(Lit),
    Tuple(Vec<Expr>),
}
#[derive(Debug, Clone, PartialEq)]
pub enum Lit {
    Int(i64),
    Bool(bool),
}
#[derive(Debug, Clone, PartialEq)]
pub struct Scheme {
    pub vars: Vec<String>, // Quantified type variables
    pub ty: Type,          // The type being quantified over
}
impl std::fmt::Display for Expr {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Expr::Var(name) => write!(f, "{}", name),
            Expr::Lit(Lit::Int(n)) => write!(f, "{}", n),
            Expr::Lit(Lit::Bool(b)) => write!(f, "{}", b),
            Expr::Abs(param, body) => write!(f, "λ{}.{}", param, body),
            Expr::App(func, arg) => match (func.as_ref(), arg.as_ref()) {
                (Expr::Abs(_, _), _) => write!(f, "({}) {}", func, arg),
                (_, Expr::App(_, _)) => write!(f, "{} ({})", func, arg),
                (_, Expr::Abs(_, _)) => write!(f, "{} ({})", func, arg),
                _ => write!(f, "{} {}", func, arg),
            },
            Expr::Let(var, value, body) => {
                write!(f, "let {} = {} in {}", var, value, body)
            }
            Expr::Tuple(exprs) => {
                write!(f, "(")?;
                for (i, expr) in exprs.iter().enumerate() {
                    if i > 0 {
                        write!(f, ", ")?;
                    }
                    write!(f, "{}", expr)?;
                }
                write!(f, ")")
            }
        }
    }
}
impl std::fmt::Display for Type {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Type::Var(name) => write!(f, "{}", name),
            Type::Int => write!(f, "Int"),
            Type::Bool => write!(f, "Bool"),
            Type::Arrow(t1, t2) => {
                // Add parentheses around left side if it's an arrow to avoid ambiguity
                match t1.as_ref() {
                    Type::Arrow(_, _) => write!(f, "({}) → {}", t1, t2),
                    _ => write!(f, "{} → {}", t1, t2),
                }
            }
            Type::Tuple(types) => {
                write!(f, "(")?;
                for (i, ty) in types.iter().enumerate() {
                    if i > 0 {
                        write!(f, ", ")?;
                    }
                    write!(f, "{}", ty)?;
                }
                write!(f, ")")
            }
        }
    }
}
impl std::fmt::Display for Scheme {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        if self.vars.is_empty() {
            write!(f, "{}", self.ty)
        } else {
            write!(f, "forall {}. {}", self.vars.join(" "), self.ty)
        }
    }
}
#[derive(Debug, Clone, PartialEq)]
pub enum Type {
    Var(String),
    Arrow(Box<Type>, Box<Type>),
    Int,
    Bool,
    Tuple(Vec<Type>),
}
}

Type variables (Type::Var) serve as placeholders during inference, eventually getting instantiated to concrete types through unification. Arrow types (Type::Arrow) represent function types, encoding both parameter and return types. Base types like Int and Bool provide the foundation, while tuple types (Type::Tuple) support structured data. The recursive nature of these types allows us to express arbitrarily complex type structures, from simple integers to higher-order functions that manipulate other functions.

Type Inference Algorithm

Algorithm W operates on several fundamental data structures that capture the essential concepts of type inference. These type aliases provide names for the core abstractions and make the algorithm’s implementation more readable.

The type variable abstraction represents unknown types that will be resolved during inference. Term variables represent program variables that appear in expressions. The type environment maps term variables to their types, while substitutions map type variables to concrete types.

#![allow(unused)]
fn main() {
pub type TyVar = String;
pub type TmVar = String;
pub type Env = BTreeMap<TmVar, Scheme>;  // Now stores schemes, not types
pub type Subst = HashMap<TyVar, Type>;
}

These aliases encapsulate the fundamental data flow in Algorithm W. Type variables like t0, t1, and t2 serve as placeholders that get unified with concrete types as inference progresses. Term variables represent the actual identifiers in source programs. The environment now tracks polymorphic type schemes rather than just types, enabling proper let-polymorphism, while substitutions record the solutions discovered by unification.

The choice of String for both type and term variables reflects the simplicity of our implementation. In a full implementation, systems often use more complex representations like de Bruijn indices for type variables or interned strings for performance, but strings provide clarity for understanding the fundamental algorithms.

The heart of our Algorithm W implementation lies in the TypeInference struct, which maintains the state necessary for sound type inference across an entire program.

#![allow(unused)]
fn main() {
use std::collections::{BTreeMap, HashMap, HashSet};
use std::fmt;
use crate::ast::{Expr, Lit, Scheme, Type};
use crate::errors::{InferenceError, Result};
pub type TyVar = String;
pub type TmVar = String;
pub type Env = BTreeMap<TmVar, Scheme>;
pub type Subst = HashMap<TyVar, Type>;
#[derive(Debug)]
pub struct InferenceTree {
    pub rule: String,
    pub input: String,
    pub output: String,
    pub children: Vec<InferenceTree>,
}
impl InferenceTree {
    fn new(rule: &str, input: &str, output: &str, children: Vec<InferenceTree>) -> Self {
        Self {
            rule: rule.to_string(),
            input: input.to_string(),
            output: output.to_string(),
            children,
        }
    }
}
impl fmt::Display for InferenceTree {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        self.display_with_indent(f, 0)
    }
}
impl InferenceTree {
    fn display_with_indent(&self, f: &mut fmt::Formatter, indent: usize) -> fmt::Result {
        let prefix = "  ".repeat(indent);
        writeln!(
            f,
            "{}{}: {} => {}",
            prefix, self.rule, self.input, self.output
        )?;
        for child in &self.children {
            child.display_with_indent(f, indent + 1)?;
        }
        Ok(())
    }
}
impl Default for TypeInference {
    fn default() -> Self {
        Self::new()
    }
}
#[allow(clippy::only_used_in_recursion)]
impl TypeInference {
    pub fn new() -> Self {
        Self { counter: 0 }
    }

    fn fresh_tyvar(&mut self) -> TyVar {
        let var = format!("t{}", self.counter);
        self.counter += 1;
        var
    }

    fn pretty_env(&self, env: &Env) -> String {
        if env.is_empty() {
            "{}".to_string()
        } else {
            let entries: Vec<String> = env.iter().map(|(k, v)| format!("{}: {}", k, v)).collect();
            format!("{{{}}}", entries.join(", "))
        }
    }

    fn pretty_subst(&self, subst: &Subst) -> String {
        if subst.is_empty() {
            "{}".to_string()
        } else {
            let entries: Vec<String> = subst.iter().map(|(k, v)| format!("{}/{}", v, k)).collect();
            format!("{{{}}}", entries.join(", "))
        }
    }

    fn apply_subst(&self, subst: &Subst, ty: &Type) -> Type {
        match ty {
            Type::Var(name) => subst.get(name).cloned().unwrap_or_else(|| ty.clone()),
            Type::Arrow(t1, t2) => Type::Arrow(
                Box::new(self.apply_subst(subst, t1)),
                Box::new(self.apply_subst(subst, t2)),
            ),
            Type::Tuple(types) => {
                Type::Tuple(types.iter().map(|t| self.apply_subst(subst, t)).collect())
            }
            Type::Int | Type::Bool => ty.clone(),
        }
    }

    fn apply_subst_scheme(&self, subst: &Subst, scheme: &Scheme) -> Scheme {
        // Remove bindings for quantified variables to avoid capture
        let mut filtered_subst = subst.clone();
        for var in &scheme.vars {
            filtered_subst.remove(var);
        }
        Scheme {
            vars: scheme.vars.clone(),
            ty: self.apply_subst(&filtered_subst, &scheme.ty),
        }
    }

    fn apply_subst_env(&self, subst: &Subst, env: &Env) -> Env {
        env.iter()
            .map(|(k, v)| (k.clone(), self.apply_subst_scheme(subst, v)))
            .collect()
    }

    fn compose_subst(&self, s1: &Subst, s2: &Subst) -> Subst {
        let mut result = s1.clone();
        for (k, v) in s2 {
            result.insert(k.clone(), self.apply_subst(s1, v));
        }
        result
    }

    fn free_type_vars(&self, ty: &Type) -> HashSet<TyVar> {
        match ty {
            Type::Var(name) => {
                let mut set = HashSet::new();
                set.insert(name.clone());
                set
            }
            Type::Arrow(t1, t2) => {
                let mut set = self.free_type_vars(t1);
                set.extend(self.free_type_vars(t2));
                set
            }
            Type::Tuple(types) => {
                let mut set = HashSet::new();
                for t in types {
                    set.extend(self.free_type_vars(t));
                }
                set
            }
            Type::Int | Type::Bool => HashSet::new(),
        }
    }

    fn free_type_vars_scheme(&self, scheme: &Scheme) -> HashSet<TyVar> {
        let mut set = self.free_type_vars(&scheme.ty);
        // Remove quantified variables
        for var in &scheme.vars {
            set.remove(var);
        }
        set
    }

    fn free_type_vars_env(&self, env: &Env) -> HashSet<TyVar> {
        let mut set = HashSet::new();
        for scheme in env.values() {
            set.extend(self.free_type_vars_scheme(scheme));
        }
        set
    }

    fn generalize(&self, env: &Env, ty: &Type) -> Scheme {
        let type_vars = self.free_type_vars(ty);
        let env_vars = self.free_type_vars_env(env);
        let mut free_vars: Vec<_> = type_vars.difference(&env_vars).cloned().collect();
        free_vars.sort(); // Sort for deterministic behavior

        Scheme {
            vars: free_vars,
            ty: ty.clone(),
        }
    }

    fn instantiate(&mut self, scheme: &Scheme) -> Type {
        // Create fresh type variables for each quantified variable
        let mut subst = HashMap::new();
        for var in &scheme.vars {
            let fresh = self.fresh_tyvar();
            subst.insert(var.clone(), Type::Var(fresh));
        }

        self.apply_subst(&subst, &scheme.ty)
    }

    fn occurs_check(&self, var: &TyVar, ty: &Type) -> bool {
        match ty {
            Type::Var(name) => name == var,
            Type::Arrow(t1, t2) => self.occurs_check(var, t1) || self.occurs_check(var, t2),
            Type::Tuple(types) => types.iter().any(|t| self.occurs_check(var, t)),
            Type::Int | Type::Bool => false,
        }
    }

    fn unify(&self, t1: &Type, t2: &Type) -> Result<(Subst, InferenceTree)> {
        let input = format!("{} ~ {}", t1, t2);

        match (t1, t2) {
            (Type::Int, Type::Int) | (Type::Bool, Type::Bool) => {
                let tree = InferenceTree::new("Unify-Base", &input, "{}", vec![]);
                Ok((HashMap::new(), tree))
            }
            (Type::Var(v), ty) | (ty, Type::Var(v)) => {
                if ty == &Type::Var(v.clone()) {
                    let tree = InferenceTree::new("Unify-Var-Same", &input, "{}", vec![]);
                    Ok((HashMap::new(), tree))
                } else if self.occurs_check(v, ty) {
                    Err(InferenceError::OccursCheck {
                        var: v.clone(),
                        ty: ty.clone(),
                    })
                } else {
                    let mut subst = HashMap::new();
                    subst.insert(v.clone(), ty.clone());
                    let output = format!("{{{}/{}}}", ty, v);
                    let tree = InferenceTree::new("Unify-Var", &input, &output, vec![]);
                    Ok((subst, tree))
                }
            }
            (Type::Arrow(a1, a2), Type::Arrow(b1, b2)) => {
                let (s1, tree1) = self.unify(a1, b1)?;
                let a2_subst = self.apply_subst(&s1, a2);
                let b2_subst = self.apply_subst(&s1, b2);
                let (s2, tree2) = self.unify(&a2_subst, &b2_subst)?;
                let final_subst = self.compose_subst(&s2, &s1);
                let output = self.pretty_subst(&final_subst);
                let tree = InferenceTree::new("Unify-Arrow", &input, &output, vec![tree1, tree2]);
                Ok((final_subst, tree))
            }
            (Type::Tuple(ts1), Type::Tuple(ts2)) => {
                if ts1.len() != ts2.len() {
                    return Err(InferenceError::TupleLengthMismatch {
                        left_len: ts1.len(),
                        right_len: ts2.len(),
                    });
                }

                let mut subst = HashMap::new();
                let mut trees = Vec::new();

                for (t1, t2) in ts1.iter().zip(ts2.iter()) {
                    let t1_subst = self.apply_subst(&subst, t1);
                    let t2_subst = self.apply_subst(&subst, t2);
                    let (s, tree) = self.unify(&t1_subst, &t2_subst)?;
                    subst = self.compose_subst(&s, &subst);
                    trees.push(tree);
                }

                let output = self.pretty_subst(&subst);
                let tree = InferenceTree::new("Unify-Tuple", &input, &output, trees);
                Ok((subst, tree))
            }
            _ => Err(InferenceError::UnificationFailure {
                expected: t1.clone(),
                actual: t2.clone(),
            }),
        }
    }

    pub fn infer(&mut self, env: &Env, expr: &Expr) -> Result<(Subst, Type, InferenceTree)> {
        match expr {
            Expr::Lit(Lit::Int(_)) => self.infer_lit_int(env, expr),
            Expr::Lit(Lit::Bool(_)) => self.infer_lit_bool(env, expr),
            Expr::Var(name) => self.infer_var(env, expr, name),
            Expr::Abs(param, body) => self.infer_abs(env, expr, param, body),
            Expr::App(func, arg) => self.infer_app(env, expr, func, arg),
            Expr::Let(var, value, body) => self.infer_let(env, expr, var, value, body),
            Expr::Tuple(exprs) => self.infer_tuple(env, expr, exprs),
        }
    }

    /// T-LitInt: ─────────────────
    ///           Γ ⊢ n : Int
    fn infer_lit_int(&mut self, env: &Env, expr: &Expr) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);
        let tree = InferenceTree::new("T-Int", &input, "Int", vec![]);
        Ok((HashMap::new(), Type::Int, tree))
    }

    /// T-LitBool: ─────────────────
    ///            Γ ⊢ b : Bool
    fn infer_lit_bool(&mut self, env: &Env, expr: &Expr) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);
        let tree = InferenceTree::new("T-Bool", &input, "Bool", vec![]);
        Ok((HashMap::new(), Type::Bool, tree))
    }

    /// T-Var: x : σ ∈ Γ    τ = inst(σ)
    ///        ─────────────────────────
    ///               Γ ⊢ x : τ
    fn infer_var(
        &mut self,
        env: &Env,
        expr: &Expr,
        name: &str,
    ) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

        match env.get(name) {
            Some(scheme) => {
                let instantiated = self.instantiate(scheme);
                let output = format!("{}", instantiated);
                let tree = InferenceTree::new("T-Var", &input, &output, vec![]);
                Ok((HashMap::new(), instantiated, tree))
            }
            None => Err(InferenceError::UnboundVariable {
                name: name.to_string(),
            }),
        }
    }

    /// T-Lam: Γ, x : α ⊢ e : τ    α fresh
    ///        ─────────────────────────────
    ///           Γ ⊢ λx. e : α → τ
    fn infer_abs(
        &mut self,
        env: &Env,
        expr: &Expr,
        param: &str,
        body: &Expr,
    ) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

        let param_type = Type::Var(self.fresh_tyvar());
        let mut new_env = env.clone();
        // Insert a monomorphic scheme for the parameter
        let param_scheme = Scheme {
            vars: vec![],
            ty: param_type.clone(),
        };
        new_env.insert(param.to_string(), param_scheme);

        let (s1, body_type, tree1) = self.infer(&new_env, body)?;
        let param_type_subst = self.apply_subst(&s1, &param_type);
        let result_type = Type::Arrow(Box::new(param_type_subst), Box::new(body_type));

        let output = format!("{}", result_type);
        let tree = InferenceTree::new("T-Abs", &input, &output, vec![tree1]);
        Ok((s1, result_type, tree))
    }

    /// T-App: Γ ⊢ e₁ : τ₁    Γ ⊢ e₂ : τ₂    α fresh    S = unify(τ₁, τ₂ → α)
    ///        ──────────────────────────────────────────────────────────────
    ///                            Γ ⊢ e₁ e₂ : S(α)
    fn infer_app(
        &mut self,
        env: &Env,
        expr: &Expr,
        func: &Expr,
        arg: &Expr,
    ) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

        let result_type = Type::Var(self.fresh_tyvar());

        let (s1, func_type, tree1) = self.infer(env, func)?;
        let env_subst = self.apply_subst_env(&s1, env);
        let (s2, arg_type, tree2) = self.infer(&env_subst, arg)?;

        let func_type_subst = self.apply_subst(&s2, &func_type);
        let expected_func_type = Type::Arrow(Box::new(arg_type), Box::new(result_type.clone()));

        let (s3, tree3) = self.unify(&func_type_subst, &expected_func_type)?;

        let final_subst = self.compose_subst(&s3, &self.compose_subst(&s2, &s1));
        let final_type = self.apply_subst(&s3, &result_type);

        let output = format!("{}", final_type);
        let tree = InferenceTree::new("T-App", &input, &output, vec![tree1, tree2, tree3]);
        Ok((final_subst, final_type, tree))
    }

    /// T-Let: Γ ⊢ e₁ : τ₁    σ = gen(Γ, τ₁)    Γ, x : σ ⊢ e₂ : τ₂
    ///        ──────────────────────────────────────────────────────
    ///                     Γ ⊢ let x = e₁ in e₂ : τ₂
    fn infer_let(
        &mut self,
        env: &Env,
        expr: &Expr,
        var: &str,
        value: &Expr,
        body: &Expr,
    ) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

        let (s1, value_type, tree1) = self.infer(env, value)?;
        let env_subst = self.apply_subst_env(&s1, env);
        let generalized_type = self.generalize(&env_subst, &value_type);

        let mut new_env = env_subst;
        new_env.insert(var.to_string(), generalized_type);

        let (s2, body_type, tree2) = self.infer(&new_env, body)?;

        let final_subst = self.compose_subst(&s2, &s1);
        let output = format!("{}", body_type);
        let tree = InferenceTree::new("T-Let", &input, &output, vec![tree1, tree2]);
        Ok((final_subst, body_type, tree))
    }

    /// T-Tuple: Γ ⊢ e₁ : τ₁    ...    Γ ⊢ eₙ : τₙ
    ///          ─────────────────────────────────────
    ///              Γ ⊢ (e₁, ..., eₙ) : (τ₁, ..., τₙ)
    fn infer_tuple(
        &mut self,
        env: &Env,
        expr: &Expr,
        exprs: &[Expr],
    ) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

        let mut subst = HashMap::new();
        let mut types = Vec::new();
        let mut trees = Vec::new();
        let mut current_env = env.clone();

        for expr in exprs {
            let (s, ty, tree) = self.infer(&current_env, expr)?;
            subst = self.compose_subst(&s, &subst);
            current_env = self.apply_subst_env(&s, &current_env);
            types.push(ty);
            trees.push(tree);
        }

        let result_type = Type::Tuple(types);
        let output = format!("{}", result_type);
        let tree = InferenceTree::new("T-Tuple", &input, &output, trees);
        Ok((subst, result_type, tree))
    }
}
pub fn run_inference(expr: &Expr) -> Result<InferenceTree> {
    let mut inference = TypeInference::new();
    let env = BTreeMap::new();
    let (_, _, tree) = inference.infer(&env, expr)?;
    Ok(tree)
}
pub fn infer_type_only(expr: &Expr) -> Result<Type> {
    let mut inference = TypeInference::new();
    let env = BTreeMap::new();
    let (_, ty, _) = inference.infer(&env, expr)?;
    Ok(ty)
}
pub struct TypeInference {
    counter: usize,
}
}

The inference engine’s primary responsibility is generating fresh type variables, a process that ensures each unknown type gets a unique identifier. This counter-based approach provides a simple but effective way to avoid naming collisions during the inference process.

#![allow(unused)]
fn main() {
fn fresh_tyvar(&mut self) -> TyVar {
    let var = format!("t{}", self.counter);
    self.counter += 1;
    var
}
}

Fresh variable generation forms the foundation for Algorithm W’s systematic approach to handling unknowns. Each time we encounter an expression whose type we don’t yet know, we assign it a fresh type variable. These variables later get unified with concrete types as we discover more information about the program’s structure.

Substitution and Unification

Type substitutions represent the core computational mechanism of Algorithm W. A substitution maps type variables to concrete types, effectively “solving” part of our type inference puzzle.

The application of substitutions must handle the recursive structure of types correctly, ensuring that substitutions propagate through compound types like arrows and tuples.

#![allow(unused)]
fn main() {
fn apply_subst(&self, subst: &Subst, ty: &Type) -> Type {
    match ty {
        Type::Var(name) => subst.get(name).cloned().unwrap_or_else(|| ty.clone()),
        Type::Arrow(t1, t2) => Type::Arrow(
            Box::new(self.apply_subst(subst, t1)),
            Box::new(self.apply_subst(subst, t2)),
        ),
        Type::Tuple(types) => {
            Type::Tuple(types.iter().map(|t| self.apply_subst(subst, t)).collect())
        }
        Type::Int | Type::Bool => ty.clone(),
    }
}
}

Substitution application demonstrates how type information flows through our system. When we apply a substitution to an arrow type, we must apply it recursively to both the parameter and return types. This ensures that type information discovered in one part of a program correctly influences other parts.

Composition of substitutions allows us to combine multiple partial solutions into a more complete understanding of our program’s types.

#![allow(unused)]
fn main() {
fn compose_subst(&self, s1: &Subst, s2: &Subst) -> Subst {
    let mut result = s1.clone();
    for (k, v) in s2 {
        result.insert(k.clone(), self.apply_subst(s1, v));
    }
    result
}
}

The composition operation ensures that when we have multiple substitutions from different parts of our inference process, we can combine them into a single, consistent substitution that represents our cumulative knowledge about the program’s types.

Substitutions must also be applied to entire type environments when we discover new type information. This operation updates all the types in the environment according to the current substitution.

#![allow(unused)]
fn main() {
fn apply_subst_env(&self, subst: &Subst, env: &Env) -> Env {
    env.iter()
        .map(|(k, v)| (k.clone(), self.apply_subst_scheme(subst, v)))
        .collect()
}
}

Environment substitution is crucial for maintaining consistency as inference progresses. When we discover that a type variable should be instantiated to a concrete type, we must update not just individual types but entire environments to reflect this new knowledge.

Unification

Unification is the heart of type inference, solving constraints between types. The unification algorithm produces substitutions that make two types equivalent:

Reflexivity - identical types unify trivially: \[ \frac{}{\text{unify}(τ, τ) = \emptyset} \text{(U-Refl)} \]

Variable unification with occurs check: \[ \frac{α \notin \text{ftv}(τ)}{\text{unify}(α, τ) = [α ↦ τ]} \text{(U-VarL)} \]

\[ \frac{α \notin \text{ftv}(τ)}{\text{unify}(τ, α) = [α ↦ τ]} \text{(U-VarR)} \]

Arrow type unification decomposes into domain and codomain: \[ \frac{S₁ = \text{unify}(τ₁, τ₃) \quad S₂ = \text{unify}(S₁(τ₂), S₁(τ₄))}{\text{unify}(τ₁ → τ₂, τ₃ → τ₄) = S₂ ∘ S₁} \text{(U-Arrow)} \]

Tuple unification requires component-wise unification: \[ \frac{S₁ = \text{unify}(τ₁, τ₃) \quad S₂ = \text{unify}(S₁(τ₂), S₁(τ₄))}{\text{unify}((τ₁, τ₂), (τ₃, τ₄)) = S₂ ∘ S₁} \text{(U-Tuple)} \]

Base type unification succeeds only for identical types: \[ \frac{}{\text{unify}(\text{Int}, \text{Int}) = \emptyset} \text{(U-Int)} \]

\[ \frac{}{\text{unify}(\text{Bool}, \text{Bool}) = \emptyset} \text{(U-Bool)} \]

These unification rules ensure that type constraints are solved systematically while maintaining soundness through the occurs check.

When we have two types that must be equal (such as the parameter type of a function and the type of an argument being passed to it), unification determines whether this constraint can be satisfied and, if so, what substitution makes them equal.

#![allow(unused)]
fn main() {
fn unify(&self, t1: &Type, t2: &Type) -> Result<(Subst, InferenceTree)> {
    let input = format!("{} ~ {}", t1, t2);

    match (t1, t2) {
        (Type::Int, Type::Int) | (Type::Bool, Type::Bool) => {
            let tree = InferenceTree::new("Unify-Base", &input, "{}", vec![]);
            Ok((HashMap::new(), tree))
        }
        (Type::Var(v), ty) | (ty, Type::Var(v)) => {
            if ty == &Type::Var(v.clone()) {
                let tree = InferenceTree::new("Unify-Var-Same", &input, "{}", vec![]);
                Ok((HashMap::new(), tree))
            } else if self.occurs_check(v, ty) {
                Err(InferenceError::OccursCheck {
                    var: v.clone(),
                    ty: ty.clone(),
                })
            } else {
                let mut subst = HashMap::new();
                subst.insert(v.clone(), ty.clone());
                let output = format!("{{{}/{}}}", ty, v);
                let tree = InferenceTree::new("Unify-Var", &input, &output, vec![]);
                Ok((subst, tree))
            }
        }
        (Type::Arrow(a1, a2), Type::Arrow(b1, b2)) => {
            let (s1, tree1) = self.unify(a1, b1)?;
            let a2_subst = self.apply_subst(&s1, a2);
            let b2_subst = self.apply_subst(&s1, b2);
            let (s2, tree2) = self.unify(&a2_subst, &b2_subst)?;
            let final_subst = self.compose_subst(&s2, &s1);
            let output = self.pretty_subst(&final_subst);
            let tree = InferenceTree::new("Unify-Arrow", &input, &output, vec![tree1, tree2]);
            Ok((final_subst, tree))
        }
        (Type::Tuple(ts1), Type::Tuple(ts2)) => {
            if ts1.len() != ts2.len() {
                return Err(InferenceError::TupleLengthMismatch {
                    left_len: ts1.len(),
                    right_len: ts2.len(),
                });
            }

            let mut subst = HashMap::new();
            let mut trees = Vec::new();

            for (t1, t2) in ts1.iter().zip(ts2.iter()) {
                let t1_subst = self.apply_subst(&subst, t1);
                let t2_subst = self.apply_subst(&subst, t2);
                let (s, tree) = self.unify(&t1_subst, &t2_subst)?;
                subst = self.compose_subst(&s, &subst);
                trees.push(tree);
            }

            let output = self.pretty_subst(&subst);
            let tree = InferenceTree::new("Unify-Tuple", &input, &output, trees);
            Ok((subst, tree))
        }
        _ => Err(InferenceError::UnificationFailure {
            expected: t1.clone(),
            actual: t2.clone(),
        }),
    }
}
}

The unification algorithm handles several distinct cases, each representing a different constraint-solving scenario. When unifying two identical base types like Int with Int, no substitution is needed. When unifying a type variable with any other type, we create a substitution that maps the variable to that type, provided the occurs check passes.

The occurs check prevents infinite types by ensuring that a type variable doesn’t appear within the type it’s being unified with.

#![allow(unused)]
fn main() {
fn occurs_check(&self, var: &TyVar, ty: &Type) -> bool {
    match ty {
        Type::Var(name) => name == var,
        Type::Arrow(t1, t2) => self.occurs_check(var, t1) || self.occurs_check(var, t2),
        Type::Tuple(types) => types.iter().any(|t| self.occurs_check(var, t)),
        Type::Int | Type::Bool => false,
    }
}
}

This check is essential for soundness. Without it, we might generate infinite types like t0 = t0 -> Int, which would break our type system’s decidability.

For compound types like arrows, unification becomes recursive. We must unify corresponding subcomponents and then compose the resulting substitutions. This process ensures that complex types maintain their structural relationships while allowing for flexible instantiation of type variables.

The Main Inference Algorithm

The central infer method implements Algorithm W proper, analyzing expressions to determine their types while accumulating the necessary substitutions. Our implementation uses a modular approach where each syntactic construct has its own helper method implementing the corresponding typing rule.

#![allow(unused)]
fn main() {
pub fn infer(&mut self, env: &Env, expr: &Expr) -> Result<(Subst, Type, InferenceTree)> {
    match expr {
        Expr::Lit(Lit::Int(_)) => self.infer_lit_int(env, expr),
        Expr::Lit(Lit::Bool(_)) => self.infer_lit_bool(env, expr),
        Expr::Var(name) => self.infer_var(env, expr, name),
        Expr::Abs(param, body) => self.infer_abs(env, expr, param, body),
        Expr::App(func, arg) => self.infer_app(env, expr, func, arg),
        Expr::Let(var, value, body) => self.infer_let(env, expr, var, value, body),
        Expr::Tuple(exprs) => self.infer_tuple(env, expr, exprs),
    }
}
}

Each helper method corresponds directly to a formal typing rule, making the relationship between theory and implementation explicit.

Variable Lookup

\[ \frac{x : σ \in Γ \quad τ = \text{inst}(σ)}{Γ ⊢ x : τ} \text{(T-Var)} \]

#![allow(unused)]
fn main() {
/// T-Var: x : σ ∈ Γ    τ = inst(σ)
///        ─────────────────────────
///               Γ ⊢ x : τ
fn infer_var(
    &mut self,
    env: &Env,
    expr: &Expr,
    name: &str,
) -> Result<(Subst, Type, InferenceTree)> {
    let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

    match env.get(name) {
        Some(scheme) => {
            let instantiated = self.instantiate(scheme);
            let output = format!("{}", instantiated);
            let tree = InferenceTree::new("T-Var", &input, &output, vec![]);
            Ok((HashMap::new(), instantiated, tree))
        }
        None => Err(InferenceError::UnboundVariable {
            name: name.to_string(),
        }),
    }
}
}

Variable lookup requires instantiation of polymorphic types. When we find a variable in the environment, it might have a polymorphic type scheme like ∀α. α → α. We create a fresh monomorphic instance by replacing quantified variables with fresh type variables.

Lambda Abstraction

\[ \frac{Γ, x : α ⊢ e : τ \quad α \text{ fresh}}{Γ ⊢ λx. e : α → τ} \text{(T-Lam)} \]

#![allow(unused)]
fn main() {
/// T-Lam: Γ, x : α ⊢ e : τ    α fresh
///        ─────────────────────────────
///           Γ ⊢ λx. e : α → τ
fn infer_abs(
    &mut self,
    env: &Env,
    expr: &Expr,
    param: &str,
    body: &Expr,
) -> Result<(Subst, Type, InferenceTree)> {
    let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

    let param_type = Type::Var(self.fresh_tyvar());
    let mut new_env = env.clone();
    // Insert a monomorphic scheme for the parameter
    let param_scheme = Scheme {
        vars: vec![],
        ty: param_type.clone(),
    };
    new_env.insert(param.to_string(), param_scheme);

    let (s1, body_type, tree1) = self.infer(&new_env, body)?;
    let param_type_subst = self.apply_subst(&s1, &param_type);
    let result_type = Type::Arrow(Box::new(param_type_subst), Box::new(body_type));

    let output = format!("{}", result_type);
    let tree = InferenceTree::new("T-Abs", &input, &output, vec![tree1]);
    Ok((s1, result_type, tree))
}
}

Lambda abstractions introduce new variable bindings. We assign a fresh type variable to the parameter, extend the environment, and infer the body’s type. Any constraints discovered during body inference get propagated back through substitution.

Function Application

\[ \frac{Γ ⊢ e₁ : τ₁ \quad Γ ⊢ e₂ : τ₂ \quad α \text{ fresh} \quad S = \text{unify}(τ₁, τ₂ → α)}{Γ ⊢ e₁ \, e₂ : S(α)} \text{(T-App)} \]

#![allow(unused)]
fn main() {
/// T-App: Γ ⊢ e₁ : τ₁    Γ ⊢ e₂ : τ₂    α fresh    S = unify(τ₁, τ₂ → α)
///        ──────────────────────────────────────────────────────────────
///                            Γ ⊢ e₁ e₂ : S(α)
fn infer_app(
    &mut self,
    env: &Env,
    expr: &Expr,
    func: &Expr,
    arg: &Expr,
) -> Result<(Subst, Type, InferenceTree)> {
    let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

    let result_type = Type::Var(self.fresh_tyvar());

    let (s1, func_type, tree1) = self.infer(env, func)?;
    let env_subst = self.apply_subst_env(&s1, env);
    let (s2, arg_type, tree2) = self.infer(&env_subst, arg)?;

    let func_type_subst = self.apply_subst(&s2, &func_type);
    let expected_func_type = Type::Arrow(Box::new(arg_type), Box::new(result_type.clone()));

    let (s3, tree3) = self.unify(&func_type_subst, &expected_func_type)?;

    let final_subst = self.compose_subst(&s3, &self.compose_subst(&s2, &s1));
    let final_type = self.apply_subst(&s3, &result_type);

    let output = format!("{}", final_type);
    let tree = InferenceTree::new("T-App", &input, &output, vec![tree1, tree2, tree3]);
    Ok((final_subst, final_type, tree))
}
}

Application drives constraint generation. We infer types for both function and argument, then unify the function type with an arrow type constructed from the argument type and a fresh result type variable.

Let-Polymorphism

\[ \frac{Γ ⊢ e₁ : τ₁ \quad σ = \text{gen}(Γ, τ₁) \quad Γ, x : σ ⊢ e₂ : τ₂}{Γ ⊢ \text{let } x = e₁ \text{ in } e₂ : τ₂} \text{(T-Let)} \]

#![allow(unused)]
fn main() {
/// T-Let: Γ ⊢ e₁ : τ₁    σ = gen(Γ, τ₁)    Γ, x : σ ⊢ e₂ : τ₂
///        ──────────────────────────────────────────────────────
///                     Γ ⊢ let x = e₁ in e₂ : τ₂
fn infer_let(
    &mut self,
    env: &Env,
    expr: &Expr,
    var: &str,
    value: &Expr,
    body: &Expr,
) -> Result<(Subst, Type, InferenceTree)> {
    let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

    let (s1, value_type, tree1) = self.infer(env, value)?;
    let env_subst = self.apply_subst_env(&s1, env);
    let generalized_type = self.generalize(&env_subst, &value_type);

    let mut new_env = env_subst;
    new_env.insert(var.to_string(), generalized_type);

    let (s2, body_type, tree2) = self.infer(&new_env, body)?;

    let final_subst = self.compose_subst(&s2, &s1);
    let output = format!("{}", body_type);
    let tree = InferenceTree::new("T-Let", &input, &output, vec![tree1, tree2]);
    Ok((final_subst, body_type, tree))
}
}

Let expressions enable polymorphism through generalization. After inferring the bound expression’s type, we generalize it by quantifying over type variables not constrained by the environment. This allows polymorphic usage in the let body.

Literal Types

\[ \frac{}{Γ ⊢ n : \text{Int}} \text{(T-LitInt)} \]

\[ \frac{}{Γ ⊢ b : \text{Bool}} \text{(T-LitBool)} \]

#![allow(unused)]
fn main() {
/// T-LitInt: ─────────────────
///           Γ ⊢ n : Int
fn infer_lit_int(&mut self, env: &Env, expr: &Expr) -> Result<(Subst, Type, InferenceTree)> {
    let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);
    let tree = InferenceTree::new("T-Int", &input, "Int", vec![]);
    Ok((HashMap::new(), Type::Int, tree))
}
}
#![allow(unused)]
fn main() {
/// T-LitBool: ─────────────────
///            Γ ⊢ b : Bool
fn infer_lit_bool(&mut self, env: &Env, expr: &Expr) -> Result<(Subst, Type, InferenceTree)> {
    let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);
    let tree = InferenceTree::new("T-Bool", &input, "Bool", vec![]);
    Ok((HashMap::new(), Type::Bool, tree))
}
}

Literals have known types and require no constraint generation.

Generalization and Instantiation

The generalization and instantiation mechanisms handle let-polymorphism, allowing variables bound in let expressions to be used with multiple different types.

Understanding Generalization with Examples

Generalization turns concrete types into polymorphic type schemes. Consider this simple example:

let id = \x -> x in (id 42, id true)

When we infer the type of \x -> x, we get something like t0 → t0 where t0 is a type variable. Since t0 doesn’t appear anywhere else in the environment, we can generalize it to ∀t0. t0 → t0, making id polymorphic.

This is what allows id to be used with both 42 (type Int) and true (type Bool) in the same expression. Without generalization, the first use would fix t0 to Int, making the second use fail.

#![allow(unused)]
fn main() {
fn generalize(&self, env: &Env, ty: &Type) -> Scheme {
    let type_vars = self.free_type_vars(ty);
    let env_vars = self.free_type_vars_env(env);
    let mut free_vars: Vec<_> = type_vars.difference(&env_vars).cloned().collect();
    free_vars.sort(); // Sort for deterministic behavior

    Scheme {
        vars: free_vars,
        ty: ty.clone(),
    }
}
}

Generalization identifies type variables that could be made polymorphic by checking which ones don’t appear free in the current environment. If a type variable isn’t constrained by anything else in scope, it’s safe to quantify over it.

Understanding Instantiation with Examples

Instantiation creates fresh monomorphic versions from polymorphic types. When we use a polymorphic function like our identity function, we need to create a fresh copy of its type for each use.

Consider this expression:

let id = \x -> x in id id

Here we’re applying the polymorphic identity function to itself. The first id gets instantiated to (α → α) → (α → α) while the second id gets instantiated to α → α. These different instantiations allow the application to type-check successfully.

#![allow(unused)]
fn main() {
fn instantiate(&mut self, scheme: &Scheme) -> Type {
    // Create fresh type variables for each quantified variable
    let mut subst = HashMap::new();
    for var in &scheme.vars {
        let fresh = self.fresh_tyvar();
        subst.insert(var.clone(), Type::Var(fresh));
    }

    self.apply_subst(&subst, &scheme.ty)
}
}

Instantiation replaces quantified type variables with fresh type variables. This ensures that each use of a polymorphic function gets its own independent type constraints, preventing interference between different call sites.

Free Type Variables

Generalization depends on computing the free type variables in both individual types and entire environments. These operations identify which type variables could potentially be generalized versus those that are already constrained.

#![allow(unused)]
fn main() {
fn free_type_vars(&self, ty: &Type) -> HashSet<TyVar> {
    match ty {
        Type::Var(name) => {
            let mut set = HashSet::new();
            set.insert(name.clone());
            set
        }
        Type::Arrow(t1, t2) => {
            let mut set = self.free_type_vars(t1);
            set.extend(self.free_type_vars(t2));
            set
        }
        Type::Tuple(types) => {
            let mut set = HashSet::new();
            for t in types {
                set.extend(self.free_type_vars(t));
            }
            set
        }
        Type::Int | Type::Bool => HashSet::new(),
    }
}
}

The free type variables computation traverses type structures recursively, collecting all type variables that appear unbound. For compound types like arrows and tuples, it must traverse all subcomponents to ensure no variables are missed.

#![allow(unused)]
fn main() {
fn free_type_vars_env(&self, env: &Env) -> HashSet<TyVar> {
    let mut set = HashSet::new();
    for scheme in env.values() {
        set.extend(self.free_type_vars_scheme(scheme));
    }
    set
}
}

Computing free variables across entire environments requires examining every type in the environment and taking the union of their free variables. This gives us the complete set of type variables that are constrained by the current context.

Our complete implementation fully supports polymorphic instantiation by generating fresh type variables for each quantified variable in a scheme when it is instantiated. This mechanism is what allows the identity function to work on integers in one context and booleans in another, as demonstrated by expressions like let id = \x -> x in (id, id) which produces the type (t1 -> t1, t2 -> t2) showing proper polymorphic instantiation.

Error Handling and Inference Trees

Our implementation provides detailed error reporting and generates inference trees that show the step-by-step reasoning process.

#![allow(unused)]
fn main() {
use std::collections::{BTreeMap, HashMap, HashSet};
use std::fmt;
use crate::ast::{Expr, Lit, Scheme, Type};
use crate::errors::{InferenceError, Result};
pub type TyVar = String;
pub type TmVar = String;
pub type Env = BTreeMap<TmVar, Scheme>;
pub type Subst = HashMap<TyVar, Type>;
impl InferenceTree {
    fn new(rule: &str, input: &str, output: &str, children: Vec<InferenceTree>) -> Self {
        Self {
            rule: rule.to_string(),
            input: input.to_string(),
            output: output.to_string(),
            children,
        }
    }
}
impl fmt::Display for InferenceTree {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        self.display_with_indent(f, 0)
    }
}
impl InferenceTree {
    fn display_with_indent(&self, f: &mut fmt::Formatter, indent: usize) -> fmt::Result {
        let prefix = "  ".repeat(indent);
        writeln!(
            f,
            "{}{}: {} => {}",
            prefix, self.rule, self.input, self.output
        )?;
        for child in &self.children {
            child.display_with_indent(f, indent + 1)?;
        }
        Ok(())
    }
}
pub struct TypeInference {
    counter: usize,
}
impl Default for TypeInference {
    fn default() -> Self {
        Self::new()
    }
}
#[allow(clippy::only_used_in_recursion)]
impl TypeInference {
    pub fn new() -> Self {
        Self { counter: 0 }
    }

    fn fresh_tyvar(&mut self) -> TyVar {
        let var = format!("t{}", self.counter);
        self.counter += 1;
        var
    }

    fn pretty_env(&self, env: &Env) -> String {
        if env.is_empty() {
            "{}".to_string()
        } else {
            let entries: Vec<String> = env.iter().map(|(k, v)| format!("{}: {}", k, v)).collect();
            format!("{{{}}}", entries.join(", "))
        }
    }

    fn pretty_subst(&self, subst: &Subst) -> String {
        if subst.is_empty() {
            "{}".to_string()
        } else {
            let entries: Vec<String> = subst.iter().map(|(k, v)| format!("{}/{}", v, k)).collect();
            format!("{{{}}}", entries.join(", "))
        }
    }

    fn apply_subst(&self, subst: &Subst, ty: &Type) -> Type {
        match ty {
            Type::Var(name) => subst.get(name).cloned().unwrap_or_else(|| ty.clone()),
            Type::Arrow(t1, t2) => Type::Arrow(
                Box::new(self.apply_subst(subst, t1)),
                Box::new(self.apply_subst(subst, t2)),
            ),
            Type::Tuple(types) => {
                Type::Tuple(types.iter().map(|t| self.apply_subst(subst, t)).collect())
            }
            Type::Int | Type::Bool => ty.clone(),
        }
    }

    fn apply_subst_scheme(&self, subst: &Subst, scheme: &Scheme) -> Scheme {
        // Remove bindings for quantified variables to avoid capture
        let mut filtered_subst = subst.clone();
        for var in &scheme.vars {
            filtered_subst.remove(var);
        }
        Scheme {
            vars: scheme.vars.clone(),
            ty: self.apply_subst(&filtered_subst, &scheme.ty),
        }
    }

    fn apply_subst_env(&self, subst: &Subst, env: &Env) -> Env {
        env.iter()
            .map(|(k, v)| (k.clone(), self.apply_subst_scheme(subst, v)))
            .collect()
    }

    fn compose_subst(&self, s1: &Subst, s2: &Subst) -> Subst {
        let mut result = s1.clone();
        for (k, v) in s2 {
            result.insert(k.clone(), self.apply_subst(s1, v));
        }
        result
    }

    fn free_type_vars(&self, ty: &Type) -> HashSet<TyVar> {
        match ty {
            Type::Var(name) => {
                let mut set = HashSet::new();
                set.insert(name.clone());
                set
            }
            Type::Arrow(t1, t2) => {
                let mut set = self.free_type_vars(t1);
                set.extend(self.free_type_vars(t2));
                set
            }
            Type::Tuple(types) => {
                let mut set = HashSet::new();
                for t in types {
                    set.extend(self.free_type_vars(t));
                }
                set
            }
            Type::Int | Type::Bool => HashSet::new(),
        }
    }

    fn free_type_vars_scheme(&self, scheme: &Scheme) -> HashSet<TyVar> {
        let mut set = self.free_type_vars(&scheme.ty);
        // Remove quantified variables
        for var in &scheme.vars {
            set.remove(var);
        }
        set
    }

    fn free_type_vars_env(&self, env: &Env) -> HashSet<TyVar> {
        let mut set = HashSet::new();
        for scheme in env.values() {
            set.extend(self.free_type_vars_scheme(scheme));
        }
        set
    }

    fn generalize(&self, env: &Env, ty: &Type) -> Scheme {
        let type_vars = self.free_type_vars(ty);
        let env_vars = self.free_type_vars_env(env);
        let mut free_vars: Vec<_> = type_vars.difference(&env_vars).cloned().collect();
        free_vars.sort(); // Sort for deterministic behavior

        Scheme {
            vars: free_vars,
            ty: ty.clone(),
        }
    }

    fn instantiate(&mut self, scheme: &Scheme) -> Type {
        // Create fresh type variables for each quantified variable
        let mut subst = HashMap::new();
        for var in &scheme.vars {
            let fresh = self.fresh_tyvar();
            subst.insert(var.clone(), Type::Var(fresh));
        }

        self.apply_subst(&subst, &scheme.ty)
    }

    fn occurs_check(&self, var: &TyVar, ty: &Type) -> bool {
        match ty {
            Type::Var(name) => name == var,
            Type::Arrow(t1, t2) => self.occurs_check(var, t1) || self.occurs_check(var, t2),
            Type::Tuple(types) => types.iter().any(|t| self.occurs_check(var, t)),
            Type::Int | Type::Bool => false,
        }
    }

    fn unify(&self, t1: &Type, t2: &Type) -> Result<(Subst, InferenceTree)> {
        let input = format!("{} ~ {}", t1, t2);

        match (t1, t2) {
            (Type::Int, Type::Int) | (Type::Bool, Type::Bool) => {
                let tree = InferenceTree::new("Unify-Base", &input, "{}", vec![]);
                Ok((HashMap::new(), tree))
            }
            (Type::Var(v), ty) | (ty, Type::Var(v)) => {
                if ty == &Type::Var(v.clone()) {
                    let tree = InferenceTree::new("Unify-Var-Same", &input, "{}", vec![]);
                    Ok((HashMap::new(), tree))
                } else if self.occurs_check(v, ty) {
                    Err(InferenceError::OccursCheck {
                        var: v.clone(),
                        ty: ty.clone(),
                    })
                } else {
                    let mut subst = HashMap::new();
                    subst.insert(v.clone(), ty.clone());
                    let output = format!("{{{}/{}}}", ty, v);
                    let tree = InferenceTree::new("Unify-Var", &input, &output, vec![]);
                    Ok((subst, tree))
                }
            }
            (Type::Arrow(a1, a2), Type::Arrow(b1, b2)) => {
                let (s1, tree1) = self.unify(a1, b1)?;
                let a2_subst = self.apply_subst(&s1, a2);
                let b2_subst = self.apply_subst(&s1, b2);
                let (s2, tree2) = self.unify(&a2_subst, &b2_subst)?;
                let final_subst = self.compose_subst(&s2, &s1);
                let output = self.pretty_subst(&final_subst);
                let tree = InferenceTree::new("Unify-Arrow", &input, &output, vec![tree1, tree2]);
                Ok((final_subst, tree))
            }
            (Type::Tuple(ts1), Type::Tuple(ts2)) => {
                if ts1.len() != ts2.len() {
                    return Err(InferenceError::TupleLengthMismatch {
                        left_len: ts1.len(),
                        right_len: ts2.len(),
                    });
                }

                let mut subst = HashMap::new();
                let mut trees = Vec::new();

                for (t1, t2) in ts1.iter().zip(ts2.iter()) {
                    let t1_subst = self.apply_subst(&subst, t1);
                    let t2_subst = self.apply_subst(&subst, t2);
                    let (s, tree) = self.unify(&t1_subst, &t2_subst)?;
                    subst = self.compose_subst(&s, &subst);
                    trees.push(tree);
                }

                let output = self.pretty_subst(&subst);
                let tree = InferenceTree::new("Unify-Tuple", &input, &output, trees);
                Ok((subst, tree))
            }
            _ => Err(InferenceError::UnificationFailure {
                expected: t1.clone(),
                actual: t2.clone(),
            }),
        }
    }

    pub fn infer(&mut self, env: &Env, expr: &Expr) -> Result<(Subst, Type, InferenceTree)> {
        match expr {
            Expr::Lit(Lit::Int(_)) => self.infer_lit_int(env, expr),
            Expr::Lit(Lit::Bool(_)) => self.infer_lit_bool(env, expr),
            Expr::Var(name) => self.infer_var(env, expr, name),
            Expr::Abs(param, body) => self.infer_abs(env, expr, param, body),
            Expr::App(func, arg) => self.infer_app(env, expr, func, arg),
            Expr::Let(var, value, body) => self.infer_let(env, expr, var, value, body),
            Expr::Tuple(exprs) => self.infer_tuple(env, expr, exprs),
        }
    }

    /// T-LitInt: ─────────────────
    ///           Γ ⊢ n : Int
    fn infer_lit_int(&mut self, env: &Env, expr: &Expr) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);
        let tree = InferenceTree::new("T-Int", &input, "Int", vec![]);
        Ok((HashMap::new(), Type::Int, tree))
    }

    /// T-LitBool: ─────────────────
    ///            Γ ⊢ b : Bool
    fn infer_lit_bool(&mut self, env: &Env, expr: &Expr) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);
        let tree = InferenceTree::new("T-Bool", &input, "Bool", vec![]);
        Ok((HashMap::new(), Type::Bool, tree))
    }

    /// T-Var: x : σ ∈ Γ    τ = inst(σ)
    ///        ─────────────────────────
    ///               Γ ⊢ x : τ
    fn infer_var(
        &mut self,
        env: &Env,
        expr: &Expr,
        name: &str,
    ) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

        match env.get(name) {
            Some(scheme) => {
                let instantiated = self.instantiate(scheme);
                let output = format!("{}", instantiated);
                let tree = InferenceTree::new("T-Var", &input, &output, vec![]);
                Ok((HashMap::new(), instantiated, tree))
            }
            None => Err(InferenceError::UnboundVariable {
                name: name.to_string(),
            }),
        }
    }

    /// T-Lam: Γ, x : α ⊢ e : τ    α fresh
    ///        ─────────────────────────────
    ///           Γ ⊢ λx. e : α → τ
    fn infer_abs(
        &mut self,
        env: &Env,
        expr: &Expr,
        param: &str,
        body: &Expr,
    ) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

        let param_type = Type::Var(self.fresh_tyvar());
        let mut new_env = env.clone();
        // Insert a monomorphic scheme for the parameter
        let param_scheme = Scheme {
            vars: vec![],
            ty: param_type.clone(),
        };
        new_env.insert(param.to_string(), param_scheme);

        let (s1, body_type, tree1) = self.infer(&new_env, body)?;
        let param_type_subst = self.apply_subst(&s1, &param_type);
        let result_type = Type::Arrow(Box::new(param_type_subst), Box::new(body_type));

        let output = format!("{}", result_type);
        let tree = InferenceTree::new("T-Abs", &input, &output, vec![tree1]);
        Ok((s1, result_type, tree))
    }

    /// T-App: Γ ⊢ e₁ : τ₁    Γ ⊢ e₂ : τ₂    α fresh    S = unify(τ₁, τ₂ → α)
    ///        ──────────────────────────────────────────────────────────────
    ///                            Γ ⊢ e₁ e₂ : S(α)
    fn infer_app(
        &mut self,
        env: &Env,
        expr: &Expr,
        func: &Expr,
        arg: &Expr,
    ) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

        let result_type = Type::Var(self.fresh_tyvar());

        let (s1, func_type, tree1) = self.infer(env, func)?;
        let env_subst = self.apply_subst_env(&s1, env);
        let (s2, arg_type, tree2) = self.infer(&env_subst, arg)?;

        let func_type_subst = self.apply_subst(&s2, &func_type);
        let expected_func_type = Type::Arrow(Box::new(arg_type), Box::new(result_type.clone()));

        let (s3, tree3) = self.unify(&func_type_subst, &expected_func_type)?;

        let final_subst = self.compose_subst(&s3, &self.compose_subst(&s2, &s1));
        let final_type = self.apply_subst(&s3, &result_type);

        let output = format!("{}", final_type);
        let tree = InferenceTree::new("T-App", &input, &output, vec![tree1, tree2, tree3]);
        Ok((final_subst, final_type, tree))
    }

    /// T-Let: Γ ⊢ e₁ : τ₁    σ = gen(Γ, τ₁)    Γ, x : σ ⊢ e₂ : τ₂
    ///        ──────────────────────────────────────────────────────
    ///                     Γ ⊢ let x = e₁ in e₂ : τ₂
    fn infer_let(
        &mut self,
        env: &Env,
        expr: &Expr,
        var: &str,
        value: &Expr,
        body: &Expr,
    ) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

        let (s1, value_type, tree1) = self.infer(env, value)?;
        let env_subst = self.apply_subst_env(&s1, env);
        let generalized_type = self.generalize(&env_subst, &value_type);

        let mut new_env = env_subst;
        new_env.insert(var.to_string(), generalized_type);

        let (s2, body_type, tree2) = self.infer(&new_env, body)?;

        let final_subst = self.compose_subst(&s2, &s1);
        let output = format!("{}", body_type);
        let tree = InferenceTree::new("T-Let", &input, &output, vec![tree1, tree2]);
        Ok((final_subst, body_type, tree))
    }

    /// T-Tuple: Γ ⊢ e₁ : τ₁    ...    Γ ⊢ eₙ : τₙ
    ///          ─────────────────────────────────────
    ///              Γ ⊢ (e₁, ..., eₙ) : (τ₁, ..., τₙ)
    fn infer_tuple(
        &mut self,
        env: &Env,
        expr: &Expr,
        exprs: &[Expr],
    ) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

        let mut subst = HashMap::new();
        let mut types = Vec::new();
        let mut trees = Vec::new();
        let mut current_env = env.clone();

        for expr in exprs {
            let (s, ty, tree) = self.infer(&current_env, expr)?;
            subst = self.compose_subst(&s, &subst);
            current_env = self.apply_subst_env(&s, &current_env);
            types.push(ty);
            trees.push(tree);
        }

        let result_type = Type::Tuple(types);
        let output = format!("{}", result_type);
        let tree = InferenceTree::new("T-Tuple", &input, &output, trees);
        Ok((subst, result_type, tree))
    }
}
pub fn run_inference(expr: &Expr) -> Result<InferenceTree> {
    let mut inference = TypeInference::new();
    let env = BTreeMap::new();
    let (_, _, tree) = inference.infer(&env, expr)?;
    Ok(tree)
}
pub fn infer_type_only(expr: &Expr) -> Result<Type> {
    let mut inference = TypeInference::new();
    let env = BTreeMap::new();
    let (_, ty, _) = inference.infer(&env, expr)?;
    Ok(ty)
}
#[derive(Debug)]
pub struct InferenceTree {
    pub rule: String,
    pub input: String,
    pub output: String,
    pub children: Vec<InferenceTree>,
}
}

These inference trees serve both debugging and educational purposes. They make explicit the implicit reasoning that Algorithm W performs, showing how type information flows through the program and how constraints get generated and solved.

The public interface provides both tree-generating and type-only versions of inference, supporting different use cases from interactive development to automated tooling.

#![allow(unused)]
fn main() {
use std::collections::{BTreeMap, HashMap, HashSet};
use std::fmt;
use crate::ast::{Expr, Lit, Scheme, Type};
use crate::errors::{InferenceError, Result};
pub type TyVar = String;
pub type TmVar = String;
pub type Env = BTreeMap<TmVar, Scheme>;
pub type Subst = HashMap<TyVar, Type>;
#[derive(Debug)]
pub struct InferenceTree {
    pub rule: String,
    pub input: String,
    pub output: String,
    pub children: Vec<InferenceTree>,
}
impl InferenceTree {
    fn new(rule: &str, input: &str, output: &str, children: Vec<InferenceTree>) -> Self {
        Self {
            rule: rule.to_string(),
            input: input.to_string(),
            output: output.to_string(),
            children,
        }
    }
}
impl fmt::Display for InferenceTree {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        self.display_with_indent(f, 0)
    }
}
impl InferenceTree {
    fn display_with_indent(&self, f: &mut fmt::Formatter, indent: usize) -> fmt::Result {
        let prefix = "  ".repeat(indent);
        writeln!(
            f,
            "{}{}: {} => {}",
            prefix, self.rule, self.input, self.output
        )?;
        for child in &self.children {
            child.display_with_indent(f, indent + 1)?;
        }
        Ok(())
    }
}
pub struct TypeInference {
    counter: usize,
}
impl Default for TypeInference {
    fn default() -> Self {
        Self::new()
    }
}
#[allow(clippy::only_used_in_recursion)]
impl TypeInference {
    pub fn new() -> Self {
        Self { counter: 0 }
    }

    fn fresh_tyvar(&mut self) -> TyVar {
        let var = format!("t{}", self.counter);
        self.counter += 1;
        var
    }

    fn pretty_env(&self, env: &Env) -> String {
        if env.is_empty() {
            "{}".to_string()
        } else {
            let entries: Vec<String> = env.iter().map(|(k, v)| format!("{}: {}", k, v)).collect();
            format!("{{{}}}", entries.join(", "))
        }
    }

    fn pretty_subst(&self, subst: &Subst) -> String {
        if subst.is_empty() {
            "{}".to_string()
        } else {
            let entries: Vec<String> = subst.iter().map(|(k, v)| format!("{}/{}", v, k)).collect();
            format!("{{{}}}", entries.join(", "))
        }
    }

    fn apply_subst(&self, subst: &Subst, ty: &Type) -> Type {
        match ty {
            Type::Var(name) => subst.get(name).cloned().unwrap_or_else(|| ty.clone()),
            Type::Arrow(t1, t2) => Type::Arrow(
                Box::new(self.apply_subst(subst, t1)),
                Box::new(self.apply_subst(subst, t2)),
            ),
            Type::Tuple(types) => {
                Type::Tuple(types.iter().map(|t| self.apply_subst(subst, t)).collect())
            }
            Type::Int | Type::Bool => ty.clone(),
        }
    }

    fn apply_subst_scheme(&self, subst: &Subst, scheme: &Scheme) -> Scheme {
        // Remove bindings for quantified variables to avoid capture
        let mut filtered_subst = subst.clone();
        for var in &scheme.vars {
            filtered_subst.remove(var);
        }
        Scheme {
            vars: scheme.vars.clone(),
            ty: self.apply_subst(&filtered_subst, &scheme.ty),
        }
    }

    fn apply_subst_env(&self, subst: &Subst, env: &Env) -> Env {
        env.iter()
            .map(|(k, v)| (k.clone(), self.apply_subst_scheme(subst, v)))
            .collect()
    }

    fn compose_subst(&self, s1: &Subst, s2: &Subst) -> Subst {
        let mut result = s1.clone();
        for (k, v) in s2 {
            result.insert(k.clone(), self.apply_subst(s1, v));
        }
        result
    }

    fn free_type_vars(&self, ty: &Type) -> HashSet<TyVar> {
        match ty {
            Type::Var(name) => {
                let mut set = HashSet::new();
                set.insert(name.clone());
                set
            }
            Type::Arrow(t1, t2) => {
                let mut set = self.free_type_vars(t1);
                set.extend(self.free_type_vars(t2));
                set
            }
            Type::Tuple(types) => {
                let mut set = HashSet::new();
                for t in types {
                    set.extend(self.free_type_vars(t));
                }
                set
            }
            Type::Int | Type::Bool => HashSet::new(),
        }
    }

    fn free_type_vars_scheme(&self, scheme: &Scheme) -> HashSet<TyVar> {
        let mut set = self.free_type_vars(&scheme.ty);
        // Remove quantified variables
        for var in &scheme.vars {
            set.remove(var);
        }
        set
    }

    fn free_type_vars_env(&self, env: &Env) -> HashSet<TyVar> {
        let mut set = HashSet::new();
        for scheme in env.values() {
            set.extend(self.free_type_vars_scheme(scheme));
        }
        set
    }

    fn generalize(&self, env: &Env, ty: &Type) -> Scheme {
        let type_vars = self.free_type_vars(ty);
        let env_vars = self.free_type_vars_env(env);
        let mut free_vars: Vec<_> = type_vars.difference(&env_vars).cloned().collect();
        free_vars.sort(); // Sort for deterministic behavior

        Scheme {
            vars: free_vars,
            ty: ty.clone(),
        }
    }

    fn instantiate(&mut self, scheme: &Scheme) -> Type {
        // Create fresh type variables for each quantified variable
        let mut subst = HashMap::new();
        for var in &scheme.vars {
            let fresh = self.fresh_tyvar();
            subst.insert(var.clone(), Type::Var(fresh));
        }

        self.apply_subst(&subst, &scheme.ty)
    }

    fn occurs_check(&self, var: &TyVar, ty: &Type) -> bool {
        match ty {
            Type::Var(name) => name == var,
            Type::Arrow(t1, t2) => self.occurs_check(var, t1) || self.occurs_check(var, t2),
            Type::Tuple(types) => types.iter().any(|t| self.occurs_check(var, t)),
            Type::Int | Type::Bool => false,
        }
    }

    fn unify(&self, t1: &Type, t2: &Type) -> Result<(Subst, InferenceTree)> {
        let input = format!("{} ~ {}", t1, t2);

        match (t1, t2) {
            (Type::Int, Type::Int) | (Type::Bool, Type::Bool) => {
                let tree = InferenceTree::new("Unify-Base", &input, "{}", vec![]);
                Ok((HashMap::new(), tree))
            }
            (Type::Var(v), ty) | (ty, Type::Var(v)) => {
                if ty == &Type::Var(v.clone()) {
                    let tree = InferenceTree::new("Unify-Var-Same", &input, "{}", vec![]);
                    Ok((HashMap::new(), tree))
                } else if self.occurs_check(v, ty) {
                    Err(InferenceError::OccursCheck {
                        var: v.clone(),
                        ty: ty.clone(),
                    })
                } else {
                    let mut subst = HashMap::new();
                    subst.insert(v.clone(), ty.clone());
                    let output = format!("{{{}/{}}}", ty, v);
                    let tree = InferenceTree::new("Unify-Var", &input, &output, vec![]);
                    Ok((subst, tree))
                }
            }
            (Type::Arrow(a1, a2), Type::Arrow(b1, b2)) => {
                let (s1, tree1) = self.unify(a1, b1)?;
                let a2_subst = self.apply_subst(&s1, a2);
                let b2_subst = self.apply_subst(&s1, b2);
                let (s2, tree2) = self.unify(&a2_subst, &b2_subst)?;
                let final_subst = self.compose_subst(&s2, &s1);
                let output = self.pretty_subst(&final_subst);
                let tree = InferenceTree::new("Unify-Arrow", &input, &output, vec![tree1, tree2]);
                Ok((final_subst, tree))
            }
            (Type::Tuple(ts1), Type::Tuple(ts2)) => {
                if ts1.len() != ts2.len() {
                    return Err(InferenceError::TupleLengthMismatch {
                        left_len: ts1.len(),
                        right_len: ts2.len(),
                    });
                }

                let mut subst = HashMap::new();
                let mut trees = Vec::new();

                for (t1, t2) in ts1.iter().zip(ts2.iter()) {
                    let t1_subst = self.apply_subst(&subst, t1);
                    let t2_subst = self.apply_subst(&subst, t2);
                    let (s, tree) = self.unify(&t1_subst, &t2_subst)?;
                    subst = self.compose_subst(&s, &subst);
                    trees.push(tree);
                }

                let output = self.pretty_subst(&subst);
                let tree = InferenceTree::new("Unify-Tuple", &input, &output, trees);
                Ok((subst, tree))
            }
            _ => Err(InferenceError::UnificationFailure {
                expected: t1.clone(),
                actual: t2.clone(),
            }),
        }
    }

    pub fn infer(&mut self, env: &Env, expr: &Expr) -> Result<(Subst, Type, InferenceTree)> {
        match expr {
            Expr::Lit(Lit::Int(_)) => self.infer_lit_int(env, expr),
            Expr::Lit(Lit::Bool(_)) => self.infer_lit_bool(env, expr),
            Expr::Var(name) => self.infer_var(env, expr, name),
            Expr::Abs(param, body) => self.infer_abs(env, expr, param, body),
            Expr::App(func, arg) => self.infer_app(env, expr, func, arg),
            Expr::Let(var, value, body) => self.infer_let(env, expr, var, value, body),
            Expr::Tuple(exprs) => self.infer_tuple(env, expr, exprs),
        }
    }

    /// T-LitInt: ─────────────────
    ///           Γ ⊢ n : Int
    fn infer_lit_int(&mut self, env: &Env, expr: &Expr) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);
        let tree = InferenceTree::new("T-Int", &input, "Int", vec![]);
        Ok((HashMap::new(), Type::Int, tree))
    }

    /// T-LitBool: ─────────────────
    ///            Γ ⊢ b : Bool
    fn infer_lit_bool(&mut self, env: &Env, expr: &Expr) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);
        let tree = InferenceTree::new("T-Bool", &input, "Bool", vec![]);
        Ok((HashMap::new(), Type::Bool, tree))
    }

    /// T-Var: x : σ ∈ Γ    τ = inst(σ)
    ///        ─────────────────────────
    ///               Γ ⊢ x : τ
    fn infer_var(
        &mut self,
        env: &Env,
        expr: &Expr,
        name: &str,
    ) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

        match env.get(name) {
            Some(scheme) => {
                let instantiated = self.instantiate(scheme);
                let output = format!("{}", instantiated);
                let tree = InferenceTree::new("T-Var", &input, &output, vec![]);
                Ok((HashMap::new(), instantiated, tree))
            }
            None => Err(InferenceError::UnboundVariable {
                name: name.to_string(),
            }),
        }
    }

    /// T-Lam: Γ, x : α ⊢ e : τ    α fresh
    ///        ─────────────────────────────
    ///           Γ ⊢ λx. e : α → τ
    fn infer_abs(
        &mut self,
        env: &Env,
        expr: &Expr,
        param: &str,
        body: &Expr,
    ) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

        let param_type = Type::Var(self.fresh_tyvar());
        let mut new_env = env.clone();
        // Insert a monomorphic scheme for the parameter
        let param_scheme = Scheme {
            vars: vec![],
            ty: param_type.clone(),
        };
        new_env.insert(param.to_string(), param_scheme);

        let (s1, body_type, tree1) = self.infer(&new_env, body)?;
        let param_type_subst = self.apply_subst(&s1, &param_type);
        let result_type = Type::Arrow(Box::new(param_type_subst), Box::new(body_type));

        let output = format!("{}", result_type);
        let tree = InferenceTree::new("T-Abs", &input, &output, vec![tree1]);
        Ok((s1, result_type, tree))
    }

    /// T-App: Γ ⊢ e₁ : τ₁    Γ ⊢ e₂ : τ₂    α fresh    S = unify(τ₁, τ₂ → α)
    ///        ──────────────────────────────────────────────────────────────
    ///                            Γ ⊢ e₁ e₂ : S(α)
    fn infer_app(
        &mut self,
        env: &Env,
        expr: &Expr,
        func: &Expr,
        arg: &Expr,
    ) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

        let result_type = Type::Var(self.fresh_tyvar());

        let (s1, func_type, tree1) = self.infer(env, func)?;
        let env_subst = self.apply_subst_env(&s1, env);
        let (s2, arg_type, tree2) = self.infer(&env_subst, arg)?;

        let func_type_subst = self.apply_subst(&s2, &func_type);
        let expected_func_type = Type::Arrow(Box::new(arg_type), Box::new(result_type.clone()));

        let (s3, tree3) = self.unify(&func_type_subst, &expected_func_type)?;

        let final_subst = self.compose_subst(&s3, &self.compose_subst(&s2, &s1));
        let final_type = self.apply_subst(&s3, &result_type);

        let output = format!("{}", final_type);
        let tree = InferenceTree::new("T-App", &input, &output, vec![tree1, tree2, tree3]);
        Ok((final_subst, final_type, tree))
    }

    /// T-Let: Γ ⊢ e₁ : τ₁    σ = gen(Γ, τ₁)    Γ, x : σ ⊢ e₂ : τ₂
    ///        ──────────────────────────────────────────────────────
    ///                     Γ ⊢ let x = e₁ in e₂ : τ₂
    fn infer_let(
        &mut self,
        env: &Env,
        expr: &Expr,
        var: &str,
        value: &Expr,
        body: &Expr,
    ) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

        let (s1, value_type, tree1) = self.infer(env, value)?;
        let env_subst = self.apply_subst_env(&s1, env);
        let generalized_type = self.generalize(&env_subst, &value_type);

        let mut new_env = env_subst;
        new_env.insert(var.to_string(), generalized_type);

        let (s2, body_type, tree2) = self.infer(&new_env, body)?;

        let final_subst = self.compose_subst(&s2, &s1);
        let output = format!("{}", body_type);
        let tree = InferenceTree::new("T-Let", &input, &output, vec![tree1, tree2]);
        Ok((final_subst, body_type, tree))
    }

    /// T-Tuple: Γ ⊢ e₁ : τ₁    ...    Γ ⊢ eₙ : τₙ
    ///          ─────────────────────────────────────
    ///              Γ ⊢ (e₁, ..., eₙ) : (τ₁, ..., τₙ)
    fn infer_tuple(
        &mut self,
        env: &Env,
        expr: &Expr,
        exprs: &[Expr],
    ) -> Result<(Subst, Type, InferenceTree)> {
        let input = format!("{} ⊢ {} ⇒", self.pretty_env(env), expr);

        let mut subst = HashMap::new();
        let mut types = Vec::new();
        let mut trees = Vec::new();
        let mut current_env = env.clone();

        for expr in exprs {
            let (s, ty, tree) = self.infer(&current_env, expr)?;
            subst = self.compose_subst(&s, &subst);
            current_env = self.apply_subst_env(&s, &current_env);
            types.push(ty);
            trees.push(tree);
        }

        let result_type = Type::Tuple(types);
        let output = format!("{}", result_type);
        let tree = InferenceTree::new("T-Tuple", &input, &output, trees);
        Ok((subst, result_type, tree))
    }
}
pub fn run_inference(expr: &Expr) -> Result<InferenceTree> {
    let mut inference = TypeInference::new();
    let env = BTreeMap::new();
    let (_, _, tree) = inference.infer(&env, expr)?;
    Ok(tree)
}
pub fn infer_type_only(expr: &Expr) -> Result<Type> {
    let mut inference = TypeInference::new();
    let env = BTreeMap::new();
    let (_, ty, _) = inference.infer(&env, expr)?;
    Ok(ty)
}
}

Example Usage

To see Algorithm W in action, let’s type-check a polymorphic function that demonstrates let-polymorphism and generalization:

$ cargo run -- "let const = \\x -> \\y -> x in const 42 true"

This produces the following clean output showing the complete inference process:

Parsed expression: let const = λx.λy.x in const 42 true

Type inference successful!
Final type: Int

Inference trace:
T-Let: {} ⊢ let const = λx.λy.x in const 42 true ⇒ => Int
  T-Abs: {} ⊢ λx.λy.x ⇒ => t0 → t1 → t0
    T-Abs: {x: t0} ⊢ λy.x ⇒ => t1 → t0
      T-Var: {x: t0, y: t1} ⊢ x ⇒ => t0
  T-App: {const: forall t0 t1. t0 → t1 → t0} ⊢ const 42 true ⇒ => Int
    T-App: {const: forall t0 t1. t0 → t1 → t0} ⊢ const 42 ⇒ => t5 → Int
      T-Var: {const: forall t0 t1. t0 → t1 → t0} ⊢ const ⇒ => t4 → t5 → t4
      T-Int: {const: forall t0 t1. t0 → t1 → t0} ⊢ 42 ⇒ => Int
      Unify-Arrow: t4 → t5 → t4 ~ Int → t3 => {t5 → Int/t3, Int/t4}
        Unify-Var: t4 ~ Int => {Int/t4}
        Unify-Var: t5 → Int ~ t3 => {t5 → Int/t3}
    T-Bool: {const: forall t0 t1. t0 → t1 → t0} ⊢ true ⇒ => Bool
    Unify-Arrow: t5 → Int ~ Bool → t2 => {Int/t2, Bool/t5}
      Unify-Var: t5 ~ Bool => {Bool/t5}
      Unify-Var: Int ~ t2 => {Int/t2}

This trace shows several key aspects of Algorithm W:

  1. Generalization: The lambda \x -> \y -> x initially gets type t0 → t1 → t0, but when bound to const in the let-expression, it’s generalized to ∀t0 t1. t0 → t1 → t0.

  2. Instantiation: When const is used in the application, it gets instantiated with fresh type variables t4 and t5, allowing it to be used polymorphically.

  3. Unification: The constraints from applying const to 42 and true get solved through unification, determining that t4 = Int, t5 = Bool, and the final result type is Int.

The final result is Int, showing that const 42 true correctly returns the first argument (42) regardless of the second argument’s type (true).

These interface functions demonstrate how Algorithm W can be embedded into larger systems. The tree-generating version supports educational tools and debuggers, while the type-only version provides the minimal interface needed for type checking during compilation.