
Introduction
This project started as a collection of personal notes about Rust libraries useful for compiler development. As I explored different crates and built prototypes, I found myself repeatedly looking up the same patterns and examples. What began as scattered markdown files evolved into a structured internal reference document for our team.
The guide focuses on practical, compiler-specific use cases for each crate. Rather than duplicating existing documentation, it shows how these libraries solve real problems in lexing, parsing, type checking, and code generation. Each example is a working implementation that demonstrates patterns we’ve found effective in production compiler projects.
We’re sharing this publicly in the hope that others building compilers in Rust will find it useful. The examples are intentionally concise and focused on compiler engineering tasks. All code is tested and ready to use as a starting point for your own implementations.
The source code for all examples in this guide is available at https://github.com/sdiehl/compiler-crates.
Technology Stacks
Choosing the right combination of crates can significantly impact your compiler project’s success. Here are our recommendations based on different use cases and experience levels.
Use Case | Parsing | Lexing | Code Generation | Error Reporting |
---|---|---|---|---|
Simple | pest or chumsky | - | cranelift | ariadne |
Rapid Prototyping | pest or chumsky | - | cranelift | codespan-reporting |
Performance-Critical | lalrpop | logos | inkwell | codespan-reporting |
Production Compilers | lalrpop | logos | melior | codespan-reporting |
The examples in this guide demonstrate these combinations in practice, showing how different crates work together to build complete compiler pipelines.
Terminology
So you want to build a new compiler? Building a compiler is one of the most challenging and rewarding projects you can undertake. Some compilers rival operating systems in their complexity, but the journey of creating one provides deep insights into how programming languages work at their most fundamental level.
Building a compiler might be the right choice when a compiler for your desired language and target platform simply doesn’t exist. Perhaps you’re creating a domain-specific language for your industry, targeting unusual hardware, or implementing a new programming paradigm. Beyond practical needs, compiler development is profoundly educational. You’ll gain intimate knowledge of how kernels, compilers, and runtime libraries interact, and you’ll understand what it truly takes to implement a programming language. Is it strictly neccessary to learn? No. But so few things in life are.
However, compiler development demands significant knowledge. You need complete understanding of your source language specification and deep knowledge of your target architecture’s assembly language. Creating a production-quality compiler that rivals GCC or LLVM in optimization capabilities requires years (if not decades) of dedicated work. Full compliance with language specifications proves surprisingly difficult, as edge cases and subtle interactions between features often reveal themselves only during implementation.
Understanding compiler terminology helps navigate the field’s extensive literature. The host system runs the compiler itself, while the target system runs the compiled programs. When host and target match, you produce native executables. When they differ, you’ve built a cross-compiler. The runtime encompasses libraries and processes available on the target system that programs depend on. Two machines with identical hardware but different runtime resources effectively become different targets.
An executable contains all information necessary for the target system to launch your program. While it could be a flat binary, modern executables include linking information, relocation data, and metadata. The linker creates connections between your program and the runtime it depends on. Programs without these dependencies, like operating system kernels, are called freestanding. A compiler capable of compiling its own source code is self-hosting, representing a significant milestone in compiler maturity.
Modern compilers divide their work into distinct phases, each handling specific transformation tasks. This modular architecture enables code reuse and simplifies development. The standard pipeline consists of three major components working in sequence.
The front end accepts source code in a specific programming language and transforms it into an intermediate representation (or IR for short). This phase handles all language-specific processing including parsing syntax, checking types, and resolving names. By producing a common IR, front ends for different languages can share the same optimization and code generation infrastructure.
The middle end operates on the intermediate representation to improve code quality. This optional phase applies optimization algorithms that eliminate redundancy, improve performance, and reduce code size. Because it works with abstract IR rather than source code or machine code, optimizations implemented here benefit all source languages and target architectures.
The back end consumes intermediate representation and produces executable code for specific target architectures. This phase handles machine-specific concerns like register allocation, instruction selection, and executable file format generation. By separating target-specific code into the back end, the same IR can be compiled for different architectures without modifying earlier phases.
Front End Components
The front end transforms human-readable source code into a form suitable for analysis and optimization. After accepting files and processing command-line options through its user interface, the front end processes code through several stages.
A preprocessor handles textual transformations before compilation begins. In C-like languages, this includes copying header file contents, expanding macros, and processing conditional compilation directives. The preprocessor works purely with text, unaware of the language’s actual syntax.
The scanner (or lexer) reads preprocessed source text and produces a stream of tokens representing the language’s basic vocabulary. Each identifier, keyword, operator, and literal becomes a discrete token. The scanner handles details like recognizing number formats, processing string escape sequences, and skipping whitespace.
The parser consumes tokens and constructs a tree structure representing the program’s syntactic structure. This parse tree captures how tokens group into expressions, statements, and declarations according to the language grammar. Modern parsers often build more abstract representations that omit syntactic noise like parentheses and semicolons.
The semantic analyzer traverses the parse tree to determine meaning. It builds symbol tables mapping names to their declarations, checks that types match correctly, and verifies that the program follows all language rules not captured by the grammar. This phase transforms a syntactically valid program into a semantically valid one.
A type checker is often an integral part of semantic analysis. It ensures that operations are applied to compatible types, infers types where possible, and enforces language-specific type rules. Type checking can be simple in dynamically-typed languages or complex in statically-typed languages with features like generics and polymorphism. See my Typechecker Zoo writeup for more details on this phase.
Finally, the front end generates its intermediate representation. A well-designed IR captures all source program semantics while adding explicit information about types, control flow, and data dependencies that later phases need. The IR might resemble the parse tree, use three-address code, employ static single assignment form, or adopt more exotic representations.
Middle End Processing
The middle end hosts numerous optimization passes that improve program performance without changing its observable behavior. These optimizations work at various granularities from individual instructions to whole-program transformations.
Common optimizations include dead code elimination to remove unreachable or unnecessary code, constant propagation to replace variables with known values, and loop optimizations that reduce iteration overhead. More sophisticated techniques like inlining, vectorization, and escape analysis require complex analysis but can dramatically improve performance.
The middle end also performs essential transformations even in unoptimized builds. These include lowering high-level constructs to simpler operations and inserting runtime checks required by the language specification. While you might omit sophisticated optimizations in a simple compiler, some middle-end processing often proves necessary.
Back End Components
The back end bridges the gap between abstract intermediate representation and concrete machine code. This phase handles all target-specific details that earlier phases deliberately ignored.
The code generator traverses the IR and emits assembly-like instructions. To simplify this process, it typically assumes unlimited registers and ignores calling conventions initially. This pseudo-assembly captures the desired computation without committing to specific resource allocation decisions.
The register allocator maps the code generator’s virtual registers onto the limited physical registers available on the target CPU. This critically important phase uses sophisticated algorithms to minimize memory traffic by keeping frequently-used values in registers. Poor register allocation can devastate performance.
The assembler translates assembly language into machine code bytes. It encodes each instruction according to the target architecture’s rules and tracks addresses of labels for jump instructions. The assembler produces object files containing machine code plus metadata about symbols and relocations.
The linker combines object files with libraries to create complete executables. It resolves references between compilation units, performs relocations to assign final addresses, and adds runtime loader information. While sometimes considered separate from compilation, linking remains essential for producing runnable programs.
Remember that compiler construction is as much an empirical engineering discipline as theoretical knowledge. Start simple, test thoroughly, and gradually add complexity. Even a basic compiler that handles a subset of a language provides valuable learning experiences.
But we live in great times where much of the complexity has been abstracted into off-the-shelf libraries. The crates covered here give you professional-quality tools to build upon, letting you focus on the interesting problems unique to your language and target platform. It’s never been a better time to be a compiler developer.
Parser Comparison
This page provides a comprehensive comparison of actual parser generators and parser combinator libraries covered in this guide. Each parser uses different parsing algorithms and techniques, making them suitable for different language implementation scenarios.
Legend
- 🟢 Full support
- 🟡 Partial support or with limitations
- 🔴 Not supported
Parser Overview
Parser | Type | Algorithm | Grammar Format | Performance | Error Recovery | Learning Curve | Production Ready | Best For |
---|---|---|---|---|---|---|---|---|
nom | Parser Combinator | Recursive Descent | Rust combinators | ⭐⭐⭐⭐⭐ | ⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | Binary formats, streaming protocols |
pest | Parser Generator | PEG | External .pest files | ⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐ | ⭐⭐⭐⭐ | Prototyping, DSLs, configuration languages |
lalrpop | Parser Generator | LALR(1) | External .lalrpop files | ⭐⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐⭐ | Production compilers, programming languages |
chumsky | Parser Combinator | Recursive Descent | Rust combinators | ⭐⭐⭐ | ⭐⭐⭐⭐⭐ | ⭐⭐ | ⭐⭐⭐ | Error recovery, IDE support |
winnow | Parser Combinator | Recursive Descent | Rust combinators | ⭐⭐⭐⭐⭐ | ⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐⭐ | Successor to nom, cleaner API |
pom | Parser Combinator | Recursive Descent | Rust combinators | ⭐⭐⭐⭐ | ⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐ | Simple parsers, educational |
Parsing Algorithm Characteristics
Algorithm | Left Recursion | Ambiguity | Backtracking | Memory Usage | Parse Time | Lookahead |
---|---|---|---|---|---|---|
LALR(1) | 🟢 Handles naturally | 🔴 Must resolve | 🔴 None | Low | O(n) | 1 token |
PEG | 🔴 Requires rewriting | 🟢 First match wins | 🟢 Unlimited | Medium | O(n) typical, O(n²) worst | Unlimited |
Recursive Descent | 🔴 Stack overflow | 🟢 Can handle | 🟢 Manual | Low | O(n) to O(n²) | Unlimited |
Detailed Feature Comparison
Feature | nom | pest | lalrpop | chumsky | winnow | pom |
---|---|---|---|---|---|---|
Grammar Definition | ||||||
External grammar files | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 |
Inline in Rust code | 🟢 | 🔴 | 🔴 | 🟢 | 🟢 | 🟢 |
Type-safe | 🟢 | 🟡 | 🟢 | 🟢 | 🟢 | 🟢 |
Grammar validation | Runtime | Runtime | Compile-time | Runtime | Runtime | Runtime |
Parsing Features | ||||||
Streaming input | 🟢 | 🔴 | 🔴 | 🟡 | 🟢 | 🔴 |
Zero-copy parsing | 🟢 | 🟢 | 🟢 | 🟡 | 🟢 | 🟢 |
Incremental parsing | 🔴 | 🔴 | 🔴 | 🔴 | 🔴 | 🔴 |
Memoization/Packrat | 🔴 | 🟢 | 🔴 | 🟡 | 🔴 | 🟡 |
Custom lexer support | 🟢 | N/A | 🟢 | 🟢 | 🟢 | 🟢 |
Error Handling | ||||||
Error recovery | 🔴 | 🟢 | 🟡 | 🟢 | 🔴 | 🔴 |
Custom error types | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 |
Error position tracking | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 |
Multiple errors | 🔴 | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 |
Contextual errors | 🟢 | 🟢 | 🟡 | 🟢 | 🟢 | 🟡 |
AST Generation | ||||||
Automatic AST generation | 🔴 | 🟡 | 🟢 | 🔴 | 🔴 | 🔴 |
Custom AST types | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 |
Location spans | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 |
Development Experience | ||||||
IDE support | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
Debugging tools | 🟡 | 🟢 | 🟢 | 🟢 | 🟡 | 🟡 |
Documentation quality | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐⭐ | ⭐⭐⭐ |
Grammar Complexity Support
Feature | nom | pest | lalrpop | chumsky | winnow | pom |
---|---|---|---|---|---|---|
Grammar Types | ||||||
Context-free | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 |
Context-sensitive | 🟢 | 🔴 | 🔴 | 🟡 | 🟢 | 🟡 |
Ambiguous grammars | 🟢 | 🟡 | 🔴 | 🟢 | 🟢 | 🟢 |
Advanced Features | ||||||
Left recursion | 🟡* | 🔴 | 🟢 | 🔴 | 🟡* | 🔴 |
Operator precedence | Manual | 🟢 | 🟢 | 🟢 | Manual | Manual |
Parameterized rules | 🟢 | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 |
Semantic predicates | 🟢 | 🟢 | 🔴 | 🟢 | 🟢 | 🟢 |
*Can be handled with special combinators or techniques
tl;dr Recommendations
Choose nom when:
- Parsing binary formats or network protocols
- Need streaming/incremental parsing
- Performance is critical
- Want fine-grained control over parsing
Choose pest when:
- Rapid prototyping of new languages
- Grammar readability is important
- Need good error messages out of the box
- Working with configuration languages or DSLs
Choose lalrpop when:
- Building production programming language compilers
- Grammar has left recursion
- Need maximum parsing performance
- Want compile-time grammar validation
Choose chumsky when:
- Error recovery is critical (IDE/LSP scenarios)
- Need excellent error messages
- Building development tools
- Want modern combinator API
Choose winnow when:
- Starting a new project (nom successor)
- Want cleaner API than nom
- Need streaming support
- Performance is important
Choose pom when:
- Learning parser combinators
- Building simple parsers
- Want minimal dependencies
- Prefer simple, readable code
In short:
- For maximum performance: LALRPOP (with grammar restrictions) or nom/winnow (with flexibility)
- For best developer experience: pest (external grammars) or chumsky (error recovery)
- For binary formats: nom or winnow are specifically designed for this
- For production compilers: LALRPOP provides the traditional compiler construction approach
- For learning: pom offers the simplest mental model
Each parser makes different trade-offs between performance, expressiveness, error handling, and ease of use. Consider your specific requirements carefully when making a selection.
chumsky
Chumsky is a parser combinator library that emphasizes error recovery, performance, and ease of use. Unlike traditional parser generators, chumsky builds parsers from small, composable functions that can be combined to parse complex grammars. The library excels at providing detailed error messages and recovering from parse errors to continue processing malformed input.
Parser combinators in chumsky follow a functional programming style where parsers are values that can be composed using combinator functions. Each parser is a function that consumes input and produces either a parsed value or an error. The library provides extensive built-in combinators for common patterns like repetition, choice, and sequencing.
Core Parser Types
#![allow(unused)] fn main() { use std::fmt; use chumsky::error::Rich; use chumsky::extra; use chumsky::prelude::*; /// Binary operators #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, Eq, Lt, } /// Unary operators #[derive(Debug, Clone, PartialEq)] pub enum UnOp { Neg, Not, } /// Parse a simple expression language with operator precedence pub fn expr_parser<'src>() -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> { let ident = text::ident().padded().to_slice(); let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap())) .padded(); let atom = recursive(|expr| { let args = expr .clone() .separated_by(just(',')) .allow_trailing() .collect() .delimited_by(just('('), just(')')); let call = ident .then(args) .map(|(name, args): (&str, Vec<Expr>)| Expr::Call(name.to_string(), args)); let let_binding = text::keyword("let") .ignore_then(ident) .then_ignore(just('=')) .then(expr.clone()) .then_ignore(text::keyword("in")) .then(expr.clone()) .map(|((name, value), body): ((&str, Expr), Expr)| { Expr::Let(name.to_string(), Box::new(value), Box::new(body)) }); choice(( number, call, let_binding, ident.map(|s: &str| Expr::Identifier(s.to_string())), expr.delimited_by(just('('), just(')')), )) }) .padded(); let unary = just('-') .repeated() .collect::<Vec<_>>() .then(atom.clone()) .map(|(ops, expr)| { ops.into_iter() .fold(expr, |expr, _| Expr::Unary(UnOp::Neg, Box::new(expr))) }); let product = unary.clone().foldl( choice((just('*').to(BinOp::Mul), just('/').to(BinOp::Div))) .then(unary) .repeated(), |left, (op, right)| Expr::Binary(op, Box::new(left), Box::new(right)), ); let sum = product.clone().foldl( choice((just('+').to(BinOp::Add), just('-').to(BinOp::Sub))) .then(product) .repeated(), |left, (op, right)| Expr::Binary(op, Box::new(left), Box::new(right)), ); sum.then_ignore(end()) } /// Token type for lexing #[derive(Debug, Clone, PartialEq)] pub enum Token { Number(f64), Identifier(String), Keyword(String), Op(char), Delimiter(char), } impl fmt::Display for Token { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Token::Number(n) => write!(f, "{}", n), Token::Identifier(s) | Token::Keyword(s) => write!(f, "{}", s), Token::Op(c) | Token::Delimiter(c) => write!(f, "{}", c), } } } /// Lexer that produces tokens with spans pub fn lexer<'src>( ) -> impl Parser<'src, &'src str, Vec<(Token, SimpleSpan)>, extra::Err<Rich<'src, char>>> { let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Token::Number(s.parse().unwrap())); let identifier = text::ident().to_slice().map(|s: &str| match s { "let" | "in" | "if" | "then" | "else" => Token::Keyword(s.to_string()), _ => Token::Identifier(s.to_string()), }); let op = one_of("+-*/=<>!&|").map(Token::Op); let delimiter = one_of("(){}[],;").map(Token::Delimiter); let token = choice((number, identifier, op, delimiter)).padded_by(text::whitespace()); token .map_with(|tok, e| (tok, e.span())) .repeated() .collect() .then_ignore(end()) } /// Parser with error recovery pub fn robust_parser<'src>() -> impl Parser<'src, &'src str, Vec<Expr>, extra::Err<Rich<'src, char>>> { let ident = text::ident().padded().to_slice(); let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap_or(0.0))) .padded(); let expr = recursive(|expr| { let atom = choice(( number, ident.map(|s: &str| Expr::Identifier(s.to_string())), expr.clone() .delimited_by(just('('), just(')')) .recover_with(via_parser(nested_delimiters( '(', ')', [('{', '}'), ('[', ']')], |_| Expr::Number(0.0), ))), )); atom }); expr.separated_by(just(';')) .allow_leading() .allow_trailing() .collect() .then_ignore(end()) } /// Custom parser combinator for binary operators with precedence pub fn binary_op_parser<'src>( ops: &[(&'src str, BinOp)], next: impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> + Clone + 'src, ) -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> + Clone + 'src { let op = choice( ops.iter() .map(|(s, op)| just(*s).to(op.clone())) .collect::<Vec<_>>(), ); next.clone() .foldl(op.then(next).repeated(), |left, (op, right)| { Expr::Binary(op, Box::new(left), Box::new(right)) }) } /// Parser with custom error types #[derive(Debug, Clone, PartialEq)] pub enum ParseError { UnexpectedToken(String), UnclosedDelimiter(char), InvalidNumber(String), } pub fn validated_parser<'src>() -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> { let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .try_map(|s: &str, span| { s.parse::<f64>() .map(Expr::Number) .map_err(|_| Rich::custom(span, format!("Invalid number: {}", s))) }); let ident = text::ident().to_slice().try_map(|s: &str, span| { if s.len() > 100 { Err(Rich::custom(span, "Identifier too long")) } else { Ok(Expr::Identifier(s.to_string())) } }); choice((number, ident)).then_ignore(end()) } #[cfg(test)] mod tests { use super::*; #[test] fn test_expr_parser() { let parser = expr_parser(); let input = "2 + 3 * 4"; let result = parser.parse(input); assert!(!result.has_errors()); match result.into_output().unwrap() { Expr::Binary(BinOp::Add, left, right) => { assert_eq!(*left, Expr::Number(2.0)); match *right { Expr::Binary(BinOp::Mul, l, r) => { assert_eq!(*l, Expr::Number(3.0)); assert_eq!(*r, Expr::Number(4.0)); } _ => panic!("Expected multiplication on right"), } } _ => panic!("Expected addition at top level"), } } #[test] fn test_lexer() { let lexer = lexer(); let input = "let x = 42 + 3.14"; let result = lexer.parse(input); assert!(!result.has_errors()); let tokens = result.into_output().unwrap(); assert_eq!(tokens.len(), 6); // let, x, =, 42, +, 3.14 assert_eq!(tokens[0].0, Token::Keyword("let".to_string())); assert_eq!(tokens[1].0, Token::Identifier("x".to_string())); } #[test] fn test_robust_parser() { let parser = robust_parser(); // Test with valid input let input = "42; x; y"; let result = parser.parse(input); assert!(!result.has_errors()); assert_eq!(result.into_output().unwrap().len(), 3); // Test with recovery - unclosed paren let input_with_error = "42; (x; y"; let result = parser.parse(input_with_error); // The parser should still produce some output even with errors assert!(result.has_errors()); } #[test] fn test_binary_op_parser() { let ops = &[("&&", BinOp::Eq), ("||", BinOp::Eq)]; let atom = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap())) .padded(); let parser = binary_op_parser(ops, atom); let result = parser.parse("1 && 2 || 3"); assert!(!result.has_errors()); } #[test] fn test_validated_parser() { let parser = validated_parser(); // Test valid input let result = parser.parse("42"); assert!(!result.has_errors()); assert_eq!(result.into_output().unwrap(), Expr::Number(42.0)); // Test invalid number - this should produce an error let result = parser.parse("12.34.56"); assert!(result.has_errors()); } } /// AST node for expressions #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(f64), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Unary(UnOp, Box<Expr>), Call(String, Vec<Expr>), Let(String, Box<Expr>, Box<Expr>), } }
The expression type represents the abstract syntax tree nodes that parsers produce. Chumsky parsers transform character streams into structured data like this AST.
#![allow(unused)] fn main() { use std::fmt; use chumsky::error::Rich; use chumsky::extra; use chumsky::prelude::*; /// AST node for expressions #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(f64), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Unary(UnOp, Box<Expr>), Call(String, Vec<Expr>), Let(String, Box<Expr>, Box<Expr>), } /// Unary operators #[derive(Debug, Clone, PartialEq)] pub enum UnOp { Neg, Not, } /// Parse a simple expression language with operator precedence pub fn expr_parser<'src>() -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> { let ident = text::ident().padded().to_slice(); let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap())) .padded(); let atom = recursive(|expr| { let args = expr .clone() .separated_by(just(',')) .allow_trailing() .collect() .delimited_by(just('('), just(')')); let call = ident .then(args) .map(|(name, args): (&str, Vec<Expr>)| Expr::Call(name.to_string(), args)); let let_binding = text::keyword("let") .ignore_then(ident) .then_ignore(just('=')) .then(expr.clone()) .then_ignore(text::keyword("in")) .then(expr.clone()) .map(|((name, value), body): ((&str, Expr), Expr)| { Expr::Let(name.to_string(), Box::new(value), Box::new(body)) }); choice(( number, call, let_binding, ident.map(|s: &str| Expr::Identifier(s.to_string())), expr.delimited_by(just('('), just(')')), )) }) .padded(); let unary = just('-') .repeated() .collect::<Vec<_>>() .then(atom.clone()) .map(|(ops, expr)| { ops.into_iter() .fold(expr, |expr, _| Expr::Unary(UnOp::Neg, Box::new(expr))) }); let product = unary.clone().foldl( choice((just('*').to(BinOp::Mul), just('/').to(BinOp::Div))) .then(unary) .repeated(), |left, (op, right)| Expr::Binary(op, Box::new(left), Box::new(right)), ); let sum = product.clone().foldl( choice((just('+').to(BinOp::Add), just('-').to(BinOp::Sub))) .then(product) .repeated(), |left, (op, right)| Expr::Binary(op, Box::new(left), Box::new(right)), ); sum.then_ignore(end()) } /// Token type for lexing #[derive(Debug, Clone, PartialEq)] pub enum Token { Number(f64), Identifier(String), Keyword(String), Op(char), Delimiter(char), } impl fmt::Display for Token { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Token::Number(n) => write!(f, "{}", n), Token::Identifier(s) | Token::Keyword(s) => write!(f, "{}", s), Token::Op(c) | Token::Delimiter(c) => write!(f, "{}", c), } } } /// Lexer that produces tokens with spans pub fn lexer<'src>( ) -> impl Parser<'src, &'src str, Vec<(Token, SimpleSpan)>, extra::Err<Rich<'src, char>>> { let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Token::Number(s.parse().unwrap())); let identifier = text::ident().to_slice().map(|s: &str| match s { "let" | "in" | "if" | "then" | "else" => Token::Keyword(s.to_string()), _ => Token::Identifier(s.to_string()), }); let op = one_of("+-*/=<>!&|").map(Token::Op); let delimiter = one_of("(){}[],;").map(Token::Delimiter); let token = choice((number, identifier, op, delimiter)).padded_by(text::whitespace()); token .map_with(|tok, e| (tok, e.span())) .repeated() .collect() .then_ignore(end()) } /// Parser with error recovery pub fn robust_parser<'src>() -> impl Parser<'src, &'src str, Vec<Expr>, extra::Err<Rich<'src, char>>> { let ident = text::ident().padded().to_slice(); let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap_or(0.0))) .padded(); let expr = recursive(|expr| { let atom = choice(( number, ident.map(|s: &str| Expr::Identifier(s.to_string())), expr.clone() .delimited_by(just('('), just(')')) .recover_with(via_parser(nested_delimiters( '(', ')', [('{', '}'), ('[', ']')], |_| Expr::Number(0.0), ))), )); atom }); expr.separated_by(just(';')) .allow_leading() .allow_trailing() .collect() .then_ignore(end()) } /// Custom parser combinator for binary operators with precedence pub fn binary_op_parser<'src>( ops: &[(&'src str, BinOp)], next: impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> + Clone + 'src, ) -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> + Clone + 'src { let op = choice( ops.iter() .map(|(s, op)| just(*s).to(op.clone())) .collect::<Vec<_>>(), ); next.clone() .foldl(op.then(next).repeated(), |left, (op, right)| { Expr::Binary(op, Box::new(left), Box::new(right)) }) } /// Parser with custom error types #[derive(Debug, Clone, PartialEq)] pub enum ParseError { UnexpectedToken(String), UnclosedDelimiter(char), InvalidNumber(String), } pub fn validated_parser<'src>() -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> { let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .try_map(|s: &str, span| { s.parse::<f64>() .map(Expr::Number) .map_err(|_| Rich::custom(span, format!("Invalid number: {}", s))) }); let ident = text::ident().to_slice().try_map(|s: &str, span| { if s.len() > 100 { Err(Rich::custom(span, "Identifier too long")) } else { Ok(Expr::Identifier(s.to_string())) } }); choice((number, ident)).then_ignore(end()) } #[cfg(test)] mod tests { use super::*; #[test] fn test_expr_parser() { let parser = expr_parser(); let input = "2 + 3 * 4"; let result = parser.parse(input); assert!(!result.has_errors()); match result.into_output().unwrap() { Expr::Binary(BinOp::Add, left, right) => { assert_eq!(*left, Expr::Number(2.0)); match *right { Expr::Binary(BinOp::Mul, l, r) => { assert_eq!(*l, Expr::Number(3.0)); assert_eq!(*r, Expr::Number(4.0)); } _ => panic!("Expected multiplication on right"), } } _ => panic!("Expected addition at top level"), } } #[test] fn test_lexer() { let lexer = lexer(); let input = "let x = 42 + 3.14"; let result = lexer.parse(input); assert!(!result.has_errors()); let tokens = result.into_output().unwrap(); assert_eq!(tokens.len(), 6); // let, x, =, 42, +, 3.14 assert_eq!(tokens[0].0, Token::Keyword("let".to_string())); assert_eq!(tokens[1].0, Token::Identifier("x".to_string())); } #[test] fn test_robust_parser() { let parser = robust_parser(); // Test with valid input let input = "42; x; y"; let result = parser.parse(input); assert!(!result.has_errors()); assert_eq!(result.into_output().unwrap().len(), 3); // Test with recovery - unclosed paren let input_with_error = "42; (x; y"; let result = parser.parse(input_with_error); // The parser should still produce some output even with errors assert!(result.has_errors()); } #[test] fn test_binary_op_parser() { let ops = &[("&&", BinOp::Eq), ("||", BinOp::Eq)]; let atom = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap())) .padded(); let parser = binary_op_parser(ops, atom); let result = parser.parse("1 && 2 || 3"); assert!(!result.has_errors()); } #[test] fn test_validated_parser() { let parser = validated_parser(); // Test valid input let result = parser.parse("42"); assert!(!result.has_errors()); assert_eq!(result.into_output().unwrap(), Expr::Number(42.0)); // Test invalid number - this should produce an error let result = parser.parse("12.34.56"); assert!(result.has_errors()); } } /// Binary operators #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, Eq, Lt, } }
Binary operators demonstrate how parsers handle operator precedence and associativity through careful combinator composition.
Building Expression Parsers
#![allow(unused)] fn main() { use std::fmt; use chumsky::error::Rich; use chumsky::extra; use chumsky::prelude::*; /// AST node for expressions #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(f64), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Unary(UnOp, Box<Expr>), Call(String, Vec<Expr>), Let(String, Box<Expr>, Box<Expr>), } /// Binary operators #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, Eq, Lt, } /// Unary operators #[derive(Debug, Clone, PartialEq)] pub enum UnOp { Neg, Not, } /// Token type for lexing #[derive(Debug, Clone, PartialEq)] pub enum Token { Number(f64), Identifier(String), Keyword(String), Op(char), Delimiter(char), } impl fmt::Display for Token { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Token::Number(n) => write!(f, "{}", n), Token::Identifier(s) | Token::Keyword(s) => write!(f, "{}", s), Token::Op(c) | Token::Delimiter(c) => write!(f, "{}", c), } } } /// Lexer that produces tokens with spans pub fn lexer<'src>( ) -> impl Parser<'src, &'src str, Vec<(Token, SimpleSpan)>, extra::Err<Rich<'src, char>>> { let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Token::Number(s.parse().unwrap())); let identifier = text::ident().to_slice().map(|s: &str| match s { "let" | "in" | "if" | "then" | "else" => Token::Keyword(s.to_string()), _ => Token::Identifier(s.to_string()), }); let op = one_of("+-*/=<>!&|").map(Token::Op); let delimiter = one_of("(){}[],;").map(Token::Delimiter); let token = choice((number, identifier, op, delimiter)).padded_by(text::whitespace()); token .map_with(|tok, e| (tok, e.span())) .repeated() .collect() .then_ignore(end()) } /// Parser with error recovery pub fn robust_parser<'src>() -> impl Parser<'src, &'src str, Vec<Expr>, extra::Err<Rich<'src, char>>> { let ident = text::ident().padded().to_slice(); let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap_or(0.0))) .padded(); let expr = recursive(|expr| { let atom = choice(( number, ident.map(|s: &str| Expr::Identifier(s.to_string())), expr.clone() .delimited_by(just('('), just(')')) .recover_with(via_parser(nested_delimiters( '(', ')', [('{', '}'), ('[', ']')], |_| Expr::Number(0.0), ))), )); atom }); expr.separated_by(just(';')) .allow_leading() .allow_trailing() .collect() .then_ignore(end()) } /// Custom parser combinator for binary operators with precedence pub fn binary_op_parser<'src>( ops: &[(&'src str, BinOp)], next: impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> + Clone + 'src, ) -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> + Clone + 'src { let op = choice( ops.iter() .map(|(s, op)| just(*s).to(op.clone())) .collect::<Vec<_>>(), ); next.clone() .foldl(op.then(next).repeated(), |left, (op, right)| { Expr::Binary(op, Box::new(left), Box::new(right)) }) } /// Parser with custom error types #[derive(Debug, Clone, PartialEq)] pub enum ParseError { UnexpectedToken(String), UnclosedDelimiter(char), InvalidNumber(String), } pub fn validated_parser<'src>() -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> { let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .try_map(|s: &str, span| { s.parse::<f64>() .map(Expr::Number) .map_err(|_| Rich::custom(span, format!("Invalid number: {}", s))) }); let ident = text::ident().to_slice().try_map(|s: &str, span| { if s.len() > 100 { Err(Rich::custom(span, "Identifier too long")) } else { Ok(Expr::Identifier(s.to_string())) } }); choice((number, ident)).then_ignore(end()) } #[cfg(test)] mod tests { use super::*; #[test] fn test_expr_parser() { let parser = expr_parser(); let input = "2 + 3 * 4"; let result = parser.parse(input); assert!(!result.has_errors()); match result.into_output().unwrap() { Expr::Binary(BinOp::Add, left, right) => { assert_eq!(*left, Expr::Number(2.0)); match *right { Expr::Binary(BinOp::Mul, l, r) => { assert_eq!(*l, Expr::Number(3.0)); assert_eq!(*r, Expr::Number(4.0)); } _ => panic!("Expected multiplication on right"), } } _ => panic!("Expected addition at top level"), } } #[test] fn test_lexer() { let lexer = lexer(); let input = "let x = 42 + 3.14"; let result = lexer.parse(input); assert!(!result.has_errors()); let tokens = result.into_output().unwrap(); assert_eq!(tokens.len(), 6); // let, x, =, 42, +, 3.14 assert_eq!(tokens[0].0, Token::Keyword("let".to_string())); assert_eq!(tokens[1].0, Token::Identifier("x".to_string())); } #[test] fn test_robust_parser() { let parser = robust_parser(); // Test with valid input let input = "42; x; y"; let result = parser.parse(input); assert!(!result.has_errors()); assert_eq!(result.into_output().unwrap().len(), 3); // Test with recovery - unclosed paren let input_with_error = "42; (x; y"; let result = parser.parse(input_with_error); // The parser should still produce some output even with errors assert!(result.has_errors()); } #[test] fn test_binary_op_parser() { let ops = &[("&&", BinOp::Eq), ("||", BinOp::Eq)]; let atom = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap())) .padded(); let parser = binary_op_parser(ops, atom); let result = parser.parse("1 && 2 || 3"); assert!(!result.has_errors()); } #[test] fn test_validated_parser() { let parser = validated_parser(); // Test valid input let result = parser.parse("42"); assert!(!result.has_errors()); assert_eq!(result.into_output().unwrap(), Expr::Number(42.0)); // Test invalid number - this should produce an error let result = parser.parse("12.34.56"); assert!(result.has_errors()); } } /// Parse a simple expression language with operator precedence pub fn expr_parser<'src>() -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> { let ident = text::ident().padded().to_slice(); let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap())) .padded(); let atom = recursive(|expr| { let args = expr .clone() .separated_by(just(',')) .allow_trailing() .collect() .delimited_by(just('('), just(')')); let call = ident .then(args) .map(|(name, args): (&str, Vec<Expr>)| Expr::Call(name.to_string(), args)); let let_binding = text::keyword("let") .ignore_then(ident) .then_ignore(just('=')) .then(expr.clone()) .then_ignore(text::keyword("in")) .then(expr.clone()) .map(|((name, value), body): ((&str, Expr), Expr)| { Expr::Let(name.to_string(), Box::new(value), Box::new(body)) }); choice(( number, call, let_binding, ident.map(|s: &str| Expr::Identifier(s.to_string())), expr.delimited_by(just('('), just(')')), )) }) .padded(); let unary = just('-') .repeated() .collect::<Vec<_>>() .then(atom.clone()) .map(|(ops, expr)| { ops.into_iter() .fold(expr, |expr, _| Expr::Unary(UnOp::Neg, Box::new(expr))) }); let product = unary.clone().foldl( choice((just('*').to(BinOp::Mul), just('/').to(BinOp::Div))) .then(unary) .repeated(), |left, (op, right)| Expr::Binary(op, Box::new(left), Box::new(right)), ); let sum = product.clone().foldl( choice((just('+').to(BinOp::Add), just('-').to(BinOp::Sub))) .then(product) .repeated(), |left, (op, right)| Expr::Binary(op, Box::new(left), Box::new(right)), ); sum.then_ignore(end()) } }
The expression parser showcases several key chumsky features. The recursive
combinator enables parsing recursive structures like nested expressions. The choice
combinator tries multiple alternatives until one succeeds. The foldl
combinator builds left-associative binary operations by folding a list of operators and operands.
Operator precedence emerges naturally from parser structure. Parsers for higher-precedence operators like multiplication appear lower in the combinator chain, ensuring they bind more tightly than addition or subtraction. The then
combinator sequences parsers, while map
transforms parsed values into AST nodes.
Lexical Analysis
#![allow(unused)] fn main() { use std::fmt; use chumsky::error::Rich; use chumsky::extra; use chumsky::prelude::*; /// AST node for expressions #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(f64), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Unary(UnOp, Box<Expr>), Call(String, Vec<Expr>), Let(String, Box<Expr>, Box<Expr>), } /// Binary operators #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, Eq, Lt, } /// Unary operators #[derive(Debug, Clone, PartialEq)] pub enum UnOp { Neg, Not, } /// Parse a simple expression language with operator precedence pub fn expr_parser<'src>() -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> { let ident = text::ident().padded().to_slice(); let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap())) .padded(); let atom = recursive(|expr| { let args = expr .clone() .separated_by(just(',')) .allow_trailing() .collect() .delimited_by(just('('), just(')')); let call = ident .then(args) .map(|(name, args): (&str, Vec<Expr>)| Expr::Call(name.to_string(), args)); let let_binding = text::keyword("let") .ignore_then(ident) .then_ignore(just('=')) .then(expr.clone()) .then_ignore(text::keyword("in")) .then(expr.clone()) .map(|((name, value), body): ((&str, Expr), Expr)| { Expr::Let(name.to_string(), Box::new(value), Box::new(body)) }); choice(( number, call, let_binding, ident.map(|s: &str| Expr::Identifier(s.to_string())), expr.delimited_by(just('('), just(')')), )) }) .padded(); let unary = just('-') .repeated() .collect::<Vec<_>>() .then(atom.clone()) .map(|(ops, expr)| { ops.into_iter() .fold(expr, |expr, _| Expr::Unary(UnOp::Neg, Box::new(expr))) }); let product = unary.clone().foldl( choice((just('*').to(BinOp::Mul), just('/').to(BinOp::Div))) .then(unary) .repeated(), |left, (op, right)| Expr::Binary(op, Box::new(left), Box::new(right)), ); let sum = product.clone().foldl( choice((just('+').to(BinOp::Add), just('-').to(BinOp::Sub))) .then(product) .repeated(), |left, (op, right)| Expr::Binary(op, Box::new(left), Box::new(right)), ); sum.then_ignore(end()) } impl fmt::Display for Token { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Token::Number(n) => write!(f, "{}", n), Token::Identifier(s) | Token::Keyword(s) => write!(f, "{}", s), Token::Op(c) | Token::Delimiter(c) => write!(f, "{}", c), } } } /// Lexer that produces tokens with spans pub fn lexer<'src>( ) -> impl Parser<'src, &'src str, Vec<(Token, SimpleSpan)>, extra::Err<Rich<'src, char>>> { let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Token::Number(s.parse().unwrap())); let identifier = text::ident().to_slice().map(|s: &str| match s { "let" | "in" | "if" | "then" | "else" => Token::Keyword(s.to_string()), _ => Token::Identifier(s.to_string()), }); let op = one_of("+-*/=<>!&|").map(Token::Op); let delimiter = one_of("(){}[],;").map(Token::Delimiter); let token = choice((number, identifier, op, delimiter)).padded_by(text::whitespace()); token .map_with(|tok, e| (tok, e.span())) .repeated() .collect() .then_ignore(end()) } /// Parser with error recovery pub fn robust_parser<'src>() -> impl Parser<'src, &'src str, Vec<Expr>, extra::Err<Rich<'src, char>>> { let ident = text::ident().padded().to_slice(); let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap_or(0.0))) .padded(); let expr = recursive(|expr| { let atom = choice(( number, ident.map(|s: &str| Expr::Identifier(s.to_string())), expr.clone() .delimited_by(just('('), just(')')) .recover_with(via_parser(nested_delimiters( '(', ')', [('{', '}'), ('[', ']')], |_| Expr::Number(0.0), ))), )); atom }); expr.separated_by(just(';')) .allow_leading() .allow_trailing() .collect() .then_ignore(end()) } /// Custom parser combinator for binary operators with precedence pub fn binary_op_parser<'src>( ops: &[(&'src str, BinOp)], next: impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> + Clone + 'src, ) -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> + Clone + 'src { let op = choice( ops.iter() .map(|(s, op)| just(*s).to(op.clone())) .collect::<Vec<_>>(), ); next.clone() .foldl(op.then(next).repeated(), |left, (op, right)| { Expr::Binary(op, Box::new(left), Box::new(right)) }) } /// Parser with custom error types #[derive(Debug, Clone, PartialEq)] pub enum ParseError { UnexpectedToken(String), UnclosedDelimiter(char), InvalidNumber(String), } pub fn validated_parser<'src>() -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> { let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .try_map(|s: &str, span| { s.parse::<f64>() .map(Expr::Number) .map_err(|_| Rich::custom(span, format!("Invalid number: {}", s))) }); let ident = text::ident().to_slice().try_map(|s: &str, span| { if s.len() > 100 { Err(Rich::custom(span, "Identifier too long")) } else { Ok(Expr::Identifier(s.to_string())) } }); choice((number, ident)).then_ignore(end()) } #[cfg(test)] mod tests { use super::*; #[test] fn test_expr_parser() { let parser = expr_parser(); let input = "2 + 3 * 4"; let result = parser.parse(input); assert!(!result.has_errors()); match result.into_output().unwrap() { Expr::Binary(BinOp::Add, left, right) => { assert_eq!(*left, Expr::Number(2.0)); match *right { Expr::Binary(BinOp::Mul, l, r) => { assert_eq!(*l, Expr::Number(3.0)); assert_eq!(*r, Expr::Number(4.0)); } _ => panic!("Expected multiplication on right"), } } _ => panic!("Expected addition at top level"), } } #[test] fn test_lexer() { let lexer = lexer(); let input = "let x = 42 + 3.14"; let result = lexer.parse(input); assert!(!result.has_errors()); let tokens = result.into_output().unwrap(); assert_eq!(tokens.len(), 6); // let, x, =, 42, +, 3.14 assert_eq!(tokens[0].0, Token::Keyword("let".to_string())); assert_eq!(tokens[1].0, Token::Identifier("x".to_string())); } #[test] fn test_robust_parser() { let parser = robust_parser(); // Test with valid input let input = "42; x; y"; let result = parser.parse(input); assert!(!result.has_errors()); assert_eq!(result.into_output().unwrap().len(), 3); // Test with recovery - unclosed paren let input_with_error = "42; (x; y"; let result = parser.parse(input_with_error); // The parser should still produce some output even with errors assert!(result.has_errors()); } #[test] fn test_binary_op_parser() { let ops = &[("&&", BinOp::Eq), ("||", BinOp::Eq)]; let atom = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap())) .padded(); let parser = binary_op_parser(ops, atom); let result = parser.parse("1 && 2 || 3"); assert!(!result.has_errors()); } #[test] fn test_validated_parser() { let parser = validated_parser(); // Test valid input let result = parser.parse("42"); assert!(!result.has_errors()); assert_eq!(result.into_output().unwrap(), Expr::Number(42.0)); // Test invalid number - this should produce an error let result = parser.parse("12.34.56"); assert!(result.has_errors()); } } /// Token type for lexing #[derive(Debug, Clone, PartialEq)] pub enum Token { Number(f64), Identifier(String), Keyword(String), Op(char), Delimiter(char), } }
While chumsky can parse character streams directly, separate lexical analysis often improves performance and error messages for complex languages.
#![allow(unused)] fn main() { use std::fmt; use chumsky::error::Rich; use chumsky::extra; use chumsky::prelude::*; /// AST node for expressions #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(f64), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Unary(UnOp, Box<Expr>), Call(String, Vec<Expr>), Let(String, Box<Expr>, Box<Expr>), } /// Binary operators #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, Eq, Lt, } /// Unary operators #[derive(Debug, Clone, PartialEq)] pub enum UnOp { Neg, Not, } /// Parse a simple expression language with operator precedence pub fn expr_parser<'src>() -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> { let ident = text::ident().padded().to_slice(); let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap())) .padded(); let atom = recursive(|expr| { let args = expr .clone() .separated_by(just(',')) .allow_trailing() .collect() .delimited_by(just('('), just(')')); let call = ident .then(args) .map(|(name, args): (&str, Vec<Expr>)| Expr::Call(name.to_string(), args)); let let_binding = text::keyword("let") .ignore_then(ident) .then_ignore(just('=')) .then(expr.clone()) .then_ignore(text::keyword("in")) .then(expr.clone()) .map(|((name, value), body): ((&str, Expr), Expr)| { Expr::Let(name.to_string(), Box::new(value), Box::new(body)) }); choice(( number, call, let_binding, ident.map(|s: &str| Expr::Identifier(s.to_string())), expr.delimited_by(just('('), just(')')), )) }) .padded(); let unary = just('-') .repeated() .collect::<Vec<_>>() .then(atom.clone()) .map(|(ops, expr)| { ops.into_iter() .fold(expr, |expr, _| Expr::Unary(UnOp::Neg, Box::new(expr))) }); let product = unary.clone().foldl( choice((just('*').to(BinOp::Mul), just('/').to(BinOp::Div))) .then(unary) .repeated(), |left, (op, right)| Expr::Binary(op, Box::new(left), Box::new(right)), ); let sum = product.clone().foldl( choice((just('+').to(BinOp::Add), just('-').to(BinOp::Sub))) .then(product) .repeated(), |left, (op, right)| Expr::Binary(op, Box::new(left), Box::new(right)), ); sum.then_ignore(end()) } /// Token type for lexing #[derive(Debug, Clone, PartialEq)] pub enum Token { Number(f64), Identifier(String), Keyword(String), Op(char), Delimiter(char), } impl fmt::Display for Token { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Token::Number(n) => write!(f, "{}", n), Token::Identifier(s) | Token::Keyword(s) => write!(f, "{}", s), Token::Op(c) | Token::Delimiter(c) => write!(f, "{}", c), } } } /// Parser with error recovery pub fn robust_parser<'src>() -> impl Parser<'src, &'src str, Vec<Expr>, extra::Err<Rich<'src, char>>> { let ident = text::ident().padded().to_slice(); let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap_or(0.0))) .padded(); let expr = recursive(|expr| { let atom = choice(( number, ident.map(|s: &str| Expr::Identifier(s.to_string())), expr.clone() .delimited_by(just('('), just(')')) .recover_with(via_parser(nested_delimiters( '(', ')', [('{', '}'), ('[', ']')], |_| Expr::Number(0.0), ))), )); atom }); expr.separated_by(just(';')) .allow_leading() .allow_trailing() .collect() .then_ignore(end()) } /// Custom parser combinator for binary operators with precedence pub fn binary_op_parser<'src>( ops: &[(&'src str, BinOp)], next: impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> + Clone + 'src, ) -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> + Clone + 'src { let op = choice( ops.iter() .map(|(s, op)| just(*s).to(op.clone())) .collect::<Vec<_>>(), ); next.clone() .foldl(op.then(next).repeated(), |left, (op, right)| { Expr::Binary(op, Box::new(left), Box::new(right)) }) } /// Parser with custom error types #[derive(Debug, Clone, PartialEq)] pub enum ParseError { UnexpectedToken(String), UnclosedDelimiter(char), InvalidNumber(String), } pub fn validated_parser<'src>() -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> { let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .try_map(|s: &str, span| { s.parse::<f64>() .map(Expr::Number) .map_err(|_| Rich::custom(span, format!("Invalid number: {}", s))) }); let ident = text::ident().to_slice().try_map(|s: &str, span| { if s.len() > 100 { Err(Rich::custom(span, "Identifier too long")) } else { Ok(Expr::Identifier(s.to_string())) } }); choice((number, ident)).then_ignore(end()) } #[cfg(test)] mod tests { use super::*; #[test] fn test_expr_parser() { let parser = expr_parser(); let input = "2 + 3 * 4"; let result = parser.parse(input); assert!(!result.has_errors()); match result.into_output().unwrap() { Expr::Binary(BinOp::Add, left, right) => { assert_eq!(*left, Expr::Number(2.0)); match *right { Expr::Binary(BinOp::Mul, l, r) => { assert_eq!(*l, Expr::Number(3.0)); assert_eq!(*r, Expr::Number(4.0)); } _ => panic!("Expected multiplication on right"), } } _ => panic!("Expected addition at top level"), } } #[test] fn test_lexer() { let lexer = lexer(); let input = "let x = 42 + 3.14"; let result = lexer.parse(input); assert!(!result.has_errors()); let tokens = result.into_output().unwrap(); assert_eq!(tokens.len(), 6); // let, x, =, 42, +, 3.14 assert_eq!(tokens[0].0, Token::Keyword("let".to_string())); assert_eq!(tokens[1].0, Token::Identifier("x".to_string())); } #[test] fn test_robust_parser() { let parser = robust_parser(); // Test with valid input let input = "42; x; y"; let result = parser.parse(input); assert!(!result.has_errors()); assert_eq!(result.into_output().unwrap().len(), 3); // Test with recovery - unclosed paren let input_with_error = "42; (x; y"; let result = parser.parse(input_with_error); // The parser should still produce some output even with errors assert!(result.has_errors()); } #[test] fn test_binary_op_parser() { let ops = &[("&&", BinOp::Eq), ("||", BinOp::Eq)]; let atom = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap())) .padded(); let parser = binary_op_parser(ops, atom); let result = parser.parse("1 && 2 || 3"); assert!(!result.has_errors()); } #[test] fn test_validated_parser() { let parser = validated_parser(); // Test valid input let result = parser.parse("42"); assert!(!result.has_errors()); assert_eq!(result.into_output().unwrap(), Expr::Number(42.0)); // Test invalid number - this should produce an error let result = parser.parse("12.34.56"); assert!(result.has_errors()); } } /// Lexer that produces tokens with spans pub fn lexer<'src>( ) -> impl Parser<'src, &'src str, Vec<(Token, SimpleSpan)>, extra::Err<Rich<'src, char>>> { let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Token::Number(s.parse().unwrap())); let identifier = text::ident().to_slice().map(|s: &str| match s { "let" | "in" | "if" | "then" | "else" => Token::Keyword(s.to_string()), _ => Token::Identifier(s.to_string()), }); let op = one_of("+-*/=<>!&|").map(Token::Op); let delimiter = one_of("(){}[],;").map(Token::Delimiter); let token = choice((number, identifier, op, delimiter)).padded_by(text::whitespace()); token .map_with(|tok, e| (tok, e.span())) .repeated() .collect() .then_ignore(end()) } }
The lexer demonstrates span tracking, which records the source location of each token. The map_with_span
combinator attaches location information to parsed values, enabling precise error reporting. Keywords are distinguished from identifiers during lexing rather than parsing, simplifying the grammar.
Error Recovery
#![allow(unused)] fn main() { use std::fmt; use chumsky::error::Rich; use chumsky::extra; use chumsky::prelude::*; /// AST node for expressions #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(f64), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Unary(UnOp, Box<Expr>), Call(String, Vec<Expr>), Let(String, Box<Expr>, Box<Expr>), } /// Binary operators #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, Eq, Lt, } /// Unary operators #[derive(Debug, Clone, PartialEq)] pub enum UnOp { Neg, Not, } /// Parse a simple expression language with operator precedence pub fn expr_parser<'src>() -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> { let ident = text::ident().padded().to_slice(); let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap())) .padded(); let atom = recursive(|expr| { let args = expr .clone() .separated_by(just(',')) .allow_trailing() .collect() .delimited_by(just('('), just(')')); let call = ident .then(args) .map(|(name, args): (&str, Vec<Expr>)| Expr::Call(name.to_string(), args)); let let_binding = text::keyword("let") .ignore_then(ident) .then_ignore(just('=')) .then(expr.clone()) .then_ignore(text::keyword("in")) .then(expr.clone()) .map(|((name, value), body): ((&str, Expr), Expr)| { Expr::Let(name.to_string(), Box::new(value), Box::new(body)) }); choice(( number, call, let_binding, ident.map(|s: &str| Expr::Identifier(s.to_string())), expr.delimited_by(just('('), just(')')), )) }) .padded(); let unary = just('-') .repeated() .collect::<Vec<_>>() .then(atom.clone()) .map(|(ops, expr)| { ops.into_iter() .fold(expr, |expr, _| Expr::Unary(UnOp::Neg, Box::new(expr))) }); let product = unary.clone().foldl( choice((just('*').to(BinOp::Mul), just('/').to(BinOp::Div))) .then(unary) .repeated(), |left, (op, right)| Expr::Binary(op, Box::new(left), Box::new(right)), ); let sum = product.clone().foldl( choice((just('+').to(BinOp::Add), just('-').to(BinOp::Sub))) .then(product) .repeated(), |left, (op, right)| Expr::Binary(op, Box::new(left), Box::new(right)), ); sum.then_ignore(end()) } /// Token type for lexing #[derive(Debug, Clone, PartialEq)] pub enum Token { Number(f64), Identifier(String), Keyword(String), Op(char), Delimiter(char), } impl fmt::Display for Token { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Token::Number(n) => write!(f, "{}", n), Token::Identifier(s) | Token::Keyword(s) => write!(f, "{}", s), Token::Op(c) | Token::Delimiter(c) => write!(f, "{}", c), } } } /// Lexer that produces tokens with spans pub fn lexer<'src>( ) -> impl Parser<'src, &'src str, Vec<(Token, SimpleSpan)>, extra::Err<Rich<'src, char>>> { let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Token::Number(s.parse().unwrap())); let identifier = text::ident().to_slice().map(|s: &str| match s { "let" | "in" | "if" | "then" | "else" => Token::Keyword(s.to_string()), _ => Token::Identifier(s.to_string()), }); let op = one_of("+-*/=<>!&|").map(Token::Op); let delimiter = one_of("(){}[],;").map(Token::Delimiter); let token = choice((number, identifier, op, delimiter)).padded_by(text::whitespace()); token .map_with(|tok, e| (tok, e.span())) .repeated() .collect() .then_ignore(end()) } /// Custom parser combinator for binary operators with precedence pub fn binary_op_parser<'src>( ops: &[(&'src str, BinOp)], next: impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> + Clone + 'src, ) -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> + Clone + 'src { let op = choice( ops.iter() .map(|(s, op)| just(*s).to(op.clone())) .collect::<Vec<_>>(), ); next.clone() .foldl(op.then(next).repeated(), |left, (op, right)| { Expr::Binary(op, Box::new(left), Box::new(right)) }) } /// Parser with custom error types #[derive(Debug, Clone, PartialEq)] pub enum ParseError { UnexpectedToken(String), UnclosedDelimiter(char), InvalidNumber(String), } pub fn validated_parser<'src>() -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> { let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .try_map(|s: &str, span| { s.parse::<f64>() .map(Expr::Number) .map_err(|_| Rich::custom(span, format!("Invalid number: {}", s))) }); let ident = text::ident().to_slice().try_map(|s: &str, span| { if s.len() > 100 { Err(Rich::custom(span, "Identifier too long")) } else { Ok(Expr::Identifier(s.to_string())) } }); choice((number, ident)).then_ignore(end()) } #[cfg(test)] mod tests { use super::*; #[test] fn test_expr_parser() { let parser = expr_parser(); let input = "2 + 3 * 4"; let result = parser.parse(input); assert!(!result.has_errors()); match result.into_output().unwrap() { Expr::Binary(BinOp::Add, left, right) => { assert_eq!(*left, Expr::Number(2.0)); match *right { Expr::Binary(BinOp::Mul, l, r) => { assert_eq!(*l, Expr::Number(3.0)); assert_eq!(*r, Expr::Number(4.0)); } _ => panic!("Expected multiplication on right"), } } _ => panic!("Expected addition at top level"), } } #[test] fn test_lexer() { let lexer = lexer(); let input = "let x = 42 + 3.14"; let result = lexer.parse(input); assert!(!result.has_errors()); let tokens = result.into_output().unwrap(); assert_eq!(tokens.len(), 6); // let, x, =, 42, +, 3.14 assert_eq!(tokens[0].0, Token::Keyword("let".to_string())); assert_eq!(tokens[1].0, Token::Identifier("x".to_string())); } #[test] fn test_robust_parser() { let parser = robust_parser(); // Test with valid input let input = "42; x; y"; let result = parser.parse(input); assert!(!result.has_errors()); assert_eq!(result.into_output().unwrap().len(), 3); // Test with recovery - unclosed paren let input_with_error = "42; (x; y"; let result = parser.parse(input_with_error); // The parser should still produce some output even with errors assert!(result.has_errors()); } #[test] fn test_binary_op_parser() { let ops = &[("&&", BinOp::Eq), ("||", BinOp::Eq)]; let atom = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap())) .padded(); let parser = binary_op_parser(ops, atom); let result = parser.parse("1 && 2 || 3"); assert!(!result.has_errors()); } #[test] fn test_validated_parser() { let parser = validated_parser(); // Test valid input let result = parser.parse("42"); assert!(!result.has_errors()); assert_eq!(result.into_output().unwrap(), Expr::Number(42.0)); // Test invalid number - this should produce an error let result = parser.parse("12.34.56"); assert!(result.has_errors()); } } /// Parser with error recovery pub fn robust_parser<'src>() -> impl Parser<'src, &'src str, Vec<Expr>, extra::Err<Rich<'src, char>>> { let ident = text::ident().padded().to_slice(); let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap_or(0.0))) .padded(); let expr = recursive(|expr| { let atom = choice(( number, ident.map(|s: &str| Expr::Identifier(s.to_string())), expr.clone() .delimited_by(just('('), just(')')) .recover_with(via_parser(nested_delimiters( '(', ')', [('{', '}'), ('[', ']')], |_| Expr::Number(0.0), ))), )); atom }); expr.separated_by(just(';')) .allow_leading() .allow_trailing() .collect() .then_ignore(end()) } }
Error recovery allows parsers to continue processing after encountering errors, producing partial results and multiple error messages. The recover_with
combinator specifies recovery strategies for specific error conditions. The nested_delimiters
recovery strategy handles mismatched parentheses by searching for the appropriate closing delimiter.
Recovery strategies help development tools provide better user experiences. IDEs can show multiple syntax errors simultaneously, and compilers can report more problems in a single run. The separated_by
combinator with allow_trailing
handles comma-separated lists gracefully, even with trailing commas.
Custom Combinators
#![allow(unused)] fn main() { use std::fmt; use chumsky::error::Rich; use chumsky::extra; use chumsky::prelude::*; /// AST node for expressions #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(f64), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Unary(UnOp, Box<Expr>), Call(String, Vec<Expr>), Let(String, Box<Expr>, Box<Expr>), } /// Binary operators #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, Eq, Lt, } /// Unary operators #[derive(Debug, Clone, PartialEq)] pub enum UnOp { Neg, Not, } /// Parse a simple expression language with operator precedence pub fn expr_parser<'src>() -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> { let ident = text::ident().padded().to_slice(); let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap())) .padded(); let atom = recursive(|expr| { let args = expr .clone() .separated_by(just(',')) .allow_trailing() .collect() .delimited_by(just('('), just(')')); let call = ident .then(args) .map(|(name, args): (&str, Vec<Expr>)| Expr::Call(name.to_string(), args)); let let_binding = text::keyword("let") .ignore_then(ident) .then_ignore(just('=')) .then(expr.clone()) .then_ignore(text::keyword("in")) .then(expr.clone()) .map(|((name, value), body): ((&str, Expr), Expr)| { Expr::Let(name.to_string(), Box::new(value), Box::new(body)) }); choice(( number, call, let_binding, ident.map(|s: &str| Expr::Identifier(s.to_string())), expr.delimited_by(just('('), just(')')), )) }) .padded(); let unary = just('-') .repeated() .collect::<Vec<_>>() .then(atom.clone()) .map(|(ops, expr)| { ops.into_iter() .fold(expr, |expr, _| Expr::Unary(UnOp::Neg, Box::new(expr))) }); let product = unary.clone().foldl( choice((just('*').to(BinOp::Mul), just('/').to(BinOp::Div))) .then(unary) .repeated(), |left, (op, right)| Expr::Binary(op, Box::new(left), Box::new(right)), ); let sum = product.clone().foldl( choice((just('+').to(BinOp::Add), just('-').to(BinOp::Sub))) .then(product) .repeated(), |left, (op, right)| Expr::Binary(op, Box::new(left), Box::new(right)), ); sum.then_ignore(end()) } /// Token type for lexing #[derive(Debug, Clone, PartialEq)] pub enum Token { Number(f64), Identifier(String), Keyword(String), Op(char), Delimiter(char), } impl fmt::Display for Token { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Token::Number(n) => write!(f, "{}", n), Token::Identifier(s) | Token::Keyword(s) => write!(f, "{}", s), Token::Op(c) | Token::Delimiter(c) => write!(f, "{}", c), } } } /// Lexer that produces tokens with spans pub fn lexer<'src>( ) -> impl Parser<'src, &'src str, Vec<(Token, SimpleSpan)>, extra::Err<Rich<'src, char>>> { let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Token::Number(s.parse().unwrap())); let identifier = text::ident().to_slice().map(|s: &str| match s { "let" | "in" | "if" | "then" | "else" => Token::Keyword(s.to_string()), _ => Token::Identifier(s.to_string()), }); let op = one_of("+-*/=<>!&|").map(Token::Op); let delimiter = one_of("(){}[],;").map(Token::Delimiter); let token = choice((number, identifier, op, delimiter)).padded_by(text::whitespace()); token .map_with(|tok, e| (tok, e.span())) .repeated() .collect() .then_ignore(end()) } /// Parser with error recovery pub fn robust_parser<'src>() -> impl Parser<'src, &'src str, Vec<Expr>, extra::Err<Rich<'src, char>>> { let ident = text::ident().padded().to_slice(); let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap_or(0.0))) .padded(); let expr = recursive(|expr| { let atom = choice(( number, ident.map(|s: &str| Expr::Identifier(s.to_string())), expr.clone() .delimited_by(just('('), just(')')) .recover_with(via_parser(nested_delimiters( '(', ')', [('{', '}'), ('[', ']')], |_| Expr::Number(0.0), ))), )); atom }); expr.separated_by(just(';')) .allow_leading() .allow_trailing() .collect() .then_ignore(end()) } /// Parser with custom error types #[derive(Debug, Clone, PartialEq)] pub enum ParseError { UnexpectedToken(String), UnclosedDelimiter(char), InvalidNumber(String), } pub fn validated_parser<'src>() -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> { let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .try_map(|s: &str, span| { s.parse::<f64>() .map(Expr::Number) .map_err(|_| Rich::custom(span, format!("Invalid number: {}", s))) }); let ident = text::ident().to_slice().try_map(|s: &str, span| { if s.len() > 100 { Err(Rich::custom(span, "Identifier too long")) } else { Ok(Expr::Identifier(s.to_string())) } }); choice((number, ident)).then_ignore(end()) } #[cfg(test)] mod tests { use super::*; #[test] fn test_expr_parser() { let parser = expr_parser(); let input = "2 + 3 * 4"; let result = parser.parse(input); assert!(!result.has_errors()); match result.into_output().unwrap() { Expr::Binary(BinOp::Add, left, right) => { assert_eq!(*left, Expr::Number(2.0)); match *right { Expr::Binary(BinOp::Mul, l, r) => { assert_eq!(*l, Expr::Number(3.0)); assert_eq!(*r, Expr::Number(4.0)); } _ => panic!("Expected multiplication on right"), } } _ => panic!("Expected addition at top level"), } } #[test] fn test_lexer() { let lexer = lexer(); let input = "let x = 42 + 3.14"; let result = lexer.parse(input); assert!(!result.has_errors()); let tokens = result.into_output().unwrap(); assert_eq!(tokens.len(), 6); // let, x, =, 42, +, 3.14 assert_eq!(tokens[0].0, Token::Keyword("let".to_string())); assert_eq!(tokens[1].0, Token::Identifier("x".to_string())); } #[test] fn test_robust_parser() { let parser = robust_parser(); // Test with valid input let input = "42; x; y"; let result = parser.parse(input); assert!(!result.has_errors()); assert_eq!(result.into_output().unwrap().len(), 3); // Test with recovery - unclosed paren let input_with_error = "42; (x; y"; let result = parser.parse(input_with_error); // The parser should still produce some output even with errors assert!(result.has_errors()); } #[test] fn test_binary_op_parser() { let ops = &[("&&", BinOp::Eq), ("||", BinOp::Eq)]; let atom = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap())) .padded(); let parser = binary_op_parser(ops, atom); let result = parser.parse("1 && 2 || 3"); assert!(!result.has_errors()); } #[test] fn test_validated_parser() { let parser = validated_parser(); // Test valid input let result = parser.parse("42"); assert!(!result.has_errors()); assert_eq!(result.into_output().unwrap(), Expr::Number(42.0)); // Test invalid number - this should produce an error let result = parser.parse("12.34.56"); assert!(result.has_errors()); } } /// Custom parser combinator for binary operators with precedence pub fn binary_op_parser<'src>( ops: &[(&'src str, BinOp)], next: impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> + Clone + 'src, ) -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> + Clone + 'src { let op = choice( ops.iter() .map(|(s, op)| just(*s).to(op.clone())) .collect::<Vec<_>>(), ); next.clone() .foldl(op.then(next).repeated(), |left, (op, right)| { Expr::Binary(op, Box::new(left), Box::new(right)) }) } }
Custom combinators encapsulate common parsing patterns for reuse across different parts of a grammar. This binary operator parser handles any set of operators at the same precedence level, building left-associative expressions. The generic implementation works with any operator type and expression parser.
Creating domain-specific combinators improves grammar readability and reduces duplication. Common patterns in a language can be abstracted into reusable components that compose naturally with built-in combinators.
Validation and Semantic Analysis
#![allow(unused)] fn main() { use std::fmt; use chumsky::error::Rich; use chumsky::extra; use chumsky::prelude::*; /// AST node for expressions #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(f64), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Unary(UnOp, Box<Expr>), Call(String, Vec<Expr>), Let(String, Box<Expr>, Box<Expr>), } /// Binary operators #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, Eq, Lt, } /// Unary operators #[derive(Debug, Clone, PartialEq)] pub enum UnOp { Neg, Not, } /// Parse a simple expression language with operator precedence pub fn expr_parser<'src>() -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> { let ident = text::ident().padded().to_slice(); let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap())) .padded(); let atom = recursive(|expr| { let args = expr .clone() .separated_by(just(',')) .allow_trailing() .collect() .delimited_by(just('('), just(')')); let call = ident .then(args) .map(|(name, args): (&str, Vec<Expr>)| Expr::Call(name.to_string(), args)); let let_binding = text::keyword("let") .ignore_then(ident) .then_ignore(just('=')) .then(expr.clone()) .then_ignore(text::keyword("in")) .then(expr.clone()) .map(|((name, value), body): ((&str, Expr), Expr)| { Expr::Let(name.to_string(), Box::new(value), Box::new(body)) }); choice(( number, call, let_binding, ident.map(|s: &str| Expr::Identifier(s.to_string())), expr.delimited_by(just('('), just(')')), )) }) .padded(); let unary = just('-') .repeated() .collect::<Vec<_>>() .then(atom.clone()) .map(|(ops, expr)| { ops.into_iter() .fold(expr, |expr, _| Expr::Unary(UnOp::Neg, Box::new(expr))) }); let product = unary.clone().foldl( choice((just('*').to(BinOp::Mul), just('/').to(BinOp::Div))) .then(unary) .repeated(), |left, (op, right)| Expr::Binary(op, Box::new(left), Box::new(right)), ); let sum = product.clone().foldl( choice((just('+').to(BinOp::Add), just('-').to(BinOp::Sub))) .then(product) .repeated(), |left, (op, right)| Expr::Binary(op, Box::new(left), Box::new(right)), ); sum.then_ignore(end()) } /// Token type for lexing #[derive(Debug, Clone, PartialEq)] pub enum Token { Number(f64), Identifier(String), Keyword(String), Op(char), Delimiter(char), } impl fmt::Display for Token { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Token::Number(n) => write!(f, "{}", n), Token::Identifier(s) | Token::Keyword(s) => write!(f, "{}", s), Token::Op(c) | Token::Delimiter(c) => write!(f, "{}", c), } } } /// Lexer that produces tokens with spans pub fn lexer<'src>( ) -> impl Parser<'src, &'src str, Vec<(Token, SimpleSpan)>, extra::Err<Rich<'src, char>>> { let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Token::Number(s.parse().unwrap())); let identifier = text::ident().to_slice().map(|s: &str| match s { "let" | "in" | "if" | "then" | "else" => Token::Keyword(s.to_string()), _ => Token::Identifier(s.to_string()), }); let op = one_of("+-*/=<>!&|").map(Token::Op); let delimiter = one_of("(){}[],;").map(Token::Delimiter); let token = choice((number, identifier, op, delimiter)).padded_by(text::whitespace()); token .map_with(|tok, e| (tok, e.span())) .repeated() .collect() .then_ignore(end()) } /// Parser with error recovery pub fn robust_parser<'src>() -> impl Parser<'src, &'src str, Vec<Expr>, extra::Err<Rich<'src, char>>> { let ident = text::ident().padded().to_slice(); let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap_or(0.0))) .padded(); let expr = recursive(|expr| { let atom = choice(( number, ident.map(|s: &str| Expr::Identifier(s.to_string())), expr.clone() .delimited_by(just('('), just(')')) .recover_with(via_parser(nested_delimiters( '(', ')', [('{', '}'), ('[', ']')], |_| Expr::Number(0.0), ))), )); atom }); expr.separated_by(just(';')) .allow_leading() .allow_trailing() .collect() .then_ignore(end()) } /// Custom parser combinator for binary operators with precedence pub fn binary_op_parser<'src>( ops: &[(&'src str, BinOp)], next: impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> + Clone + 'src, ) -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> + Clone + 'src { let op = choice( ops.iter() .map(|(s, op)| just(*s).to(op.clone())) .collect::<Vec<_>>(), ); next.clone() .foldl(op.then(next).repeated(), |left, (op, right)| { Expr::Binary(op, Box::new(left), Box::new(right)) }) } /// Parser with custom error types #[derive(Debug, Clone, PartialEq)] pub enum ParseError { UnexpectedToken(String), UnclosedDelimiter(char), InvalidNumber(String), } #[cfg(test)] mod tests { use super::*; #[test] fn test_expr_parser() { let parser = expr_parser(); let input = "2 + 3 * 4"; let result = parser.parse(input); assert!(!result.has_errors()); match result.into_output().unwrap() { Expr::Binary(BinOp::Add, left, right) => { assert_eq!(*left, Expr::Number(2.0)); match *right { Expr::Binary(BinOp::Mul, l, r) => { assert_eq!(*l, Expr::Number(3.0)); assert_eq!(*r, Expr::Number(4.0)); } _ => panic!("Expected multiplication on right"), } } _ => panic!("Expected addition at top level"), } } #[test] fn test_lexer() { let lexer = lexer(); let input = "let x = 42 + 3.14"; let result = lexer.parse(input); assert!(!result.has_errors()); let tokens = result.into_output().unwrap(); assert_eq!(tokens.len(), 6); // let, x, =, 42, +, 3.14 assert_eq!(tokens[0].0, Token::Keyword("let".to_string())); assert_eq!(tokens[1].0, Token::Identifier("x".to_string())); } #[test] fn test_robust_parser() { let parser = robust_parser(); // Test with valid input let input = "42; x; y"; let result = parser.parse(input); assert!(!result.has_errors()); assert_eq!(result.into_output().unwrap().len(), 3); // Test with recovery - unclosed paren let input_with_error = "42; (x; y"; let result = parser.parse(input_with_error); // The parser should still produce some output even with errors assert!(result.has_errors()); } #[test] fn test_binary_op_parser() { let ops = &[("&&", BinOp::Eq), ("||", BinOp::Eq)]; let atom = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .map(|s: &str| Expr::Number(s.parse().unwrap())) .padded(); let parser = binary_op_parser(ops, atom); let result = parser.parse("1 && 2 || 3"); assert!(!result.has_errors()); } #[test] fn test_validated_parser() { let parser = validated_parser(); // Test valid input let result = parser.parse("42"); assert!(!result.has_errors()); assert_eq!(result.into_output().unwrap(), Expr::Number(42.0)); // Test invalid number - this should produce an error let result = parser.parse("12.34.56"); assert!(result.has_errors()); } } pub fn validated_parser<'src>() -> impl Parser<'src, &'src str, Expr, extra::Err<Rich<'src, char>>> { let number = text::int(10) .then(just('.').then(text::digits(10)).or_not()) .to_slice() .try_map(|s: &str, span| { s.parse::<f64>() .map(Expr::Number) .map_err(|_| Rich::custom(span, format!("Invalid number: {}", s))) }); let ident = text::ident().to_slice().try_map(|s: &str, span| { if s.len() > 100 { Err(Rich::custom(span, "Identifier too long")) } else { Ok(Expr::Identifier(s.to_string())) } }); choice((number, ident)).then_ignore(end()) } }
The validate
combinator performs semantic checks during parsing, emitting errors for invalid constructs while continuing to parse. This enables reporting both syntactic and semantic errors in a single pass. Validation can check numeric ranges, identifier validity, or any other semantic constraint.
Combining parsing and validation reduces the number of passes over the input and provides better error messages by retaining parse context. The error emission mechanism allows multiple errors from a single validation, supporting comprehensive error reporting.
Performance Considerations
Chumsky parsers achieve good performance through several optimizations. The library uses zero-copy parsing where possible, avoiding string allocation for tokens and identifiers. Parsers are compiled to efficient state machines that minimize backtracking.
Choice combinators try alternatives in order, so placing common cases first improves performance. The or
combinator creates more efficient parsers than choice
when only two alternatives exist. Memoization can be added to recursive parsers to avoid reparsing the same input multiple times.
Integration Patterns
Chumsky integrates well with other compiler infrastructure. The span information works with error reporting libraries like ariadne or codespan-reporting to display beautiful error messages. AST nodes can implement visitor patterns or be processed by subsequent compiler passes.
The streaming API supports parsing large files without loading them entirely into memory. Incremental parsing can be implemented by caching parse results for unchanged portions of input. The modular parser design allows testing individual components in isolation.
Best Practices
Structure parsers hierarchically, with each level handling one precedence level or syntactic category. Use meaningful names for intermediate parsers to improve readability. Keep individual parsers focused on a single responsibility.
Test parsers thoroughly with both valid and invalid input. Error recovery strategies should be tested to ensure they produce reasonable partial results. Use property-based testing to verify parser properties like round-tripping through pretty-printing.
Profile parser performance on realistic input to identify bottlenecks. Complex lookahead or backtracking can dramatically impact performance. Consider using a separate lexer for languages with complex tokenization rules.
Document grammar ambiguities and their resolution strategies. Explain why certain parser structures were chosen, especially for complex precedence hierarchies. Provide examples of valid and invalid syntax to clarify language rules.
combine
combine provides a powerful parser combinator library for building parsers from composable pieces. Unlike parser generators that require separate grammar files, combine constructs parsers entirely in Rust code through a rich set of combinators. The library emphasizes flexibility and error recovery, making it well-suited for both simple configuration files and complex programming languages.
The library’s streaming approach enables parsing of large files without loading them entirely into memory, while its error handling system provides detailed information about parse failures. combine supports multiple input types including strings, byte arrays, and custom token streams, adapting to various parsing scenarios from network protocols to source code.
Basic Expression Parser
#![allow(unused)] fn main() { use combine::parser::char::{char, digit, spaces}; use combine::parser::repeat::sep_by; use combine::{attempt, between, choice, many1, parser, Parser, Stream}; use combine::error::ParseError; #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i32), Add(Box<Expr>, Box<Expr>), Sub(Box<Expr>, Box<Expr>), Mul(Box<Expr>, Box<Expr>), Div(Box<Expr>, Box<Expr>), } parser! { fn expr[Input]()(Input) -> Expr where [Input: Stream<Token = char>] { expr_() } } fn expr_<Input>() -> impl Parser<Input, Output = Expr> where Input: Stream<Token = char>, Input::Error: ParseError<Input::Token, Input::Range, Input::Position>, { term().and(many1::<Vec<_>, _, _>(( choice((char('+'), char('-'))), term() ))) .map(|(first, rest): (Expr, Vec<(char, Expr)>)| { rest.into_iter().fold(first, |acc, (op, val)| match op { '+' => Expr::Add(Box::new(acc), Box::new(val)), '-' => Expr::Sub(Box::new(acc), Box::new(val)), _ => unreachable!(), }) }) .or(term()) } }
The expression parser demonstrates combine’s approach to operator precedence through parser composition. The expr_ function handles addition and subtraction as left-associative operations by parsing a term followed by zero or more operator-term pairs. The fold operation builds the abstract syntax tree left-to-right, ensuring proper associativity. The parser macro generates a wrapper function that simplifies the parser’s type signature, making it easier to use in larger compositions.
The choice combinator selects between multiple alternatives, while many1 requires at least one match. The and method chains parsers sequentially, passing results as tuples. The map method transforms parse results, converting from the parser’s output format to the desired AST representation. The or combinator provides fallback behavior, attempting the simpler term parser if the complex expression parsing fails.
Parser Combinators
#![allow(unused)] fn main() { fn term<Input>() -> impl Parser<Input, Output = Expr> where Input: Stream<Token = char>, Input::Error: ParseError<Input::Token, Input::Range, Input::Position>, { factor().and(many1::<Vec<_>, _, _>(( choice((char('*'), char('/'))), factor() ))) .map(|(first, rest): (Expr, Vec<(char, Expr)>)| { rest.into_iter().fold(first, |acc, (op, val)| match op { '*' => Expr::Mul(Box::new(acc), Box::new(val)), '/' => Expr::Div(Box::new(acc), Box::new(val)), _ => unreachable!(), }) }) .or(factor()) } fn factor<Input>() -> impl Parser<Input, Output = Expr> where Input: Stream<Token = char>, Input::Error: ParseError<Input::Token, Input::Range, Input::Position>, { choice(( between(char('('), char(')'), expr()), number(), )) } fn number<Input>() -> impl Parser<Input, Output = Expr> where Input: Stream<Token = char>, Input::Error: ParseError<Input::Token, Input::Range, Input::Position>, { many1::<String, _, _>(digit()) .map(|s| Expr::Number(s.parse().unwrap())) } }
The term parser handles multiplication and division with higher precedence than addition and subtraction. This precedence hierarchy emerges naturally from the parser structure, with term calling factor for its operands, and expr calling term. Each level of the hierarchy handles operators of the same precedence, delegating to the next level for higher-precedence operations.
The factor parser demonstrates parenthesis handling using the between combinator, which parses delimited content while discarding the delimiters. The recursive call to expr within parentheses allows arbitrary expression nesting. The number parser combines multiple digit characters into a string, then parses the result into an integer wrapped in the Number variant.
JSON Parser
#![allow(unused)] fn main() { use combine::parser::char::{char, string}; use combine::parser::choice::optional; #[derive(Debug, Clone, PartialEq)] pub enum Json { Null, Bool(bool), Number(f64), String(String), Array(Vec<Json>), Object(Vec<(String, Json)>), } parser! { fn json[Input]()(Input) -> Json where [Input: Stream<Token = char>] { spaces().with(json_value()) } } fn json_value<Input>() -> impl Parser<Input, Output = Json> where Input: Stream<Token = char>, Input::Error: ParseError<Input::Token, Input::Range, Input::Position>, { choice(( string("null").map(|_| Json::Null), string("true").map(|_| Json::Bool(true)), string("false").map(|_| Json::Bool(false)), json_number(), json_string(), json_array(), json_object(), )) } }
The JSON parser showcases combine’s ability to handle complex recursive data structures. The json_value function uses choice to select among all possible JSON types, with each alternative parser returning the appropriate Json enum variant. The string parser matches exact character sequences, while map transforms the successful parse into the corresponding JSON value.
The spaces().with() pattern at the entry point consumes leading whitespace before parsing the actual JSON value. This pattern appears throughout the parser to handle optional whitespace between tokens. The parser macro again simplifies the type signature, hiding the complex return type that would otherwise be required.
String and Number Parsing
#![allow(unused)] fn main() { fn json_string<Input>() -> impl Parser<Input, Output = Json> where Input: Stream<Token = char>, Input::Error: ParseError<Input::Token, Input::Range, Input::Position>, { between( char('"'), char('"'), many1::<String, _, _>(satisfy(|c| c != '"' && c != '\\')) ) .map(Json::String) } fn json_number<Input>() -> impl Parser<Input, Output = Json> where Input: Stream<Token = char>, Input::Error: ParseError<Input::Token, Input::Range, Input::Position>, { optional(char('-')) .and(many1::<String, _, _>(digit())) .and(optional(char('.').and(many1::<String, _, _>(digit())))) .and(optional( choice((char('e'), char('E'))) .and(optional(choice((char('+'), char('-'))))) .and(many1::<String, _, _>(digit())) )) .map(|(((sign, int), frac), exp)| { let mut num = String::new(); if sign.is_some() { num.push('-'); } num.push_str(&int); if let Some((_, f)) = frac { num.push('.'); num.push_str(&f); } if let Some(((e, sign), exp_digits)) = exp { num.push(e); if let Some(s) = sign { num.push(s); } num.push_str(&exp_digits); } Json::Number(num.parse().unwrap()) }) } }
The json_string parser demonstrates basic string parsing with escape sequence support. The satisfy combinator accepts characters matching a predicate, building a string from all non-quote, non-backslash characters. Real JSON parsing would require additional escape sequence handling, but this simplified version illustrates the core concept.
The json_number parser handles the full JSON number format including optional signs, decimal points, and scientific notation. The nested tuple structure from chained and combinators captures each component of the number. The map function reconstructs the string representation before parsing it as a floating-point value. This approach ensures correct handling of all valid JSON number formats while maintaining parse accuracy.
Array and Object Parsing
#![allow(unused)] fn main() { fn json_array<Input>() -> impl Parser<Input, Output = Json> where Input: Stream<Token = char>, Input::Error: ParseError<Input::Token, Input::Range, Input::Position>, { between( char('[').skip(spaces()), char(']'), sep_by(json_value().skip(spaces()), char(',').skip(spaces())) ) .map(Json::Array) } fn json_object<Input>() -> impl Parser<Input, Output = Json> where Input: Stream<Token = char>, Input::Error: ParseError<Input::Token, Input::Range, Input::Position>, { let member = json_string() .skip(spaces()) .skip(char(':')) .skip(spaces()) .and(json_value()); between( char('{').skip(spaces()), char('}'), sep_by(member.skip(spaces()), char(',').skip(spaces())) ) .map(|members: Vec<(Json, Json)>| { Json::Object( members.into_iter() .map(|(k, v)| match k { Json::String(s) => (s, v), _ => unreachable!(), }) .collect() ) }) } }
The array parser uses sep_by to handle comma-separated values, automatically managing both empty arrays and trailing comma issues. The skip method discards whitespace after each element and separator, maintaining clean separation between structural parsing and whitespace handling. The between combinator ensures proper bracket matching while the map function wraps the result in the Array variant.
Object parsing combines string keys with arbitrary values using the member parser. The skip chain removes colons and whitespace without including them in the result. The map function extracts string values from the Json::String variant for use as object keys, transforming the vector of tuples into the expected format. This design maintains type safety while parsing the heterogeneous structure of JSON objects.
S-Expression Parser
#![allow(unused)] fn main() { #[derive(Debug, Clone, PartialEq)] pub enum SExpr { Symbol(String), Number(i64), String(String), List(Vec<SExpr>), } parser! { fn s_expression[Input]()(Input) -> SExpr where [Input: Stream<Token = char>] { spaces().with(s_expr()) } } fn s_expr<Input>() -> impl Parser<Input, Output = SExpr> where Input: Stream<Token = char>, Input::Error: ParseError<Input::Token, Input::Range, Input::Position>, { choice(( s_list(), s_atom(), )) } fn s_atom<Input>() -> impl Parser<Input, Output = SExpr> where Input: Stream<Token = char>, Input::Error: ParseError<Input::Token, Input::Range, Input::Position>, { choice(( s_string(), s_number(), s_symbol(), )) } }
The S-expression parser handles LISP-style symbolic expressions with minimal complexity. The grammar’s recursive nature maps directly to Rust’s enum system, with each variant corresponding to a fundamental S-expression type. The parser structure mirrors the data structure, making the implementation intuitive and maintainable.
The separation between s_expr and s_atom clarifies the grammar structure, distinguishing compound lists from atomic values. This organization simplifies error messages and makes the parser’s intent clear. The choice combinator tries each alternative in order, selecting the first successful parse.
Symbol and List Parsing
#![allow(unused)] fn main() { fn s_symbol<Input>() -> impl Parser<Input, Output = SExpr> where Input: Stream<Token = char>, Input::Error: ParseError<Input::Token, Input::Range, Input::Position>, { many1::<String, _, _>(satisfy(|c: char| c.is_alphanumeric() || "+-*/<>=!?_".contains(c) )) .map(SExpr::Symbol) } fn s_list<Input>() -> impl Parser<Input, Output = SExpr> where Input: Stream<Token = char>, Input::Error: ParseError<Input::Token, Input::Range, Input::Position>, { between( char('(').skip(spaces()), char(')'), many(s_expr().skip(spaces())) ) .map(SExpr::List) } }
Symbol parsing accepts standard LISP identifier characters including operators and special symbols. The satisfy predicate defines the valid character set, while many1 ensures at least one character. This approach handles both function names and operators uniformly, reflecting LISP’s treatment of operators as regular symbols.
List parsing recursively invokes s_expr for each element, enabling arbitrary nesting. The many combinator accepts zero or more elements, properly handling empty lists. Whitespace handling occurs after each element through skip, maintaining clean separation between elements without requiring explicit whitespace in the grammar.
Configuration Parser
#![allow(unused)] fn main() { #[derive(Debug, Clone, PartialEq)] pub struct Config { pub entries: Vec<ConfigEntry>, } #[derive(Debug, Clone, PartialEq)] pub struct ConfigEntry { pub key: String, pub value: ConfigValue, } #[derive(Debug, Clone, PartialEq)] pub enum ConfigValue { String(String), Number(f64), Bool(bool), List(Vec<ConfigValue>), } fn config_value<Input>() -> impl Parser<Input, Output = ConfigValue> where Input: Stream<Token = char>, Input::Error: ParseError<Input::Token, Input::Range, Input::Position>, { choice(( string("true").map(|_| ConfigValue::Bool(true)), string("false").map(|_| ConfigValue::Bool(false)), config_number(), config_string(), config_list(), )) } }
The configuration parser demonstrates parsing of key-value pairs with multiple value types. The ConfigValue enum supports common configuration types including strings, numbers, booleans, and lists. This structure handles most configuration file formats while remaining simple to extend.
The choice ordering matters for ambiguous cases. Parsing boolean literals before numbers prevents misinterpretation, while parsing strings after literals avoids consuming quoted keywords. This careful ordering ensures correct parsing without complex lookahead.
Error Recovery
#![allow(unused)] fn main() { use combine::stream::position; use combine::error::StringStreamError; pub fn parse_with_position(input: &str) -> Result<Json, String> { let mut parser = json(); let stream = position::Stream::new(input); match parser.parse(stream) { Ok((result, _)) => Ok(result), Err(err) => { let pos = err.position; Err(format!("Parse error at line {}, column {}: {}", pos.line, pos.column, err)) } } } }
combine provides detailed error information including position tracking and error context. The position::Stream wrapper adds line and column information to the input stream, enabling precise error reporting. This information helps users quickly locate and fix syntax errors in their input.
Error messages include both the location and nature of the failure, with combine attempting to provide helpful suggestions based on the expected tokens. The library’s error recovery mechanisms allow parsing to continue after errors in some cases, useful for IDE integration where partial results remain valuable.
Stream Abstraction
#![allow(unused)] fn main() { use combine::stream::easy; use combine::stream::state::State; pub fn parse_with_state<'a>( input: &'a str, filename: String, ) -> Result<Json, easy::ParseError<&'a str>> { let stream = State::new(easy::Stream(input)) .with_positioner(position::SourcePosition::default()); let mut parser = json(); parser.easy_parse(stream) .map(|t| t.0) } }
combine’s stream abstraction supports various input types beyond simple strings. The State wrapper adds user-defined state to the parsing process, useful for maintaining symbol tables or configuration during parsing. The easy module provides enhanced error messages at a small performance cost.
Custom stream types enable parsing of token sequences from lexers, byte arrays from network protocols, or any other sequential data source. This flexibility makes combine suitable for diverse parsing tasks while maintaining a consistent interface across different input types.
Performance Optimization
#![allow(unused)] fn main() { use combine::parser::combinator::recognize; use combine::parser::range::take_while1; fn optimized_number<Input>() -> impl Parser<Input, Output = f64> where Input: Stream<Token = char>, Input::Error: ParseError<Input::Token, Input::Range, Input::Position>, { recognize(( optional(char('-')), take_while1(|c: char| c.is_ascii_digit()), optional(( char('.'), take_while1(|c: char| c.is_ascii_digit()) )), )) .map(|s: String| s.parse().unwrap()) } }
The recognize combinator captures the entire matched input as a string without building intermediate structures. This approach reduces allocations when parsing numbers or identifiers, improving performance for large inputs. The take_while1 combinator efficiently consumes characters matching a predicate without creating temporary collections.
combine’s lazy evaluation model ensures parsers only perform necessary work. Failed alternatives don’t consume input beyond the failure point, enabling efficient backtracking. The attempt combinator explicitly marks backtrack points, giving fine control over parser performance.
Testing Strategies
#![allow(unused)] fn main() { #[cfg(test)] mod tests { use super::*; #[test] fn test_expression_precedence() { let input = "1 + 2 * 3"; let result = expression(input).unwrap(); assert_eq!(result.eval(), 7.0); // Not 9.0 } #[test] fn test_json_nested() { let input = r#"{"a": [1, {"b": true}]}"#; let result = parse_json(input).unwrap(); match result { Json::Object(obj) => { assert_eq!(obj.len(), 1); assert_eq!(obj[0].0, "a"); match &obj[0].1 { Json::Array(arr) => assert_eq!(arr.len(), 2), _ => panic!("Expected array"), } } _ => panic!("Expected object"), } } #[test] fn test_error_position() { let input = "{ invalid json }"; let result = parse_with_position(input); assert!(result.is_err()); let err = result.unwrap_err(); assert!(err.contains("line")); assert!(err.contains("column")); } } }
Testing parsers requires validating both successful parses and error handling. Precedence tests ensure operators combine correctly, while nested structure tests verify recursive parsing. Error tests confirm that position information and error messages provide useful debugging information.
Property-based testing works well with parser combinators. Generate random valid inputs according to the grammar, parse them, and verify properties like roundtrip printing or semantic equivalence. This approach finds edge cases that manual tests might miss.
Best Practices
Structure parsers to match the desired AST closely, using Rust’s type system to enforce invariants. Separate lexical concerns like whitespace handling from structural parsing using skip and trim combinators. This separation simplifies both the parser and error messages.
Use the parser! macro for public interfaces to hide complex type signatures. The macro generates cleaner function signatures while preserving full type safety. Internal helper parsers can use impl Parser return types for better compile times and simpler code.
Order choice alternatives carefully, considering both correctness and performance. Place more specific patterns before general ones to avoid incorrect matches. Use attempt when backtracking is needed, but minimize its use for better performance.
Build parsers incrementally, testing each component before composing them. Start with simple atoms, then build expressions, then statements. This approach makes debugging easier and ensures each piece works correctly before integration.
combine provides a flexible and powerful approach to parsing that scales from simple configuration files to complete programming languages. Its combinator-based design encourages modular, testable parsers while maintaining excellent performance and error reporting. The library’s stream abstraction and careful API design make it an excellent choice for compiler frontends and data processing pipelines.
LALRPOP
LALRPOP is an LR(1) parser generator for Rust that produces efficient, table-driven parsers from declarative grammars. Unlike parser combinators, LALRPOP generates parsers at compile time from grammar files, providing excellent parsing performance and compile-time validation of grammar correctness. The generated parsers handle left recursion naturally and provide precise error messages with conflict detection during grammar compilation.
For compiler development, LALRPOP offers the traditional parser generator experience familiar to users of yacc, bison, or ANTLR, but with Rust’s type safety and zero-cost abstractions. The generated code integrates seamlessly with Rust’s ownership system, producing typed ASTs without runtime overhead. LALRPOP excels at parsing programming languages where performance matters and grammar complexity requires the power of LR parsing.
Basic Calculator
A simple arithmetic expression parser demonstrates LALRPOP’s syntax:
#![allow(unused)] fn main() { use std::str::FromStr; grammar; pub Expr: i32 = { <l:Expr> "+" <r:Factor> => l + r, <l:Expr> "-" <r:Factor> => l - r, Factor, }; Factor: i32 = { <l:Factor> "*" <r:Term> => l * r, <l:Factor> "/" <r:Term> => l / r, Term, }; Term: i32 = { Number, "(" <Expr> ")", }; Number: i32 = { r"[0-9]+" => i32::from_str(<>).unwrap(), }; }
This grammar correctly handles operator precedence through rule stratification. Addition and subtraction are parsed at a lower precedence level than multiplication and division. The parser is left-associative, parsing 10 - 2 - 3
as (10 - 2) - 3 = 5
.
AST Construction
Building typed ASTs requires defining the types and constructing them in grammar actions:
#![allow(unused)] fn main() { impl Expr { /// Evaluate the expression pub fn eval(&self) -> f64 { match self { Expr::Number(n) => *n, Expr::Add(l, r) => l.eval() + r.eval(), Expr::Subtract(l, r) => l.eval() - r.eval(), Expr::Multiply(l, r) => l.eval() * r.eval(), Expr::Divide(l, r) => l.eval() / r.eval(), Expr::Negate(e) => -e.eval(), Expr::Variable(_) => panic!("Cannot evaluate variable without context"), Expr::Call(name, args) => match name.as_str() { "max" if args.len() == 2 => f64::max(args[0].eval(), args[1].eval()), "min" if args.len() == 2 => f64::min(args[0].eval(), args[1].eval()), "sqrt" if args.len() == 1 => args[0].eval().sqrt(), _ => panic!("Unknown function: {}", name), }, } } } /// Statement types for a simple language #[derive(Debug, Clone, PartialEq)] pub enum Statement { Expression(Expr), Assignment(String, Expr), Print(Expr), If(Expr, Vec<Statement>, Option<Vec<Statement>>), While(Expr, Vec<Statement>), } /// A complete program #[derive(Debug, Clone, PartialEq)] pub struct Program { pub statements: Vec<Statement>, } /// AST for a simple expression language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(f64), Add(Box<Expr>, Box<Expr>), Subtract(Box<Expr>, Box<Expr>), Multiply(Box<Expr>, Box<Expr>), Divide(Box<Expr>, Box<Expr>), Negate(Box<Expr>), Variable(String), Call(String, Vec<Expr>), } }
The grammar constructs this AST:
#![allow(unused)] fn main() { use crate::ast::{Expr}; grammar; pub Expr: Expr = { <l:Expr> "+" <r:Term> => Expr::Add(Box::new(l), Box::new(r)), <l:Expr> "-" <r:Term> => Expr::Subtract(Box::new(l), Box::new(r)), Term, }; Term: Expr = { <l:Term> "*" <r:Factor> => Expr::Multiply(Box::new(l), Box::new(r)), <l:Term> "/" <r:Factor> => Expr::Divide(Box::new(l), Box::new(r)), Factor, }; Factor: Expr = { Primary, "-" <e:Factor> => Expr::Negate(Box::new(e)), }; Primary: Expr = { Number => Expr::Number(<>), Identifier => Expr::Variable(<>), "(" <Expr> ")", }; Number: f64 = { r"[0-9]+(\.[0-9]+)?" => f64::from_str(<>).unwrap(), }; Identifier: String = { r"[a-zA-Z][a-zA-Z0-9_]*" => <>.to_string(), }; }
Each production rule returns a value of the specified type. The angle bracket syntax <name:Rule>
binds matched values to variables used in the action code.
Statement Grammar
A more complete language with statements demonstrates complex grammar patterns:
#![allow(unused)] fn main() { /// AST for a simple expression language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(f64), Add(Box<Expr>, Box<Expr>), Subtract(Box<Expr>, Box<Expr>), Multiply(Box<Expr>, Box<Expr>), Divide(Box<Expr>, Box<Expr>), Negate(Box<Expr>), Variable(String), Call(String, Vec<Expr>), } impl Expr { /// Evaluate the expression pub fn eval(&self) -> f64 { match self { Expr::Number(n) => *n, Expr::Add(l, r) => l.eval() + r.eval(), Expr::Subtract(l, r) => l.eval() - r.eval(), Expr::Multiply(l, r) => l.eval() * r.eval(), Expr::Divide(l, r) => l.eval() / r.eval(), Expr::Negate(e) => -e.eval(), Expr::Variable(_) => panic!("Cannot evaluate variable without context"), Expr::Call(name, args) => match name.as_str() { "max" if args.len() == 2 => f64::max(args[0].eval(), args[1].eval()), "min" if args.len() == 2 => f64::min(args[0].eval(), args[1].eval()), "sqrt" if args.len() == 1 => args[0].eval().sqrt(), _ => panic!("Unknown function: {}", name), }, } } } /// Statement types for a simple language #[derive(Debug, Clone, PartialEq)] pub enum Statement { Expression(Expr), Assignment(String, Expr), Print(Expr), If(Expr, Vec<Statement>, Option<Vec<Statement>>), While(Expr, Vec<Statement>), } /// A complete program #[derive(Debug, Clone, PartialEq)] pub struct Program { pub statements: Vec<Statement>, } }
#![allow(unused)] fn main() { /// AST for a simple expression language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(f64), Add(Box<Expr>, Box<Expr>), Subtract(Box<Expr>, Box<Expr>), Multiply(Box<Expr>, Box<Expr>), Divide(Box<Expr>, Box<Expr>), Negate(Box<Expr>), Variable(String), Call(String, Vec<Expr>), } impl Expr { /// Evaluate the expression pub fn eval(&self) -> f64 { match self { Expr::Number(n) => *n, Expr::Add(l, r) => l.eval() + r.eval(), Expr::Subtract(l, r) => l.eval() - r.eval(), Expr::Multiply(l, r) => l.eval() * r.eval(), Expr::Divide(l, r) => l.eval() / r.eval(), Expr::Negate(e) => -e.eval(), Expr::Variable(_) => panic!("Cannot evaluate variable without context"), Expr::Call(name, args) => match name.as_str() { "max" if args.len() == 2 => f64::max(args[0].eval(), args[1].eval()), "min" if args.len() == 2 => f64::min(args[0].eval(), args[1].eval()), "sqrt" if args.len() == 1 => args[0].eval().sqrt(), _ => panic!("Unknown function: {}", name), }, } } } /// A complete program #[derive(Debug, Clone, PartialEq)] pub struct Program { pub statements: Vec<Statement>, } /// Statement types for a simple language #[derive(Debug, Clone, PartialEq)] pub enum Statement { Expression(Expr), Assignment(String, Expr), Print(Expr), If(Expr, Vec<Statement>, Option<Vec<Statement>>), While(Expr, Vec<Statement>), } }
The grammar for this language:
#![allow(unused)] fn main() { Statement: Statement = { <e:Expr> ";" => Statement::Expression(e), "let" <name:Identifier> "=" <e:Expr> ";" => Statement::Assignment(name, e), "print" <e:Expr> ";" => Statement::Print(e), "if" <cond:Expr> "{" <then:Statement*> "}" <els:("else" "{" <Statement*> "}")?> => { Statement::If(cond, then, els) }, "while" <cond:Expr> "{" <body:Statement*> "}" => Statement::While(cond, body), }; }
The *
operator creates lists of zero or more items. The ?
operator makes productions optional. Parentheses group sub-patterns for clarity.
Using External Lexers
LALRPOP can use external lexers like logos for improved performance and features:
#![allow(unused)] fn main() { use std::fmt; use logos::Logos; impl fmt::Display for Token { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Token::Plus => write!(f, "+"), Token::Minus => write!(f, "-"), Token::Star => write!(f, "*"), Token::Slash => write!(f, "/"), Token::Equals => write!(f, "="), Token::LeftParen => write!(f, "("), Token::RightParen => write!(f, ")"), Token::LeftBrace => write!(f, "{{"), Token::RightBrace => write!(f, "}}"), Token::Comma => write!(f, ","), Token::Semicolon => write!(f, ";"), Token::Let => write!(f, "let"), Token::If => write!(f, "if"), Token::Else => write!(f, "else"), Token::While => write!(f, "while"), Token::Print => write!(f, "print"), Token::True => write!(f, "true"), Token::False => write!(f, "false"), Token::Number(n) => write!(f, "{}", n), Token::Identifier(s) => write!(f, "{}", s), Token::StringLiteral(s) => write!(f, "\"{}\"", s), Token::Error => write!(f, "<error>"), } } } /// Token types for our language using logos #[derive(Logos, Debug, PartialEq, Clone)] pub enum Token { // Operators #[token("+")] Plus, #[token("-")] Minus, #[token("*")] Star, #[token("/")] Slash, #[token("=")] Equals, // Delimiters #[token("(")] LeftParen, #[token(")")] RightParen, #[token("{")] LeftBrace, #[token("}")] RightBrace, #[token(",")] Comma, #[token(";")] Semicolon, // Keywords #[token("let")] Let, #[token("if")] If, #[token("else")] Else, #[token("while")] While, #[token("print")] Print, #[token("true")] True, #[token("false")] False, // Literals #[regex(r"[0-9]+(\.[0-9]+)?", |lex| lex.slice().parse::<f64>().ok())] Number(f64), #[regex(r"[a-zA-Z][a-zA-Z0-9_]*", |lex| lex.slice().to_string())] Identifier(String), #[regex(r#""([^"\\]|\\.)*""#, |lex| { let s = lex.slice(); s[1..s.len()-1].to_string() })] StringLiteral(String), // Skip whitespace #[regex(r"[ \t\n\f]+", logos::skip)] #[regex(r"//[^\n]*", logos::skip)] Error, } }
The grammar declares the external token type:
#![allow(unused)] fn main() { extern { type Location = usize; type Error = String; enum Token { "+" => Token::Plus, "-" => Token::Minus, "*" => Token::Star, "/" => Token::Slash, "=" => Token::Equals, "(" => Token::LeftParen, ")" => Token::RightParen, "number" => Token::Number(<f64>), "identifier" => Token::Identifier(<String>), } } }
Terminal symbols in the grammar now refer to token variants. The angle brackets extract associated data from token variants.
Integration with Logos
Connecting logos lexer to LALRPOP parser:
#![allow(unused)] fn main() { pub mod ast; pub mod token; use lalrpop_util::lalrpop_mod; lalrpop_mod!(pub calculator_builtin); lalrpop_mod!(pub expression); lalrpop_mod!(pub expression_logos); lalrpop_mod!(pub left_recursion); use lalrpop_util::ParseError; use logos::Logos; /// Parse a simple calculator expression using built-in lexer pub fn parse_calculator(input: &str) -> Result<i32, String> { calculator_builtin::ExprParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Example of detailed error handling for parse errors pub fn parse_with_detailed_errors(input: &str) -> Result<i32, String> { let parser = calculator_builtin::ExprParser::new(); match parser.parse(input) { Ok(result) => Ok(result), Err(ParseError::InvalidToken { location }) => { Err(format!("Invalid token at position {}", location)) } Err(ParseError::UnrecognizedToken { token, expected }) => { let (start, _, end) = token; Err(format!( "Unexpected '{}' at position {}-{}, expected one of: {:?}", &input[start..end], start, end, expected )) } Err(ParseError::UnrecognizedEof { location, expected }) => Err(format!( "Unexpected end of input at position {}, expected: {:?}", location, expected )), Err(ParseError::ExtraToken { token }) => { let (start, _, end) = token; Err(format!( "Extra token '{}' at position {}-{} after valid input", &input[start..end], start, end )) } Err(ParseError::User { error }) => Err(format!("Parse error: {}", error)), } } /// Parse an expression language program using built-in lexer pub fn parse_expression(input: &str) -> Result<ast::Program, String> { expression::ProgramParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Example: Building a simple interpreter pub struct Interpreter { variables: std::collections::HashMap<String, f64>, } impl Interpreter { pub fn new() -> Self { Self { variables: std::collections::HashMap::new(), } } pub fn execute(&mut self, program: &ast::Program) -> Result<(), String> { for statement in &program.statements { self.execute_statement(statement)?; } Ok(()) } fn execute_statement(&mut self, stmt: &ast::Statement) -> Result<(), String> { match stmt { ast::Statement::Expression(expr) => { self.eval_expr(expr)?; Ok(()) } ast::Statement::Assignment(name, expr) => { let value = self.eval_expr(expr)?; self.variables.insert(name.clone(), value); Ok(()) } ast::Statement::Print(expr) => { let value = self.eval_expr(expr)?; println!("{}", value); Ok(()) } ast::Statement::If(cond, then_block, else_block) => { let cond_value = self.eval_expr(cond)?; if cond_value != 0.0 { for stmt in then_block { self.execute_statement(stmt)?; } } else if let Some(else_stmts) = else_block { for stmt in else_stmts { self.execute_statement(stmt)?; } } Ok(()) } ast::Statement::While(cond, body) => { while self.eval_expr(cond)? != 0.0 { for stmt in body { self.execute_statement(stmt)?; } } Ok(()) } } } fn eval_expr(&self, expr: &ast::Expr) -> Result<f64, String> { match expr { ast::Expr::Number(n) => Ok(*n), ast::Expr::Add(l, r) => Ok(self.eval_expr(l)? + self.eval_expr(r)?), ast::Expr::Subtract(l, r) => Ok(self.eval_expr(l)? - self.eval_expr(r)?), ast::Expr::Multiply(l, r) => Ok(self.eval_expr(l)? * self.eval_expr(r)?), ast::Expr::Divide(l, r) => { let divisor = self.eval_expr(r)?; if divisor == 0.0 { Err("Division by zero".to_string()) } else { Ok(self.eval_expr(l)? / divisor) } } ast::Expr::Negate(e) => Ok(-self.eval_expr(e)?), ast::Expr::Variable(name) => self .variables .get(name) .copied() .ok_or_else(|| format!("Undefined variable: {}", name)), ast::Expr::Call(name, args) => { let arg_values: Result<Vec<_>, _> = args.iter().map(|e| self.eval_expr(e)).collect(); let arg_values = arg_values?; match name.as_str() { "max" if arg_values.len() == 2 => Ok(f64::max(arg_values[0], arg_values[1])), "min" if arg_values.len() == 2 => Ok(f64::min(arg_values[0], arg_values[1])), "sqrt" if arg_values.len() == 1 => Ok(arg_values[0].sqrt()), "abs" if arg_values.len() == 1 => Ok(arg_values[0].abs()), _ => Err(format!("Unknown function: {}", name)), } } } } } impl Default for Interpreter { fn default() -> Self { Self::new() } } /// Demonstrate left vs right associativity parsing pub fn demonstrate_associativity(input: &str) -> (String, String) { let left = left_recursion::LeftAssociativeParser::new() .parse(input) .map(|e| format!("{:?}", e)) .unwrap_or_else(|e| format!("Error: {:?}", e)); let right = left_recursion::RightAssociativeParser::new() .parse(input) .map(|e| format!("{:?}", e)) .unwrap_or_else(|e| format!("Error: {:?}", e)); (left, right) } /// Parse a comma-separated list using left recursion pub fn parse_list_left(input: &str) -> Result<Vec<i32>, String> { left_recursion::CommaSeparatedLeftParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse field access chains like "obj.field1.field2" pub fn parse_field_access(input: &str) -> Result<String, String> { left_recursion::FieldAccessParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse method chains like "obj.method1().method2()" pub fn parse_method_chain(input: &str) -> Result<String, String> { left_recursion::MethodChainParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse expressions with full operator precedence pub fn parse_with_precedence(input: &str) -> Result<ast::Expr, String> { left_recursion::ExprParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } #[cfg(test)] mod tests { use super::*; #[test] fn test_calculator() { assert_eq!(parse_calculator("2 + 3 * 4").unwrap(), 14); assert_eq!(parse_calculator("(2 + 3) * 4").unwrap(), 20); assert_eq!(parse_calculator("10 - 2 - 3").unwrap(), 5); } #[test] fn test_expression_parser() { let program = parse_expression("let x = 10; let y = 20; print x + y;").unwrap(); assert_eq!(program.statements.len(), 3); } #[test] fn test_logos_parser() { let program = parse_with_logos("let x = 5; print x * 2;").unwrap(); assert_eq!(program.statements.len(), 2); } #[test] fn test_interpreter() { let mut interpreter = Interpreter::new(); let program = parse_expression( "let x = 10; let y = 20; let z = x + y; print z;", ) .unwrap(); interpreter.execute(&program).unwrap(); assert_eq!(*interpreter.variables.get("z").unwrap(), 30.0); } #[test] fn test_if_statement() { let program = parse_expression( "let x = 5; if x { let y = 10; }", ) .unwrap(); assert_eq!(program.statements.len(), 2); } #[test] fn test_function_calls() { let program = parse_expression( "let x = max(10, 20); let y = sqrt(16);", ) .unwrap(); let mut interpreter = Interpreter::new(); interpreter.execute(&program).unwrap(); assert_eq!(*interpreter.variables.get("x").unwrap(), 20.0); assert_eq!(*interpreter.variables.get("y").unwrap(), 4.0); } #[test] fn test_left_vs_right_associativity() { // Test that subtraction is left-associative // 10 - 5 - 2 should be (10 - 5) - 2 = 3 for left // and 10 - (5 - 2) = 7 for right let (left, right) = demonstrate_associativity("10 - 5 - 2"); assert!(left.contains("Subtract")); assert!(right.contains("Subtract")); } #[test] fn test_comma_separated_list() { let result = parse_list_left("1, 2, 3, 4, 5").unwrap(); assert_eq!(result, vec![1, 2, 3, 4, 5]); } #[test] fn test_field_access_chain() { let result = parse_field_access("obj.field1.field2.field3").unwrap(); assert_eq!(result, "obj.field1.field2.field3"); } #[test] fn test_method_chain() { let result = parse_method_chain("obj.method1().method2().method3()").unwrap(); assert_eq!(result, "obj.method1().method2().method3()"); } #[test] fn test_operator_precedence() { // Test that * has higher precedence than + // 2 + 3 * 4 should be 2 + (3 * 4) = 14 let expr = parse_with_precedence("2 + 3 * 4").unwrap(); assert_eq!(expr.eval(), 14.0); } } /// Parse using logos for lexing pub fn parse_with_logos(input: &str) -> Result<ast::Program, String> { let lexer = token::Token::lexer(input); let tokens: Result<Vec<_>, _> = lexer .spanned() .map(|(tok, span)| match tok { Ok(t) => Ok((span.start, t, span.end)), Err(_) => Err("Lexer error"), }) .collect(); match tokens { Ok(tokens) => expression_logos::ProgramParser::new() .parse(tokens) .map_err(|e| format!("Parse error: {:?}", e)), Err(e) => Err(e.to_string()), } } }
The lexer produces tokens with location information that LALRPOP uses for error reporting. This separation of concerns allows optimizing lexer and parser independently.
Helper Rules
Common patterns can be abstracted into parameterized rules:
#![allow(unused)] fn main() { Comma<T>: Vec<T> = { <mut v:(<T> ",")*> <e:T?> => match e { None => v, Some(e) => { v.push(e); v } } }; }
This generic rule parses comma-separated lists of any type. Use it like <args:Comma<Expr>>
to parse function arguments.
Build Configuration
LALRPOP requires a build script to generate parsers:
extern crate lalrpop; fn main() { lalrpop::process_root().unwrap(); }
The build script processes all .lalrpop
files in the source tree, generating corresponding Rust modules.
Using Generated Parsers
Import generated parsers with the lalrpop_mod macro and use them to parse input:
#![allow(unused)] fn main() { pub mod ast; pub mod token; use lalrpop_util::lalrpop_mod; lalrpop_mod!(pub calculator_builtin); lalrpop_mod!(pub expression); lalrpop_mod!(pub expression_logos); lalrpop_mod!(pub left_recursion); use lalrpop_util::ParseError; use logos::Logos; /// Example of detailed error handling for parse errors pub fn parse_with_detailed_errors(input: &str) -> Result<i32, String> { let parser = calculator_builtin::ExprParser::new(); match parser.parse(input) { Ok(result) => Ok(result), Err(ParseError::InvalidToken { location }) => { Err(format!("Invalid token at position {}", location)) } Err(ParseError::UnrecognizedToken { token, expected }) => { let (start, _, end) = token; Err(format!( "Unexpected '{}' at position {}-{}, expected one of: {:?}", &input[start..end], start, end, expected )) } Err(ParseError::UnrecognizedEof { location, expected }) => Err(format!( "Unexpected end of input at position {}, expected: {:?}", location, expected )), Err(ParseError::ExtraToken { token }) => { let (start, _, end) = token; Err(format!( "Extra token '{}' at position {}-{} after valid input", &input[start..end], start, end )) } Err(ParseError::User { error }) => Err(format!("Parse error: {}", error)), } } /// Parse an expression language program using built-in lexer pub fn parse_expression(input: &str) -> Result<ast::Program, String> { expression::ProgramParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse using logos for lexing pub fn parse_with_logos(input: &str) -> Result<ast::Program, String> { let lexer = token::Token::lexer(input); let tokens: Result<Vec<_>, _> = lexer .spanned() .map(|(tok, span)| match tok { Ok(t) => Ok((span.start, t, span.end)), Err(_) => Err("Lexer error"), }) .collect(); match tokens { Ok(tokens) => expression_logos::ProgramParser::new() .parse(tokens) .map_err(|e| format!("Parse error: {:?}", e)), Err(e) => Err(e.to_string()), } } /// Example: Building a simple interpreter pub struct Interpreter { variables: std::collections::HashMap<String, f64>, } impl Interpreter { pub fn new() -> Self { Self { variables: std::collections::HashMap::new(), } } pub fn execute(&mut self, program: &ast::Program) -> Result<(), String> { for statement in &program.statements { self.execute_statement(statement)?; } Ok(()) } fn execute_statement(&mut self, stmt: &ast::Statement) -> Result<(), String> { match stmt { ast::Statement::Expression(expr) => { self.eval_expr(expr)?; Ok(()) } ast::Statement::Assignment(name, expr) => { let value = self.eval_expr(expr)?; self.variables.insert(name.clone(), value); Ok(()) } ast::Statement::Print(expr) => { let value = self.eval_expr(expr)?; println!("{}", value); Ok(()) } ast::Statement::If(cond, then_block, else_block) => { let cond_value = self.eval_expr(cond)?; if cond_value != 0.0 { for stmt in then_block { self.execute_statement(stmt)?; } } else if let Some(else_stmts) = else_block { for stmt in else_stmts { self.execute_statement(stmt)?; } } Ok(()) } ast::Statement::While(cond, body) => { while self.eval_expr(cond)? != 0.0 { for stmt in body { self.execute_statement(stmt)?; } } Ok(()) } } } fn eval_expr(&self, expr: &ast::Expr) -> Result<f64, String> { match expr { ast::Expr::Number(n) => Ok(*n), ast::Expr::Add(l, r) => Ok(self.eval_expr(l)? + self.eval_expr(r)?), ast::Expr::Subtract(l, r) => Ok(self.eval_expr(l)? - self.eval_expr(r)?), ast::Expr::Multiply(l, r) => Ok(self.eval_expr(l)? * self.eval_expr(r)?), ast::Expr::Divide(l, r) => { let divisor = self.eval_expr(r)?; if divisor == 0.0 { Err("Division by zero".to_string()) } else { Ok(self.eval_expr(l)? / divisor) } } ast::Expr::Negate(e) => Ok(-self.eval_expr(e)?), ast::Expr::Variable(name) => self .variables .get(name) .copied() .ok_or_else(|| format!("Undefined variable: {}", name)), ast::Expr::Call(name, args) => { let arg_values: Result<Vec<_>, _> = args.iter().map(|e| self.eval_expr(e)).collect(); let arg_values = arg_values?; match name.as_str() { "max" if arg_values.len() == 2 => Ok(f64::max(arg_values[0], arg_values[1])), "min" if arg_values.len() == 2 => Ok(f64::min(arg_values[0], arg_values[1])), "sqrt" if arg_values.len() == 1 => Ok(arg_values[0].sqrt()), "abs" if arg_values.len() == 1 => Ok(arg_values[0].abs()), _ => Err(format!("Unknown function: {}", name)), } } } } } impl Default for Interpreter { fn default() -> Self { Self::new() } } /// Demonstrate left vs right associativity parsing pub fn demonstrate_associativity(input: &str) -> (String, String) { let left = left_recursion::LeftAssociativeParser::new() .parse(input) .map(|e| format!("{:?}", e)) .unwrap_or_else(|e| format!("Error: {:?}", e)); let right = left_recursion::RightAssociativeParser::new() .parse(input) .map(|e| format!("{:?}", e)) .unwrap_or_else(|e| format!("Error: {:?}", e)); (left, right) } /// Parse a comma-separated list using left recursion pub fn parse_list_left(input: &str) -> Result<Vec<i32>, String> { left_recursion::CommaSeparatedLeftParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse field access chains like "obj.field1.field2" pub fn parse_field_access(input: &str) -> Result<String, String> { left_recursion::FieldAccessParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse method chains like "obj.method1().method2()" pub fn parse_method_chain(input: &str) -> Result<String, String> { left_recursion::MethodChainParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse expressions with full operator precedence pub fn parse_with_precedence(input: &str) -> Result<ast::Expr, String> { left_recursion::ExprParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } #[cfg(test)] mod tests { use super::*; #[test] fn test_calculator() { assert_eq!(parse_calculator("2 + 3 * 4").unwrap(), 14); assert_eq!(parse_calculator("(2 + 3) * 4").unwrap(), 20); assert_eq!(parse_calculator("10 - 2 - 3").unwrap(), 5); } #[test] fn test_expression_parser() { let program = parse_expression("let x = 10; let y = 20; print x + y;").unwrap(); assert_eq!(program.statements.len(), 3); } #[test] fn test_logos_parser() { let program = parse_with_logos("let x = 5; print x * 2;").unwrap(); assert_eq!(program.statements.len(), 2); } #[test] fn test_interpreter() { let mut interpreter = Interpreter::new(); let program = parse_expression( "let x = 10; let y = 20; let z = x + y; print z;", ) .unwrap(); interpreter.execute(&program).unwrap(); assert_eq!(*interpreter.variables.get("z").unwrap(), 30.0); } #[test] fn test_if_statement() { let program = parse_expression( "let x = 5; if x { let y = 10; }", ) .unwrap(); assert_eq!(program.statements.len(), 2); } #[test] fn test_function_calls() { let program = parse_expression( "let x = max(10, 20); let y = sqrt(16);", ) .unwrap(); let mut interpreter = Interpreter::new(); interpreter.execute(&program).unwrap(); assert_eq!(*interpreter.variables.get("x").unwrap(), 20.0); assert_eq!(*interpreter.variables.get("y").unwrap(), 4.0); } #[test] fn test_left_vs_right_associativity() { // Test that subtraction is left-associative // 10 - 5 - 2 should be (10 - 5) - 2 = 3 for left // and 10 - (5 - 2) = 7 for right let (left, right) = demonstrate_associativity("10 - 5 - 2"); assert!(left.contains("Subtract")); assert!(right.contains("Subtract")); } #[test] fn test_comma_separated_list() { let result = parse_list_left("1, 2, 3, 4, 5").unwrap(); assert_eq!(result, vec![1, 2, 3, 4, 5]); } #[test] fn test_field_access_chain() { let result = parse_field_access("obj.field1.field2.field3").unwrap(); assert_eq!(result, "obj.field1.field2.field3"); } #[test] fn test_method_chain() { let result = parse_method_chain("obj.method1().method2().method3()").unwrap(); assert_eq!(result, "obj.method1().method2().method3()"); } #[test] fn test_operator_precedence() { // Test that * has higher precedence than + // 2 + 3 * 4 should be 2 + (3 * 4) = 14 let expr = parse_with_precedence("2 + 3 * 4").unwrap(); assert_eq!(expr.eval(), 14.0); } } /// Parse a simple calculator expression using built-in lexer pub fn parse_calculator(input: &str) -> Result<i32, String> { calculator_builtin::ExprParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } }
Each public rule in the grammar generates a corresponding parser struct with a parse
method.
Error Handling
LALRPOP provides detailed error information with location tracking and expected tokens:
#![allow(unused)] fn main() { pub mod ast; pub mod token; use lalrpop_util::lalrpop_mod; lalrpop_mod!(pub calculator_builtin); lalrpop_mod!(pub expression); lalrpop_mod!(pub expression_logos); lalrpop_mod!(pub left_recursion); use lalrpop_util::ParseError; use logos::Logos; /// Parse a simple calculator expression using built-in lexer pub fn parse_calculator(input: &str) -> Result<i32, String> { calculator_builtin::ExprParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse an expression language program using built-in lexer pub fn parse_expression(input: &str) -> Result<ast::Program, String> { expression::ProgramParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse using logos for lexing pub fn parse_with_logos(input: &str) -> Result<ast::Program, String> { let lexer = token::Token::lexer(input); let tokens: Result<Vec<_>, _> = lexer .spanned() .map(|(tok, span)| match tok { Ok(t) => Ok((span.start, t, span.end)), Err(_) => Err("Lexer error"), }) .collect(); match tokens { Ok(tokens) => expression_logos::ProgramParser::new() .parse(tokens) .map_err(|e| format!("Parse error: {:?}", e)), Err(e) => Err(e.to_string()), } } /// Example: Building a simple interpreter pub struct Interpreter { variables: std::collections::HashMap<String, f64>, } impl Interpreter { pub fn new() -> Self { Self { variables: std::collections::HashMap::new(), } } pub fn execute(&mut self, program: &ast::Program) -> Result<(), String> { for statement in &program.statements { self.execute_statement(statement)?; } Ok(()) } fn execute_statement(&mut self, stmt: &ast::Statement) -> Result<(), String> { match stmt { ast::Statement::Expression(expr) => { self.eval_expr(expr)?; Ok(()) } ast::Statement::Assignment(name, expr) => { let value = self.eval_expr(expr)?; self.variables.insert(name.clone(), value); Ok(()) } ast::Statement::Print(expr) => { let value = self.eval_expr(expr)?; println!("{}", value); Ok(()) } ast::Statement::If(cond, then_block, else_block) => { let cond_value = self.eval_expr(cond)?; if cond_value != 0.0 { for stmt in then_block { self.execute_statement(stmt)?; } } else if let Some(else_stmts) = else_block { for stmt in else_stmts { self.execute_statement(stmt)?; } } Ok(()) } ast::Statement::While(cond, body) => { while self.eval_expr(cond)? != 0.0 { for stmt in body { self.execute_statement(stmt)?; } } Ok(()) } } } fn eval_expr(&self, expr: &ast::Expr) -> Result<f64, String> { match expr { ast::Expr::Number(n) => Ok(*n), ast::Expr::Add(l, r) => Ok(self.eval_expr(l)? + self.eval_expr(r)?), ast::Expr::Subtract(l, r) => Ok(self.eval_expr(l)? - self.eval_expr(r)?), ast::Expr::Multiply(l, r) => Ok(self.eval_expr(l)? * self.eval_expr(r)?), ast::Expr::Divide(l, r) => { let divisor = self.eval_expr(r)?; if divisor == 0.0 { Err("Division by zero".to_string()) } else { Ok(self.eval_expr(l)? / divisor) } } ast::Expr::Negate(e) => Ok(-self.eval_expr(e)?), ast::Expr::Variable(name) => self .variables .get(name) .copied() .ok_or_else(|| format!("Undefined variable: {}", name)), ast::Expr::Call(name, args) => { let arg_values: Result<Vec<_>, _> = args.iter().map(|e| self.eval_expr(e)).collect(); let arg_values = arg_values?; match name.as_str() { "max" if arg_values.len() == 2 => Ok(f64::max(arg_values[0], arg_values[1])), "min" if arg_values.len() == 2 => Ok(f64::min(arg_values[0], arg_values[1])), "sqrt" if arg_values.len() == 1 => Ok(arg_values[0].sqrt()), "abs" if arg_values.len() == 1 => Ok(arg_values[0].abs()), _ => Err(format!("Unknown function: {}", name)), } } } } } impl Default for Interpreter { fn default() -> Self { Self::new() } } /// Demonstrate left vs right associativity parsing pub fn demonstrate_associativity(input: &str) -> (String, String) { let left = left_recursion::LeftAssociativeParser::new() .parse(input) .map(|e| format!("{:?}", e)) .unwrap_or_else(|e| format!("Error: {:?}", e)); let right = left_recursion::RightAssociativeParser::new() .parse(input) .map(|e| format!("{:?}", e)) .unwrap_or_else(|e| format!("Error: {:?}", e)); (left, right) } /// Parse a comma-separated list using left recursion pub fn parse_list_left(input: &str) -> Result<Vec<i32>, String> { left_recursion::CommaSeparatedLeftParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse field access chains like "obj.field1.field2" pub fn parse_field_access(input: &str) -> Result<String, String> { left_recursion::FieldAccessParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse method chains like "obj.method1().method2()" pub fn parse_method_chain(input: &str) -> Result<String, String> { left_recursion::MethodChainParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse expressions with full operator precedence pub fn parse_with_precedence(input: &str) -> Result<ast::Expr, String> { left_recursion::ExprParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } #[cfg(test)] mod tests { use super::*; #[test] fn test_calculator() { assert_eq!(parse_calculator("2 + 3 * 4").unwrap(), 14); assert_eq!(parse_calculator("(2 + 3) * 4").unwrap(), 20); assert_eq!(parse_calculator("10 - 2 - 3").unwrap(), 5); } #[test] fn test_expression_parser() { let program = parse_expression("let x = 10; let y = 20; print x + y;").unwrap(); assert_eq!(program.statements.len(), 3); } #[test] fn test_logos_parser() { let program = parse_with_logos("let x = 5; print x * 2;").unwrap(); assert_eq!(program.statements.len(), 2); } #[test] fn test_interpreter() { let mut interpreter = Interpreter::new(); let program = parse_expression( "let x = 10; let y = 20; let z = x + y; print z;", ) .unwrap(); interpreter.execute(&program).unwrap(); assert_eq!(*interpreter.variables.get("z").unwrap(), 30.0); } #[test] fn test_if_statement() { let program = parse_expression( "let x = 5; if x { let y = 10; }", ) .unwrap(); assert_eq!(program.statements.len(), 2); } #[test] fn test_function_calls() { let program = parse_expression( "let x = max(10, 20); let y = sqrt(16);", ) .unwrap(); let mut interpreter = Interpreter::new(); interpreter.execute(&program).unwrap(); assert_eq!(*interpreter.variables.get("x").unwrap(), 20.0); assert_eq!(*interpreter.variables.get("y").unwrap(), 4.0); } #[test] fn test_left_vs_right_associativity() { // Test that subtraction is left-associative // 10 - 5 - 2 should be (10 - 5) - 2 = 3 for left // and 10 - (5 - 2) = 7 for right let (left, right) = demonstrate_associativity("10 - 5 - 2"); assert!(left.contains("Subtract")); assert!(right.contains("Subtract")); } #[test] fn test_comma_separated_list() { let result = parse_list_left("1, 2, 3, 4, 5").unwrap(); assert_eq!(result, vec![1, 2, 3, 4, 5]); } #[test] fn test_field_access_chain() { let result = parse_field_access("obj.field1.field2.field3").unwrap(); assert_eq!(result, "obj.field1.field2.field3"); } #[test] fn test_method_chain() { let result = parse_method_chain("obj.method1().method2().method3()").unwrap(); assert_eq!(result, "obj.method1().method2().method3()"); } #[test] fn test_operator_precedence() { // Test that * has higher precedence than + // 2 + 3 * 4 should be 2 + (3 * 4) = 14 let expr = parse_with_precedence("2 + 3 * 4").unwrap(); assert_eq!(expr.eval(), 14.0); } } /// Example of detailed error handling for parse errors pub fn parse_with_detailed_errors(input: &str) -> Result<i32, String> { let parser = calculator_builtin::ExprParser::new(); match parser.parse(input) { Ok(result) => Ok(result), Err(ParseError::InvalidToken { location }) => { Err(format!("Invalid token at position {}", location)) } Err(ParseError::UnrecognizedToken { token, expected }) => { let (start, _, end) = token; Err(format!( "Unexpected '{}' at position {}-{}, expected one of: {:?}", &input[start..end], start, end, expected )) } Err(ParseError::UnrecognizedEof { location, expected }) => Err(format!( "Unexpected end of input at position {}, expected: {:?}", location, expected )), Err(ParseError::ExtraToken { token }) => { let (start, _, end) = token; Err(format!( "Extra token '{}' at position {}-{} after valid input", &input[start..end], start, end )) } Err(ParseError::User { error }) => Err(format!("Parse error: {}", error)), } } }
The error types include location information and expected tokens, enabling high-quality error messages. The parser tracks byte positions which can be converted to line and column numbers for user-friendly error reporting.
Precedence and Associativity
Operator precedence is controlled by grammar structure:
#![allow(unused)] fn main() { // Lower precedence Expr: Expr = { <l:Expr> "||" <r:AndExpr> => Expr::Or(Box::new(l), Box::new(r)), AndExpr, }; AndExpr: Expr = { <l:AndExpr> "&&" <r:CmpExpr> => Expr::And(Box::new(l), Box::new(r)), CmpExpr, }; // Higher precedence CmpExpr: Expr = { <l:CmpExpr> "==" <r:AddExpr> => Expr::Equal(Box::new(l), Box::new(r)), <l:CmpExpr> "!=" <r:AddExpr> => Expr::NotEqual(Box::new(l), Box::new(r)), AddExpr, }; }
Rules lower in the grammar hierarchy have higher precedence. Left recursion creates left associativity; right recursion creates right associativity.
Left Recursion
Unlike recursive descent parsers and PEG parsers, LALRPOP handles left recursion naturally and efficiently. This is a fundamental advantage of LR parsing that enables intuitive grammar definitions for left-associative operators and list constructions.
Consider the difference between left and right associative parsing for subtraction:
#![allow(unused)] fn main() { // LEFT RECURSIVE - parses "10 - 5 - 2" as (10 - 5) - 2 = 3 pub LeftAssociative: Expr = { <l:LeftAssociative> "-" <r:Term> => Expr::Subtract(Box::new(l), Box::new(r)), Term, }; // RIGHT RECURSIVE - parses "10 - 5 - 2" as 10 - (5 - 2) = 7 pub RightAssociative: Expr = { <l:Term> "-" <r:RightAssociative> => Expr::Subtract(Box::new(l), Box::new(r)), Term, }; }
The left recursive version correctly implements the standard mathematical interpretation where operations associate left to right. This natural expression of grammar rules is impossible in top-down parsers without transformation.
Left recursion excels at parsing lists that build incrementally:
#![allow(unused)] fn main() { // Builds list as items are encountered pub CommaSeparatedLeft: Vec<i32> = { <mut list:CommaSeparatedLeft> "," <item:Number> => { list.push(item); list }, <n:Number> => vec![n], }; }
Field access and method chaining naturally use left recursion:
#![allow(unused)] fn main() { // Parses "obj.field1.field2" correctly pub FieldAccess: String = { <obj:FieldAccess> "." <field:Identifier> => format!("{}.{}", obj, field), Identifier, }; // Parses "obj.method1().method2()" correctly pub MethodChain: String = { <obj:MethodChain> "." <method:Identifier> "(" ")" => format!("{}.{}()", obj, method), Identifier, }; }
These patterns appear frequently in programming languages where operations chain from left to right. The ability to express them directly as left recursive rules simplifies grammar development and improves parser performance.
Postfix operators also benefit from left recursion:
#![allow(unused)] fn main() { // Array indexing, function calls, and postfix increment PostfixExpr: Expr = { <e:PostfixExpr> "[" <index:Expr> "]" => Expr::Index(Box::new(e), Box::new(index)), <func:PostfixExpr> "(" <args:Arguments> ")" => Expr::Call(Box::new(func), args), <e:PostfixExpr> "++" => Expr::PostIncrement(Box::new(e)), PrimaryExpr, }; }
Testing associativity demonstrates the difference:
#![allow(unused)] fn main() { pub mod ast; pub mod token; use lalrpop_util::lalrpop_mod; lalrpop_mod!(pub calculator_builtin); lalrpop_mod!(pub expression); lalrpop_mod!(pub expression_logos); lalrpop_mod!(pub left_recursion); use lalrpop_util::ParseError; use logos::Logos; /// Parse a simple calculator expression using built-in lexer pub fn parse_calculator(input: &str) -> Result<i32, String> { calculator_builtin::ExprParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Example of detailed error handling for parse errors pub fn parse_with_detailed_errors(input: &str) -> Result<i32, String> { let parser = calculator_builtin::ExprParser::new(); match parser.parse(input) { Ok(result) => Ok(result), Err(ParseError::InvalidToken { location }) => { Err(format!("Invalid token at position {}", location)) } Err(ParseError::UnrecognizedToken { token, expected }) => { let (start, _, end) = token; Err(format!( "Unexpected '{}' at position {}-{}, expected one of: {:?}", &input[start..end], start, end, expected )) } Err(ParseError::UnrecognizedEof { location, expected }) => Err(format!( "Unexpected end of input at position {}, expected: {:?}", location, expected )), Err(ParseError::ExtraToken { token }) => { let (start, _, end) = token; Err(format!( "Extra token '{}' at position {}-{} after valid input", &input[start..end], start, end )) } Err(ParseError::User { error }) => Err(format!("Parse error: {}", error)), } } /// Parse an expression language program using built-in lexer pub fn parse_expression(input: &str) -> Result<ast::Program, String> { expression::ProgramParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse using logos for lexing pub fn parse_with_logos(input: &str) -> Result<ast::Program, String> { let lexer = token::Token::lexer(input); let tokens: Result<Vec<_>, _> = lexer .spanned() .map(|(tok, span)| match tok { Ok(t) => Ok((span.start, t, span.end)), Err(_) => Err("Lexer error"), }) .collect(); match tokens { Ok(tokens) => expression_logos::ProgramParser::new() .parse(tokens) .map_err(|e| format!("Parse error: {:?}", e)), Err(e) => Err(e.to_string()), } } /// Example: Building a simple interpreter pub struct Interpreter { variables: std::collections::HashMap<String, f64>, } impl Interpreter { pub fn new() -> Self { Self { variables: std::collections::HashMap::new(), } } pub fn execute(&mut self, program: &ast::Program) -> Result<(), String> { for statement in &program.statements { self.execute_statement(statement)?; } Ok(()) } fn execute_statement(&mut self, stmt: &ast::Statement) -> Result<(), String> { match stmt { ast::Statement::Expression(expr) => { self.eval_expr(expr)?; Ok(()) } ast::Statement::Assignment(name, expr) => { let value = self.eval_expr(expr)?; self.variables.insert(name.clone(), value); Ok(()) } ast::Statement::Print(expr) => { let value = self.eval_expr(expr)?; println!("{}", value); Ok(()) } ast::Statement::If(cond, then_block, else_block) => { let cond_value = self.eval_expr(cond)?; if cond_value != 0.0 { for stmt in then_block { self.execute_statement(stmt)?; } } else if let Some(else_stmts) = else_block { for stmt in else_stmts { self.execute_statement(stmt)?; } } Ok(()) } ast::Statement::While(cond, body) => { while self.eval_expr(cond)? != 0.0 { for stmt in body { self.execute_statement(stmt)?; } } Ok(()) } } } fn eval_expr(&self, expr: &ast::Expr) -> Result<f64, String> { match expr { ast::Expr::Number(n) => Ok(*n), ast::Expr::Add(l, r) => Ok(self.eval_expr(l)? + self.eval_expr(r)?), ast::Expr::Subtract(l, r) => Ok(self.eval_expr(l)? - self.eval_expr(r)?), ast::Expr::Multiply(l, r) => Ok(self.eval_expr(l)? * self.eval_expr(r)?), ast::Expr::Divide(l, r) => { let divisor = self.eval_expr(r)?; if divisor == 0.0 { Err("Division by zero".to_string()) } else { Ok(self.eval_expr(l)? / divisor) } } ast::Expr::Negate(e) => Ok(-self.eval_expr(e)?), ast::Expr::Variable(name) => self .variables .get(name) .copied() .ok_or_else(|| format!("Undefined variable: {}", name)), ast::Expr::Call(name, args) => { let arg_values: Result<Vec<_>, _> = args.iter().map(|e| self.eval_expr(e)).collect(); let arg_values = arg_values?; match name.as_str() { "max" if arg_values.len() == 2 => Ok(f64::max(arg_values[0], arg_values[1])), "min" if arg_values.len() == 2 => Ok(f64::min(arg_values[0], arg_values[1])), "sqrt" if arg_values.len() == 1 => Ok(arg_values[0].sqrt()), "abs" if arg_values.len() == 1 => Ok(arg_values[0].abs()), _ => Err(format!("Unknown function: {}", name)), } } } } } impl Default for Interpreter { fn default() -> Self { Self::new() } } /// Parse a comma-separated list using left recursion pub fn parse_list_left(input: &str) -> Result<Vec<i32>, String> { left_recursion::CommaSeparatedLeftParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse field access chains like "obj.field1.field2" pub fn parse_field_access(input: &str) -> Result<String, String> { left_recursion::FieldAccessParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse method chains like "obj.method1().method2()" pub fn parse_method_chain(input: &str) -> Result<String, String> { left_recursion::MethodChainParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse expressions with full operator precedence pub fn parse_with_precedence(input: &str) -> Result<ast::Expr, String> { left_recursion::ExprParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } #[cfg(test)] mod tests { use super::*; #[test] fn test_calculator() { assert_eq!(parse_calculator("2 + 3 * 4").unwrap(), 14); assert_eq!(parse_calculator("(2 + 3) * 4").unwrap(), 20); assert_eq!(parse_calculator("10 - 2 - 3").unwrap(), 5); } #[test] fn test_expression_parser() { let program = parse_expression("let x = 10; let y = 20; print x + y;").unwrap(); assert_eq!(program.statements.len(), 3); } #[test] fn test_logos_parser() { let program = parse_with_logos("let x = 5; print x * 2;").unwrap(); assert_eq!(program.statements.len(), 2); } #[test] fn test_interpreter() { let mut interpreter = Interpreter::new(); let program = parse_expression( "let x = 10; let y = 20; let z = x + y; print z;", ) .unwrap(); interpreter.execute(&program).unwrap(); assert_eq!(*interpreter.variables.get("z").unwrap(), 30.0); } #[test] fn test_if_statement() { let program = parse_expression( "let x = 5; if x { let y = 10; }", ) .unwrap(); assert_eq!(program.statements.len(), 2); } #[test] fn test_function_calls() { let program = parse_expression( "let x = max(10, 20); let y = sqrt(16);", ) .unwrap(); let mut interpreter = Interpreter::new(); interpreter.execute(&program).unwrap(); assert_eq!(*interpreter.variables.get("x").unwrap(), 20.0); assert_eq!(*interpreter.variables.get("y").unwrap(), 4.0); } #[test] fn test_left_vs_right_associativity() { // Test that subtraction is left-associative // 10 - 5 - 2 should be (10 - 5) - 2 = 3 for left // and 10 - (5 - 2) = 7 for right let (left, right) = demonstrate_associativity("10 - 5 - 2"); assert!(left.contains("Subtract")); assert!(right.contains("Subtract")); } #[test] fn test_comma_separated_list() { let result = parse_list_left("1, 2, 3, 4, 5").unwrap(); assert_eq!(result, vec![1, 2, 3, 4, 5]); } #[test] fn test_field_access_chain() { let result = parse_field_access("obj.field1.field2.field3").unwrap(); assert_eq!(result, "obj.field1.field2.field3"); } #[test] fn test_method_chain() { let result = parse_method_chain("obj.method1().method2().method3()").unwrap(); assert_eq!(result, "obj.method1().method2().method3()"); } #[test] fn test_operator_precedence() { // Test that * has higher precedence than + // 2 + 3 * 4 should be 2 + (3 * 4) = 14 let expr = parse_with_precedence("2 + 3 * 4").unwrap(); assert_eq!(expr.eval(), 14.0); } } /// Demonstrate left vs right associativity parsing pub fn demonstrate_associativity(input: &str) -> (String, String) { let left = left_recursion::LeftAssociativeParser::new() .parse(input) .map(|e| format!("{:?}", e)) .unwrap_or_else(|e| format!("Error: {:?}", e)); let right = left_recursion::RightAssociativeParser::new() .parse(input) .map(|e| format!("{:?}", e)) .unwrap_or_else(|e| format!("Error: {:?}", e)); (left, right) } }
The function parses the same input with both left and right associative grammars, revealing how the parse tree structure differs. For the expression “10 - 5 - 2”, left association produces 3 while right association produces 7.
Complex expressions with multiple precedence levels all use left recursion:
#![allow(unused)] fn main() { BinaryOp: Expr = { <l:BinaryOp> "||" <r:AndExpr> => Expr::Or(Box::new(l), Box::new(r)), AndExpr, }; AndExpr: Expr = { <l:AndExpr> "&&" <r:EqExpr> => Expr::And(Box::new(l), Box::new(r)), EqExpr, }; AddExpr: Expr = { <l:AddExpr> "+" <r:MulExpr> => Expr::Add(Box::new(l), Box::new(r)), <l:AddExpr> "-" <r:MulExpr> => Expr::Subtract(Box::new(l), Box::new(r)), MulExpr, }; }
Each level of the precedence hierarchy uses left recursion to ensure operators associate correctly. This pattern scales to arbitrarily complex expression grammars while maintaining readability and performance.
The LR parsing algorithm builds the parse tree bottom-up, naturally handling left recursion without stack overflow issues that plague recursive descent parsers. This fundamental difference makes LALRPOP ideal for parsing programming languages with complex expression syntax.
Building an Interpreter
LALRPOP-generated parsers integrate well with interpreters:
#![allow(unused)] fn main() { pub mod ast; pub mod token; use lalrpop_util::lalrpop_mod; lalrpop_mod!(pub calculator_builtin); lalrpop_mod!(pub expression); lalrpop_mod!(pub expression_logos); lalrpop_mod!(pub left_recursion); use lalrpop_util::ParseError; use logos::Logos; /// Parse a simple calculator expression using built-in lexer pub fn parse_calculator(input: &str) -> Result<i32, String> { calculator_builtin::ExprParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Example of detailed error handling for parse errors pub fn parse_with_detailed_errors(input: &str) -> Result<i32, String> { let parser = calculator_builtin::ExprParser::new(); match parser.parse(input) { Ok(result) => Ok(result), Err(ParseError::InvalidToken { location }) => { Err(format!("Invalid token at position {}", location)) } Err(ParseError::UnrecognizedToken { token, expected }) => { let (start, _, end) = token; Err(format!( "Unexpected '{}' at position {}-{}, expected one of: {:?}", &input[start..end], start, end, expected )) } Err(ParseError::UnrecognizedEof { location, expected }) => Err(format!( "Unexpected end of input at position {}, expected: {:?}", location, expected )), Err(ParseError::ExtraToken { token }) => { let (start, _, end) = token; Err(format!( "Extra token '{}' at position {}-{} after valid input", &input[start..end], start, end )) } Err(ParseError::User { error }) => Err(format!("Parse error: {}", error)), } } /// Parse an expression language program using built-in lexer pub fn parse_expression(input: &str) -> Result<ast::Program, String> { expression::ProgramParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse using logos for lexing pub fn parse_with_logos(input: &str) -> Result<ast::Program, String> { let lexer = token::Token::lexer(input); let tokens: Result<Vec<_>, _> = lexer .spanned() .map(|(tok, span)| match tok { Ok(t) => Ok((span.start, t, span.end)), Err(_) => Err("Lexer error"), }) .collect(); match tokens { Ok(tokens) => expression_logos::ProgramParser::new() .parse(tokens) .map_err(|e| format!("Parse error: {:?}", e)), Err(e) => Err(e.to_string()), } } impl Interpreter { pub fn new() -> Self { Self { variables: std::collections::HashMap::new(), } } pub fn execute(&mut self, program: &ast::Program) -> Result<(), String> { for statement in &program.statements { self.execute_statement(statement)?; } Ok(()) } fn execute_statement(&mut self, stmt: &ast::Statement) -> Result<(), String> { match stmt { ast::Statement::Expression(expr) => { self.eval_expr(expr)?; Ok(()) } ast::Statement::Assignment(name, expr) => { let value = self.eval_expr(expr)?; self.variables.insert(name.clone(), value); Ok(()) } ast::Statement::Print(expr) => { let value = self.eval_expr(expr)?; println!("{}", value); Ok(()) } ast::Statement::If(cond, then_block, else_block) => { let cond_value = self.eval_expr(cond)?; if cond_value != 0.0 { for stmt in then_block { self.execute_statement(stmt)?; } } else if let Some(else_stmts) = else_block { for stmt in else_stmts { self.execute_statement(stmt)?; } } Ok(()) } ast::Statement::While(cond, body) => { while self.eval_expr(cond)? != 0.0 { for stmt in body { self.execute_statement(stmt)?; } } Ok(()) } } } fn eval_expr(&self, expr: &ast::Expr) -> Result<f64, String> { match expr { ast::Expr::Number(n) => Ok(*n), ast::Expr::Add(l, r) => Ok(self.eval_expr(l)? + self.eval_expr(r)?), ast::Expr::Subtract(l, r) => Ok(self.eval_expr(l)? - self.eval_expr(r)?), ast::Expr::Multiply(l, r) => Ok(self.eval_expr(l)? * self.eval_expr(r)?), ast::Expr::Divide(l, r) => { let divisor = self.eval_expr(r)?; if divisor == 0.0 { Err("Division by zero".to_string()) } else { Ok(self.eval_expr(l)? / divisor) } } ast::Expr::Negate(e) => Ok(-self.eval_expr(e)?), ast::Expr::Variable(name) => self .variables .get(name) .copied() .ok_or_else(|| format!("Undefined variable: {}", name)), ast::Expr::Call(name, args) => { let arg_values: Result<Vec<_>, _> = args.iter().map(|e| self.eval_expr(e)).collect(); let arg_values = arg_values?; match name.as_str() { "max" if arg_values.len() == 2 => Ok(f64::max(arg_values[0], arg_values[1])), "min" if arg_values.len() == 2 => Ok(f64::min(arg_values[0], arg_values[1])), "sqrt" if arg_values.len() == 1 => Ok(arg_values[0].sqrt()), "abs" if arg_values.len() == 1 => Ok(arg_values[0].abs()), _ => Err(format!("Unknown function: {}", name)), } } } } } impl Default for Interpreter { fn default() -> Self { Self::new() } } /// Demonstrate left vs right associativity parsing pub fn demonstrate_associativity(input: &str) -> (String, String) { let left = left_recursion::LeftAssociativeParser::new() .parse(input) .map(|e| format!("{:?}", e)) .unwrap_or_else(|e| format!("Error: {:?}", e)); let right = left_recursion::RightAssociativeParser::new() .parse(input) .map(|e| format!("{:?}", e)) .unwrap_or_else(|e| format!("Error: {:?}", e)); (left, right) } /// Parse a comma-separated list using left recursion pub fn parse_list_left(input: &str) -> Result<Vec<i32>, String> { left_recursion::CommaSeparatedLeftParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse field access chains like "obj.field1.field2" pub fn parse_field_access(input: &str) -> Result<String, String> { left_recursion::FieldAccessParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse method chains like "obj.method1().method2()" pub fn parse_method_chain(input: &str) -> Result<String, String> { left_recursion::MethodChainParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse expressions with full operator precedence pub fn parse_with_precedence(input: &str) -> Result<ast::Expr, String> { left_recursion::ExprParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } #[cfg(test)] mod tests { use super::*; #[test] fn test_calculator() { assert_eq!(parse_calculator("2 + 3 * 4").unwrap(), 14); assert_eq!(parse_calculator("(2 + 3) * 4").unwrap(), 20); assert_eq!(parse_calculator("10 - 2 - 3").unwrap(), 5); } #[test] fn test_expression_parser() { let program = parse_expression("let x = 10; let y = 20; print x + y;").unwrap(); assert_eq!(program.statements.len(), 3); } #[test] fn test_logos_parser() { let program = parse_with_logos("let x = 5; print x * 2;").unwrap(); assert_eq!(program.statements.len(), 2); } #[test] fn test_interpreter() { let mut interpreter = Interpreter::new(); let program = parse_expression( "let x = 10; let y = 20; let z = x + y; print z;", ) .unwrap(); interpreter.execute(&program).unwrap(); assert_eq!(*interpreter.variables.get("z").unwrap(), 30.0); } #[test] fn test_if_statement() { let program = parse_expression( "let x = 5; if x { let y = 10; }", ) .unwrap(); assert_eq!(program.statements.len(), 2); } #[test] fn test_function_calls() { let program = parse_expression( "let x = max(10, 20); let y = sqrt(16);", ) .unwrap(); let mut interpreter = Interpreter::new(); interpreter.execute(&program).unwrap(); assert_eq!(*interpreter.variables.get("x").unwrap(), 20.0); assert_eq!(*interpreter.variables.get("y").unwrap(), 4.0); } #[test] fn test_left_vs_right_associativity() { // Test that subtraction is left-associative // 10 - 5 - 2 should be (10 - 5) - 2 = 3 for left // and 10 - (5 - 2) = 7 for right let (left, right) = demonstrate_associativity("10 - 5 - 2"); assert!(left.contains("Subtract")); assert!(right.contains("Subtract")); } #[test] fn test_comma_separated_list() { let result = parse_list_left("1, 2, 3, 4, 5").unwrap(); assert_eq!(result, vec![1, 2, 3, 4, 5]); } #[test] fn test_field_access_chain() { let result = parse_field_access("obj.field1.field2.field3").unwrap(); assert_eq!(result, "obj.field1.field2.field3"); } #[test] fn test_method_chain() { let result = parse_method_chain("obj.method1().method2().method3()").unwrap(); assert_eq!(result, "obj.method1().method2().method3()"); } #[test] fn test_operator_precedence() { // Test that * has higher precedence than + // 2 + 3 * 4 should be 2 + (3 * 4) = 14 let expr = parse_with_precedence("2 + 3 * 4").unwrap(); assert_eq!(expr.eval(), 14.0); } } /// Example: Building a simple interpreter pub struct Interpreter { variables: std::collections::HashMap<String, f64>, } }
#![allow(unused)] fn main() { pub mod ast; pub mod token; use lalrpop_util::lalrpop_mod; lalrpop_mod!(pub calculator_builtin); lalrpop_mod!(pub expression); lalrpop_mod!(pub expression_logos); lalrpop_mod!(pub left_recursion); use lalrpop_util::ParseError; use logos::Logos; /// Parse a simple calculator expression using built-in lexer pub fn parse_calculator(input: &str) -> Result<i32, String> { calculator_builtin::ExprParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Example of detailed error handling for parse errors pub fn parse_with_detailed_errors(input: &str) -> Result<i32, String> { let parser = calculator_builtin::ExprParser::new(); match parser.parse(input) { Ok(result) => Ok(result), Err(ParseError::InvalidToken { location }) => { Err(format!("Invalid token at position {}", location)) } Err(ParseError::UnrecognizedToken { token, expected }) => { let (start, _, end) = token; Err(format!( "Unexpected '{}' at position {}-{}, expected one of: {:?}", &input[start..end], start, end, expected )) } Err(ParseError::UnrecognizedEof { location, expected }) => Err(format!( "Unexpected end of input at position {}, expected: {:?}", location, expected )), Err(ParseError::ExtraToken { token }) => { let (start, _, end) = token; Err(format!( "Extra token '{}' at position {}-{} after valid input", &input[start..end], start, end )) } Err(ParseError::User { error }) => Err(format!("Parse error: {}", error)), } } /// Parse an expression language program using built-in lexer pub fn parse_expression(input: &str) -> Result<ast::Program, String> { expression::ProgramParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse using logos for lexing pub fn parse_with_logos(input: &str) -> Result<ast::Program, String> { let lexer = token::Token::lexer(input); let tokens: Result<Vec<_>, _> = lexer .spanned() .map(|(tok, span)| match tok { Ok(t) => Ok((span.start, t, span.end)), Err(_) => Err("Lexer error"), }) .collect(); match tokens { Ok(tokens) => expression_logos::ProgramParser::new() .parse(tokens) .map_err(|e| format!("Parse error: {:?}", e)), Err(e) => Err(e.to_string()), } } /// Example: Building a simple interpreter pub struct Interpreter { variables: std::collections::HashMap<String, f64>, } impl Default for Interpreter { fn default() -> Self { Self::new() } } /// Demonstrate left vs right associativity parsing pub fn demonstrate_associativity(input: &str) -> (String, String) { let left = left_recursion::LeftAssociativeParser::new() .parse(input) .map(|e| format!("{:?}", e)) .unwrap_or_else(|e| format!("Error: {:?}", e)); let right = left_recursion::RightAssociativeParser::new() .parse(input) .map(|e| format!("{:?}", e)) .unwrap_or_else(|e| format!("Error: {:?}", e)); (left, right) } /// Parse a comma-separated list using left recursion pub fn parse_list_left(input: &str) -> Result<Vec<i32>, String> { left_recursion::CommaSeparatedLeftParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse field access chains like "obj.field1.field2" pub fn parse_field_access(input: &str) -> Result<String, String> { left_recursion::FieldAccessParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse method chains like "obj.method1().method2()" pub fn parse_method_chain(input: &str) -> Result<String, String> { left_recursion::MethodChainParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } /// Parse expressions with full operator precedence pub fn parse_with_precedence(input: &str) -> Result<ast::Expr, String> { left_recursion::ExprParser::new() .parse(input) .map_err(|e| format!("Parse error: {:?}", e)) } #[cfg(test)] mod tests { use super::*; #[test] fn test_calculator() { assert_eq!(parse_calculator("2 + 3 * 4").unwrap(), 14); assert_eq!(parse_calculator("(2 + 3) * 4").unwrap(), 20); assert_eq!(parse_calculator("10 - 2 - 3").unwrap(), 5); } #[test] fn test_expression_parser() { let program = parse_expression("let x = 10; let y = 20; print x + y;").unwrap(); assert_eq!(program.statements.len(), 3); } #[test] fn test_logos_parser() { let program = parse_with_logos("let x = 5; print x * 2;").unwrap(); assert_eq!(program.statements.len(), 2); } #[test] fn test_interpreter() { let mut interpreter = Interpreter::new(); let program = parse_expression( "let x = 10; let y = 20; let z = x + y; print z;", ) .unwrap(); interpreter.execute(&program).unwrap(); assert_eq!(*interpreter.variables.get("z").unwrap(), 30.0); } #[test] fn test_if_statement() { let program = parse_expression( "let x = 5; if x { let y = 10; }", ) .unwrap(); assert_eq!(program.statements.len(), 2); } #[test] fn test_function_calls() { let program = parse_expression( "let x = max(10, 20); let y = sqrt(16);", ) .unwrap(); let mut interpreter = Interpreter::new(); interpreter.execute(&program).unwrap(); assert_eq!(*interpreter.variables.get("x").unwrap(), 20.0); assert_eq!(*interpreter.variables.get("y").unwrap(), 4.0); } #[test] fn test_left_vs_right_associativity() { // Test that subtraction is left-associative // 10 - 5 - 2 should be (10 - 5) - 2 = 3 for left // and 10 - (5 - 2) = 7 for right let (left, right) = demonstrate_associativity("10 - 5 - 2"); assert!(left.contains("Subtract")); assert!(right.contains("Subtract")); } #[test] fn test_comma_separated_list() { let result = parse_list_left("1, 2, 3, 4, 5").unwrap(); assert_eq!(result, vec![1, 2, 3, 4, 5]); } #[test] fn test_field_access_chain() { let result = parse_field_access("obj.field1.field2.field3").unwrap(); assert_eq!(result, "obj.field1.field2.field3"); } #[test] fn test_method_chain() { let result = parse_method_chain("obj.method1().method2().method3()").unwrap(); assert_eq!(result, "obj.method1().method2().method3()"); } #[test] fn test_operator_precedence() { // Test that * has higher precedence than + // 2 + 3 * 4 should be 2 + (3 * 4) = 14 let expr = parse_with_precedence("2 + 3 * 4").unwrap(); assert_eq!(expr.eval(), 14.0); } } impl Interpreter { pub fn new() -> Self { Self { variables: std::collections::HashMap::new(), } } pub fn execute(&mut self, program: &ast::Program) -> Result<(), String> { for statement in &program.statements { self.execute_statement(statement)?; } Ok(()) } fn execute_statement(&mut self, stmt: &ast::Statement) -> Result<(), String> { match stmt { ast::Statement::Expression(expr) => { self.eval_expr(expr)?; Ok(()) } ast::Statement::Assignment(name, expr) => { let value = self.eval_expr(expr)?; self.variables.insert(name.clone(), value); Ok(()) } ast::Statement::Print(expr) => { let value = self.eval_expr(expr)?; println!("{}", value); Ok(()) } ast::Statement::If(cond, then_block, else_block) => { let cond_value = self.eval_expr(cond)?; if cond_value != 0.0 { for stmt in then_block { self.execute_statement(stmt)?; } } else if let Some(else_stmts) = else_block { for stmt in else_stmts { self.execute_statement(stmt)?; } } Ok(()) } ast::Statement::While(cond, body) => { while self.eval_expr(cond)? != 0.0 { for stmt in body { self.execute_statement(stmt)?; } } Ok(()) } } } fn eval_expr(&self, expr: &ast::Expr) -> Result<f64, String> { match expr { ast::Expr::Number(n) => Ok(*n), ast::Expr::Add(l, r) => Ok(self.eval_expr(l)? + self.eval_expr(r)?), ast::Expr::Subtract(l, r) => Ok(self.eval_expr(l)? - self.eval_expr(r)?), ast::Expr::Multiply(l, r) => Ok(self.eval_expr(l)? * self.eval_expr(r)?), ast::Expr::Divide(l, r) => { let divisor = self.eval_expr(r)?; if divisor == 0.0 { Err("Division by zero".to_string()) } else { Ok(self.eval_expr(l)? / divisor) } } ast::Expr::Negate(e) => Ok(-self.eval_expr(e)?), ast::Expr::Variable(name) => self .variables .get(name) .copied() .ok_or_else(|| format!("Undefined variable: {}", name)), ast::Expr::Call(name, args) => { let arg_values: Result<Vec<_>, _> = args.iter().map(|e| self.eval_expr(e)).collect(); let arg_values = arg_values?; match name.as_str() { "max" if arg_values.len() == 2 => Ok(f64::max(arg_values[0], arg_values[1])), "min" if arg_values.len() == 2 => Ok(f64::min(arg_values[0], arg_values[1])), "sqrt" if arg_values.len() == 1 => Ok(arg_values[0].sqrt()), "abs" if arg_values.len() == 1 => Ok(arg_values[0].abs()), _ => Err(format!("Unknown function: {}", name)), } } } } } }
The interpreter walks the AST, maintaining variable bindings and executing statements. This separation of parsing and execution allows optimization and analysis passes between parsing and execution.
Conflict Resolution
LALRPOP detects grammar conflicts at compile time:
error: ambiguity detected
The following symbols can be reduced in two ways:
Expr "+" Expr
They could be reduced like so:
Expr = Expr "+" Expr
Or they could be reduced like so:
Expr = Expr, "+" Expr
Resolve conflicts by restructuring the grammar or using precedence annotations. LALRPOP’s error messages pinpoint the exact productions causing conflicts.
Performance Optimization
LALRPOP generates table-driven parsers with excellent performance characteristics. The parsing algorithm is O(n) for valid input with no backtracking. Tables are computed at compile time, so runtime overhead is minimal.
For maximum performance, use external lexers like logos that produce tokens in a single pass. The combination of logos lexing and LALRPOP parsing can process millions of lines per second.
Best Practices
Structure grammars for clarity and maintainability. Group related productions together and use comments to explain complex patterns. Keep action code simple, delegating complex logic to separate functions.
Use typed ASTs to catch errors at compile time. The type system ensures grammar productions and AST construction remain synchronized. Changes to the AST that break the grammar are caught during compilation.
Test grammars thoroughly with both valid and invalid input. LALRPOP’s error reporting helps debug grammar issues, but comprehensive tests ensure the parser accepts the intended language.
Profile parser performance on realistic input. While LALRPOP generates efficient parsers, grammar structure affects performance. Minimize ambiguity and left-factorize common prefixes when performance matters.
logos
The logos crate provides a fast, derive-based lexer generator for Rust. Unlike traditional lexer generators that produce separate source files, logos integrates directly into your Rust code through derive macros. It generates highly optimized, table-driven lexers that can process millions of tokens per second, making it ideal for production compilers and language servers.
Logos excels at tokenizing programming languages with its declarative syntax for defining token patterns. The generated lexers use efficient DFA-based matching with automatic longest-match semantics. The crate handles common compiler requirements like source location tracking, error recovery, and stateful lexing for context-sensitive tokens like string literals or nested comments.
Basic Token Definition
Token types in logos are defined as enums with the #[derive(Logos)]
attribute. Each variant represents a different token type, annotated with patterns that match the token.
#![allow(unused)] fn main() { #[derive(Logos, Debug, PartialEq, Clone)] #[logos(skip r"[ \t\f]+")] // Skip whitespace except newlines pub enum Token { // Keywords #[token("fn")] Function, #[token("let")] Let, #[token("if")] If, #[token("else")] Else, // Identifiers and literals #[regex("[a-zA-Z_][a-zA-Z0-9_]*", |lex| lex.slice().to_string())] Identifier(String), #[regex(r"-?[0-9]+", |lex| lex.slice().parse::<i64>().ok())] Integer(Option<i64>), // Operators #[token("+")] Plus, #[token("-")] Minus, #[token("==")] Equal, #[token("!=")] NotEqual, } }
The #[token]
attribute matches exact strings, while #[regex]
matches regular expressions. The skip directive tells logos to automatically skip whitespace between tokens. Tokens can capture data by providing a closure that processes the matched text.
Using the Lexer
Logos generates an iterator-based lexer that processes input incrementally. The lexer provides access to the matched token, its span in the source, and the matched text.
use logos::{Lexer, Logos}; #[derive(Logos, Debug, PartialEq, Clone)] #[logos(skip r"[ \t\f]+")] // Skip whitespace except newlines pub enum Token { // Keywords #[token("fn")] Function, #[token("let")] Let, #[token("const")] Const, #[token("if")] If, #[token("else")] Else, #[token("while")] While, #[token("for")] For, #[token("return")] Return, #[token("struct")] Struct, #[token("enum")] Enum, #[token("impl")] Impl, #[token("trait")] Trait, #[token("pub")] Pub, #[token("mod")] Mod, #[token("use")] Use, #[token("mut")] Mut, #[token("true")] True, #[token("false")] False, // Identifiers and literals #[regex("[a-zA-Z_][a-zA-Z0-9_]*", |lex| lex.slice().to_string())] Identifier(String), #[regex(r"-?[0-9]+", |lex| lex.slice().parse::<i64>().ok())] Integer(Option<i64>), #[regex(r"-?[0-9]+\.[0-9]+", |lex| lex.slice().parse::<f64>().ok())] Float(Option<f64>), #[regex(r#""([^"\\]|\\.)*""#, |lex| { let s = lex.slice(); s[1..s.len()-1].to_string() })] String(String), #[regex(r"'([^'\\]|\\.)'")] Char, // Comments #[regex(r"//[^\n]*", logos::skip)] #[regex(r"/\*([^*]|\*[^/])*\*/", logos::skip)] Comment, // Operators #[token("+")] Plus, #[token("-")] Minus, #[token("*")] Star, #[token("/")] Slash, #[token("%")] Percent, #[token("=")] Assign, #[token("==")] Equal, #[token("!=")] NotEqual, #[token("<")] Less, #[token("<=")] LessEqual, #[token(">")] Greater, #[token(">=")] GreaterEqual, #[token("&&")] And, #[token("||")] Or, #[token("!")] Not, #[token("&")] Ampersand, #[token("|")] Pipe, #[token("^")] Caret, #[token("<<")] LeftShift, #[token(">>")] RightShift, #[token("+=")] PlusAssign, #[token("-=")] MinusAssign, #[token("*=")] StarAssign, #[token("/=")] SlashAssign, #[token("->")] Arrow, #[token("=>")] FatArrow, #[token("::")] PathSeparator, // Punctuation #[token("(")] LeftParen, #[token(")")] RightParen, #[token("{")] LeftBrace, #[token("}")] RightBrace, #[token("[")] LeftBracket, #[token("]")] RightBracket, #[token(";")] Semicolon, #[token(":")] Colon, #[token(",")] Comma, #[token(".")] Dot, #[token("..")] DotDot, #[token("...")] DotDotDot, #[token("?")] Question, // Special handling for newlines (for line counting) #[token("\n")] Newline, } pub struct TokenStream<'source> { lexer: Lexer<'source, Token>, peeked: Option<Result<Token, ()>>, } impl<'source> TokenStream<'source> { pub fn new(source: &'source str) -> Self { Self { lexer: Token::lexer(source), peeked: None, } } pub fn next_token(&mut self) -> Option<Result<Token, ()>> { if let Some(token) = self.peeked.take() { return Some(token); } self.lexer.next() } pub fn peek_token(&mut self) -> Option<&Result<Token, ()>> { if self.peeked.is_none() { self.peeked = self.lexer.next(); } self.peeked.as_ref() } pub fn span(&self) -> std::ops::Range<usize> { self.lexer.span() } pub fn slice(&self) -> &'source str { self.lexer.slice() } pub fn remainder(&self) -> &'source str { self.lexer.remainder() } } #[derive(Debug, Clone)] pub struct SourceLocation { pub line: usize, pub column: usize, pub byte_offset: usize, } pub struct SourceTracker<'source> { source: &'source str, line_starts: Vec<usize>, } impl<'source> SourceTracker<'source> { pub fn new(source: &'source str) -> Self { let mut line_starts = vec![0]; for (i, ch) in source.char_indices() { if ch == '\n' { line_starts.push(i + 1); } } Self { source, line_starts, } } pub fn location(&self, byte_offset: usize) -> SourceLocation { let line = self .line_starts .binary_search(&byte_offset) .unwrap_or_else(|i| i.saturating_sub(1)); let line_start = self.line_starts[line]; let column = self.source[line_start..byte_offset].chars().count(); SourceLocation { line: line + 1, // 1-based column: column + 1, // 1-based byte_offset, } } pub fn line_content(&self, line: usize) -> &'source str { if line == 0 || line > self.line_starts.len() { return ""; } let start = self.line_starts[line - 1]; let end = if line < self.line_starts.len() { self.line_starts[line] - 1 } else { self.source.len() }; &self.source[start..end] } } pub type TokenSpan = (Token, std::ops::Range<usize>); pub type ErrorSpan = std::ops::Range<usize>; pub fn tokenize_with_errors(input: &str) -> (Vec<TokenSpan>, Vec<ErrorSpan>) { let mut tokens = Vec::new(); let mut errors = Vec::new(); let mut lexer = Token::lexer(input); while let Some(result) = lexer.next() { match result { Ok(token) => tokens.push((token, lexer.span())), Err(()) => errors.push(lexer.span()), } } (tokens, errors) } #[derive(Logos, Debug, PartialEq)] pub enum ExprToken { #[regex(r"[0-9]+", |lex| lex.slice().parse::<i32>().ok())] Number(Option<i32>), #[regex(r"[a-zA-Z_][a-zA-Z0-9_]*", |lex| lex.slice().to_string())] Identifier(String), #[token("+")] Plus, #[token("-")] Minus, #[token("*")] Times, #[token("/")] Divide, #[token("(")] LeftParen, #[token(")")] RightParen, #[regex(r"[ \t\n]+", logos::skip)] Whitespace, } pub fn parse_expression(input: &str) -> Vec<ExprToken> { ExprToken::lexer(input).filter_map(Result::ok).collect() } #[derive(Logos, Debug, PartialEq)] #[logos(extras = IndentationTracker)] pub enum IndentedToken { #[token("\n", |lex| { lex.extras.newline(); logos::Skip })] Newline, #[regex(r"[ ]+", |lex| { let spaces = lex.slice().len(); lex.extras.track_indent(spaces) })] Indent(IndentLevel), #[regex(r"[a-zA-Z_][a-zA-Z0-9_]*", |lex| lex.slice().to_string())] Identifier(String), #[token(":")] Colon, #[regex(r"#[^\n]*", logos::skip)] Comment, } #[derive(Default)] pub struct IndentationTracker { at_line_start: bool, _current_indent: usize, indent_stack: Vec<usize>, } impl IndentationTracker { fn newline(&mut self) { self.at_line_start = true; } fn track_indent(&mut self, spaces: usize) -> IndentLevel { if !self.at_line_start { return IndentLevel::None; } self.at_line_start = false; if self.indent_stack.is_empty() { self.indent_stack.push(0); } let previous = *self.indent_stack.last().unwrap(); use std::cmp::Ordering; match spaces.cmp(&previous) { Ordering::Greater => { self.indent_stack.push(spaces); IndentLevel::Indent } Ordering::Less => { let mut dedent_count = 0; while let Some(&level) = self.indent_stack.last() { if level <= spaces { break; } self.indent_stack.pop(); dedent_count += 1; } IndentLevel::Dedent(dedent_count) } Ordering::Equal => IndentLevel::None, } } } #[derive(Debug, PartialEq)] pub enum IndentLevel { None, Indent, Dedent(usize), } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_tokens() { let input = "fn main() { let x = 42; }"; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::Function); assert_eq!(tokens[1].0, Token::Identifier("main".to_string())); assert_eq!(tokens[2].0, Token::LeftParen); assert_eq!(tokens[3].0, Token::RightParen); assert_eq!(tokens[4].0, Token::LeftBrace); assert_eq!(tokens[5].0, Token::Let); } #[test] fn test_numeric_literals() { let input = "42 -17 3.14 -2.718"; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::Integer(Some(42))); assert_eq!(tokens[1].0, Token::Integer(Some(-17))); assert_eq!(tokens[2].0, Token::Float(Some(3.14))); assert_eq!(tokens[3].0, Token::Float(Some(-2.718))); } #[test] fn test_string_literals() { let input = r#""hello" "world\n" "with\"quotes\"""#; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::String("hello".to_string())); assert_eq!(tokens[1].0, Token::String("world\\n".to_string())); assert_eq!(tokens[2].0, Token::String("with\\\"quotes\\\"".to_string())); } #[test] fn test_operators() { let input = "+ - * / == != <= >= && || -> =>"; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::Plus); assert_eq!(tokens[1].0, Token::Minus); assert_eq!(tokens[2].0, Token::Star); assert_eq!(tokens[3].0, Token::Slash); assert_eq!(tokens[4].0, Token::Equal); assert_eq!(tokens[5].0, Token::NotEqual); assert_eq!(tokens[6].0, Token::LessEqual); assert_eq!(tokens[7].0, Token::GreaterEqual); assert_eq!(tokens[8].0, Token::And); assert_eq!(tokens[9].0, Token::Or); assert_eq!(tokens[10].0, Token::Arrow); assert_eq!(tokens[11].0, Token::FatArrow); } #[test] fn test_error_handling() { let input = "let x = 42 @ invalid"; let (_tokens, errors) = tokenize_with_errors(input); assert!(!errors.is_empty()); assert_eq!(errors[0], 11..12); // Position of '@' } #[test] fn test_source_location() { let input = "fn main() {\n let x = 42;\n}"; let tracker = SourceTracker::new(input); // 'l' in 'let' on line 2 let loc = tracker.location(16); assert_eq!(loc.line, 2); assert_eq!(loc.column, 5); // Get line content let line2 = tracker.line_content(2); assert_eq!(line2, " let x = 42;"); } #[test] fn test_expression_lexer() { let input = "x + 42 * (y - 3)"; let tokens = parse_expression(input); assert_eq!(tokens[0], ExprToken::Identifier("x".to_string())); assert_eq!(tokens[1], ExprToken::Plus); assert_eq!(tokens[2], ExprToken::Number(Some(42))); assert_eq!(tokens[3], ExprToken::Times); assert_eq!(tokens[4], ExprToken::LeftParen); assert_eq!(tokens[5], ExprToken::Identifier("y".to_string())); assert_eq!(tokens[6], ExprToken::Minus); assert_eq!(tokens[7], ExprToken::Number(Some(3))); assert_eq!(tokens[8], ExprToken::RightParen); } } pub fn tokenize(input: &str) -> Vec<TokenSpan> { let mut tokens = Vec::new(); let mut lexer = Token::lexer(input); while let Some(result) = lexer.next() { if let Ok(token) = result { tokens.push((token, lexer.span())); } } tokens }
This function demonstrates basic tokenization, collecting all tokens with their source spans. The spans are byte offsets that can be used to extract the original text or generate error messages with precise locations.
Error Handling
Real compilers need robust error handling for invalid input. Logos returns Result<Token, ()>
for each token, allowing graceful handling of lexical errors.
use logos::{Lexer, Logos}; #[derive(Logos, Debug, PartialEq, Clone)] #[logos(skip r"[ \t\f]+")] // Skip whitespace except newlines pub enum Token { // Keywords #[token("fn")] Function, #[token("let")] Let, #[token("const")] Const, #[token("if")] If, #[token("else")] Else, #[token("while")] While, #[token("for")] For, #[token("return")] Return, #[token("struct")] Struct, #[token("enum")] Enum, #[token("impl")] Impl, #[token("trait")] Trait, #[token("pub")] Pub, #[token("mod")] Mod, #[token("use")] Use, #[token("mut")] Mut, #[token("true")] True, #[token("false")] False, // Identifiers and literals #[regex("[a-zA-Z_][a-zA-Z0-9_]*", |lex| lex.slice().to_string())] Identifier(String), #[regex(r"-?[0-9]+", |lex| lex.slice().parse::<i64>().ok())] Integer(Option<i64>), #[regex(r"-?[0-9]+\.[0-9]+", |lex| lex.slice().parse::<f64>().ok())] Float(Option<f64>), #[regex(r#""([^"\\]|\\.)*""#, |lex| { let s = lex.slice(); s[1..s.len()-1].to_string() })] String(String), #[regex(r"'([^'\\]|\\.)'")] Char, // Comments #[regex(r"//[^\n]*", logos::skip)] #[regex(r"/\*([^*]|\*[^/])*\*/", logos::skip)] Comment, // Operators #[token("+")] Plus, #[token("-")] Minus, #[token("*")] Star, #[token("/")] Slash, #[token("%")] Percent, #[token("=")] Assign, #[token("==")] Equal, #[token("!=")] NotEqual, #[token("<")] Less, #[token("<=")] LessEqual, #[token(">")] Greater, #[token(">=")] GreaterEqual, #[token("&&")] And, #[token("||")] Or, #[token("!")] Not, #[token("&")] Ampersand, #[token("|")] Pipe, #[token("^")] Caret, #[token("<<")] LeftShift, #[token(">>")] RightShift, #[token("+=")] PlusAssign, #[token("-=")] MinusAssign, #[token("*=")] StarAssign, #[token("/=")] SlashAssign, #[token("->")] Arrow, #[token("=>")] FatArrow, #[token("::")] PathSeparator, // Punctuation #[token("(")] LeftParen, #[token(")")] RightParen, #[token("{")] LeftBrace, #[token("}")] RightBrace, #[token("[")] LeftBracket, #[token("]")] RightBracket, #[token(";")] Semicolon, #[token(":")] Colon, #[token(",")] Comma, #[token(".")] Dot, #[token("..")] DotDot, #[token("...")] DotDotDot, #[token("?")] Question, // Special handling for newlines (for line counting) #[token("\n")] Newline, } pub struct TokenStream<'source> { lexer: Lexer<'source, Token>, peeked: Option<Result<Token, ()>>, } impl<'source> TokenStream<'source> { pub fn new(source: &'source str) -> Self { Self { lexer: Token::lexer(source), peeked: None, } } pub fn next_token(&mut self) -> Option<Result<Token, ()>> { if let Some(token) = self.peeked.take() { return Some(token); } self.lexer.next() } pub fn peek_token(&mut self) -> Option<&Result<Token, ()>> { if self.peeked.is_none() { self.peeked = self.lexer.next(); } self.peeked.as_ref() } pub fn span(&self) -> std::ops::Range<usize> { self.lexer.span() } pub fn slice(&self) -> &'source str { self.lexer.slice() } pub fn remainder(&self) -> &'source str { self.lexer.remainder() } } #[derive(Debug, Clone)] pub struct SourceLocation { pub line: usize, pub column: usize, pub byte_offset: usize, } pub struct SourceTracker<'source> { source: &'source str, line_starts: Vec<usize>, } impl<'source> SourceTracker<'source> { pub fn new(source: &'source str) -> Self { let mut line_starts = vec![0]; for (i, ch) in source.char_indices() { if ch == '\n' { line_starts.push(i + 1); } } Self { source, line_starts, } } pub fn location(&self, byte_offset: usize) -> SourceLocation { let line = self .line_starts .binary_search(&byte_offset) .unwrap_or_else(|i| i.saturating_sub(1)); let line_start = self.line_starts[line]; let column = self.source[line_start..byte_offset].chars().count(); SourceLocation { line: line + 1, // 1-based column: column + 1, // 1-based byte_offset, } } pub fn line_content(&self, line: usize) -> &'source str { if line == 0 || line > self.line_starts.len() { return ""; } let start = self.line_starts[line - 1]; let end = if line < self.line_starts.len() { self.line_starts[line] - 1 } else { self.source.len() }; &self.source[start..end] } } pub fn tokenize(input: &str) -> Vec<TokenSpan> { let mut tokens = Vec::new(); let mut lexer = Token::lexer(input); while let Some(result) = lexer.next() { if let Ok(token) = result { tokens.push((token, lexer.span())); } } tokens } pub type TokenSpan = (Token, std::ops::Range<usize>); pub type ErrorSpan = std::ops::Range<usize>; #[derive(Logos, Debug, PartialEq)] pub enum ExprToken { #[regex(r"[0-9]+", |lex| lex.slice().parse::<i32>().ok())] Number(Option<i32>), #[regex(r"[a-zA-Z_][a-zA-Z0-9_]*", |lex| lex.slice().to_string())] Identifier(String), #[token("+")] Plus, #[token("-")] Minus, #[token("*")] Times, #[token("/")] Divide, #[token("(")] LeftParen, #[token(")")] RightParen, #[regex(r"[ \t\n]+", logos::skip)] Whitespace, } pub fn parse_expression(input: &str) -> Vec<ExprToken> { ExprToken::lexer(input).filter_map(Result::ok).collect() } #[derive(Logos, Debug, PartialEq)] #[logos(extras = IndentationTracker)] pub enum IndentedToken { #[token("\n", |lex| { lex.extras.newline(); logos::Skip })] Newline, #[regex(r"[ ]+", |lex| { let spaces = lex.slice().len(); lex.extras.track_indent(spaces) })] Indent(IndentLevel), #[regex(r"[a-zA-Z_][a-zA-Z0-9_]*", |lex| lex.slice().to_string())] Identifier(String), #[token(":")] Colon, #[regex(r"#[^\n]*", logos::skip)] Comment, } #[derive(Default)] pub struct IndentationTracker { at_line_start: bool, _current_indent: usize, indent_stack: Vec<usize>, } impl IndentationTracker { fn newline(&mut self) { self.at_line_start = true; } fn track_indent(&mut self, spaces: usize) -> IndentLevel { if !self.at_line_start { return IndentLevel::None; } self.at_line_start = false; if self.indent_stack.is_empty() { self.indent_stack.push(0); } let previous = *self.indent_stack.last().unwrap(); use std::cmp::Ordering; match spaces.cmp(&previous) { Ordering::Greater => { self.indent_stack.push(spaces); IndentLevel::Indent } Ordering::Less => { let mut dedent_count = 0; while let Some(&level) = self.indent_stack.last() { if level <= spaces { break; } self.indent_stack.pop(); dedent_count += 1; } IndentLevel::Dedent(dedent_count) } Ordering::Equal => IndentLevel::None, } } } #[derive(Debug, PartialEq)] pub enum IndentLevel { None, Indent, Dedent(usize), } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_tokens() { let input = "fn main() { let x = 42; }"; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::Function); assert_eq!(tokens[1].0, Token::Identifier("main".to_string())); assert_eq!(tokens[2].0, Token::LeftParen); assert_eq!(tokens[3].0, Token::RightParen); assert_eq!(tokens[4].0, Token::LeftBrace); assert_eq!(tokens[5].0, Token::Let); } #[test] fn test_numeric_literals() { let input = "42 -17 3.14 -2.718"; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::Integer(Some(42))); assert_eq!(tokens[1].0, Token::Integer(Some(-17))); assert_eq!(tokens[2].0, Token::Float(Some(3.14))); assert_eq!(tokens[3].0, Token::Float(Some(-2.718))); } #[test] fn test_string_literals() { let input = r#""hello" "world\n" "with\"quotes\"""#; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::String("hello".to_string())); assert_eq!(tokens[1].0, Token::String("world\\n".to_string())); assert_eq!(tokens[2].0, Token::String("with\\\"quotes\\\"".to_string())); } #[test] fn test_operators() { let input = "+ - * / == != <= >= && || -> =>"; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::Plus); assert_eq!(tokens[1].0, Token::Minus); assert_eq!(tokens[2].0, Token::Star); assert_eq!(tokens[3].0, Token::Slash); assert_eq!(tokens[4].0, Token::Equal); assert_eq!(tokens[5].0, Token::NotEqual); assert_eq!(tokens[6].0, Token::LessEqual); assert_eq!(tokens[7].0, Token::GreaterEqual); assert_eq!(tokens[8].0, Token::And); assert_eq!(tokens[9].0, Token::Or); assert_eq!(tokens[10].0, Token::Arrow); assert_eq!(tokens[11].0, Token::FatArrow); } #[test] fn test_error_handling() { let input = "let x = 42 @ invalid"; let (_tokens, errors) = tokenize_with_errors(input); assert!(!errors.is_empty()); assert_eq!(errors[0], 11..12); // Position of '@' } #[test] fn test_source_location() { let input = "fn main() {\n let x = 42;\n}"; let tracker = SourceTracker::new(input); // 'l' in 'let' on line 2 let loc = tracker.location(16); assert_eq!(loc.line, 2); assert_eq!(loc.column, 5); // Get line content let line2 = tracker.line_content(2); assert_eq!(line2, " let x = 42;"); } #[test] fn test_expression_lexer() { let input = "x + 42 * (y - 3)"; let tokens = parse_expression(input); assert_eq!(tokens[0], ExprToken::Identifier("x".to_string())); assert_eq!(tokens[1], ExprToken::Plus); assert_eq!(tokens[2], ExprToken::Number(Some(42))); assert_eq!(tokens[3], ExprToken::Times); assert_eq!(tokens[4], ExprToken::LeftParen); assert_eq!(tokens[5], ExprToken::Identifier("y".to_string())); assert_eq!(tokens[6], ExprToken::Minus); assert_eq!(tokens[7], ExprToken::Number(Some(3))); assert_eq!(tokens[8], ExprToken::RightParen); } } pub fn tokenize_with_errors(input: &str) -> (Vec<TokenSpan>, Vec<ErrorSpan>) { let mut tokens = Vec::new(); let mut errors = Vec::new(); let mut lexer = Token::lexer(input); while let Some(result) = lexer.next() { match result { Ok(token) => tokens.push((token, lexer.span())), Err(()) => errors.push(lexer.span()), } } (tokens, errors) }
This approach separates valid tokens from error spans, enabling the compiler to continue processing after encountering invalid characters. Error spans can be used to generate diagnostics showing exactly where the problem occurred.
Source Location Tracking
Compilers need to map byte offsets to human-readable line and column numbers for error reporting. The SourceTracker maintains this mapping efficiently.
use logos::{Lexer, Logos}; #[derive(Logos, Debug, PartialEq, Clone)] #[logos(skip r"[ \t\f]+")] // Skip whitespace except newlines pub enum Token { // Keywords #[token("fn")] Function, #[token("let")] Let, #[token("const")] Const, #[token("if")] If, #[token("else")] Else, #[token("while")] While, #[token("for")] For, #[token("return")] Return, #[token("struct")] Struct, #[token("enum")] Enum, #[token("impl")] Impl, #[token("trait")] Trait, #[token("pub")] Pub, #[token("mod")] Mod, #[token("use")] Use, #[token("mut")] Mut, #[token("true")] True, #[token("false")] False, // Identifiers and literals #[regex("[a-zA-Z_][a-zA-Z0-9_]*", |lex| lex.slice().to_string())] Identifier(String), #[regex(r"-?[0-9]+", |lex| lex.slice().parse::<i64>().ok())] Integer(Option<i64>), #[regex(r"-?[0-9]+\.[0-9]+", |lex| lex.slice().parse::<f64>().ok())] Float(Option<f64>), #[regex(r#""([^"\\]|\\.)*""#, |lex| { let s = lex.slice(); s[1..s.len()-1].to_string() })] String(String), #[regex(r"'([^'\\]|\\.)'")] Char, // Comments #[regex(r"//[^\n]*", logos::skip)] #[regex(r"/\*([^*]|\*[^/])*\*/", logos::skip)] Comment, // Operators #[token("+")] Plus, #[token("-")] Minus, #[token("*")] Star, #[token("/")] Slash, #[token("%")] Percent, #[token("=")] Assign, #[token("==")] Equal, #[token("!=")] NotEqual, #[token("<")] Less, #[token("<=")] LessEqual, #[token(">")] Greater, #[token(">=")] GreaterEqual, #[token("&&")] And, #[token("||")] Or, #[token("!")] Not, #[token("&")] Ampersand, #[token("|")] Pipe, #[token("^")] Caret, #[token("<<")] LeftShift, #[token(">>")] RightShift, #[token("+=")] PlusAssign, #[token("-=")] MinusAssign, #[token("*=")] StarAssign, #[token("/=")] SlashAssign, #[token("->")] Arrow, #[token("=>")] FatArrow, #[token("::")] PathSeparator, // Punctuation #[token("(")] LeftParen, #[token(")")] RightParen, #[token("{")] LeftBrace, #[token("}")] RightBrace, #[token("[")] LeftBracket, #[token("]")] RightBracket, #[token(";")] Semicolon, #[token(":")] Colon, #[token(",")] Comma, #[token(".")] Dot, #[token("..")] DotDot, #[token("...")] DotDotDot, #[token("?")] Question, // Special handling for newlines (for line counting) #[token("\n")] Newline, } pub struct TokenStream<'source> { lexer: Lexer<'source, Token>, peeked: Option<Result<Token, ()>>, } impl<'source> TokenStream<'source> { pub fn new(source: &'source str) -> Self { Self { lexer: Token::lexer(source), peeked: None, } } pub fn next_token(&mut self) -> Option<Result<Token, ()>> { if let Some(token) = self.peeked.take() { return Some(token); } self.lexer.next() } pub fn peek_token(&mut self) -> Option<&Result<Token, ()>> { if self.peeked.is_none() { self.peeked = self.lexer.next(); } self.peeked.as_ref() } pub fn span(&self) -> std::ops::Range<usize> { self.lexer.span() } pub fn slice(&self) -> &'source str { self.lexer.slice() } pub fn remainder(&self) -> &'source str { self.lexer.remainder() } } #[derive(Debug, Clone)] pub struct SourceLocation { pub line: usize, pub column: usize, pub byte_offset: usize, } impl<'source> SourceTracker<'source> { pub fn new(source: &'source str) -> Self { let mut line_starts = vec![0]; for (i, ch) in source.char_indices() { if ch == '\n' { line_starts.push(i + 1); } } Self { source, line_starts, } } pub fn location(&self, byte_offset: usize) -> SourceLocation { let line = self .line_starts .binary_search(&byte_offset) .unwrap_or_else(|i| i.saturating_sub(1)); let line_start = self.line_starts[line]; let column = self.source[line_start..byte_offset].chars().count(); SourceLocation { line: line + 1, // 1-based column: column + 1, // 1-based byte_offset, } } pub fn line_content(&self, line: usize) -> &'source str { if line == 0 || line > self.line_starts.len() { return ""; } let start = self.line_starts[line - 1]; let end = if line < self.line_starts.len() { self.line_starts[line] - 1 } else { self.source.len() }; &self.source[start..end] } } pub fn tokenize(input: &str) -> Vec<TokenSpan> { let mut tokens = Vec::new(); let mut lexer = Token::lexer(input); while let Some(result) = lexer.next() { if let Ok(token) = result { tokens.push((token, lexer.span())); } } tokens } pub type TokenSpan = (Token, std::ops::Range<usize>); pub type ErrorSpan = std::ops::Range<usize>; pub fn tokenize_with_errors(input: &str) -> (Vec<TokenSpan>, Vec<ErrorSpan>) { let mut tokens = Vec::new(); let mut errors = Vec::new(); let mut lexer = Token::lexer(input); while let Some(result) = lexer.next() { match result { Ok(token) => tokens.push((token, lexer.span())), Err(()) => errors.push(lexer.span()), } } (tokens, errors) } #[derive(Logos, Debug, PartialEq)] pub enum ExprToken { #[regex(r"[0-9]+", |lex| lex.slice().parse::<i32>().ok())] Number(Option<i32>), #[regex(r"[a-zA-Z_][a-zA-Z0-9_]*", |lex| lex.slice().to_string())] Identifier(String), #[token("+")] Plus, #[token("-")] Minus, #[token("*")] Times, #[token("/")] Divide, #[token("(")] LeftParen, #[token(")")] RightParen, #[regex(r"[ \t\n]+", logos::skip)] Whitespace, } pub fn parse_expression(input: &str) -> Vec<ExprToken> { ExprToken::lexer(input).filter_map(Result::ok).collect() } #[derive(Logos, Debug, PartialEq)] #[logos(extras = IndentationTracker)] pub enum IndentedToken { #[token("\n", |lex| { lex.extras.newline(); logos::Skip })] Newline, #[regex(r"[ ]+", |lex| { let spaces = lex.slice().len(); lex.extras.track_indent(spaces) })] Indent(IndentLevel), #[regex(r"[a-zA-Z_][a-zA-Z0-9_]*", |lex| lex.slice().to_string())] Identifier(String), #[token(":")] Colon, #[regex(r"#[^\n]*", logos::skip)] Comment, } #[derive(Default)] pub struct IndentationTracker { at_line_start: bool, _current_indent: usize, indent_stack: Vec<usize>, } impl IndentationTracker { fn newline(&mut self) { self.at_line_start = true; } fn track_indent(&mut self, spaces: usize) -> IndentLevel { if !self.at_line_start { return IndentLevel::None; } self.at_line_start = false; if self.indent_stack.is_empty() { self.indent_stack.push(0); } let previous = *self.indent_stack.last().unwrap(); use std::cmp::Ordering; match spaces.cmp(&previous) { Ordering::Greater => { self.indent_stack.push(spaces); IndentLevel::Indent } Ordering::Less => { let mut dedent_count = 0; while let Some(&level) = self.indent_stack.last() { if level <= spaces { break; } self.indent_stack.pop(); dedent_count += 1; } IndentLevel::Dedent(dedent_count) } Ordering::Equal => IndentLevel::None, } } } #[derive(Debug, PartialEq)] pub enum IndentLevel { None, Indent, Dedent(usize), } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_tokens() { let input = "fn main() { let x = 42; }"; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::Function); assert_eq!(tokens[1].0, Token::Identifier("main".to_string())); assert_eq!(tokens[2].0, Token::LeftParen); assert_eq!(tokens[3].0, Token::RightParen); assert_eq!(tokens[4].0, Token::LeftBrace); assert_eq!(tokens[5].0, Token::Let); } #[test] fn test_numeric_literals() { let input = "42 -17 3.14 -2.718"; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::Integer(Some(42))); assert_eq!(tokens[1].0, Token::Integer(Some(-17))); assert_eq!(tokens[2].0, Token::Float(Some(3.14))); assert_eq!(tokens[3].0, Token::Float(Some(-2.718))); } #[test] fn test_string_literals() { let input = r#""hello" "world\n" "with\"quotes\"""#; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::String("hello".to_string())); assert_eq!(tokens[1].0, Token::String("world\\n".to_string())); assert_eq!(tokens[2].0, Token::String("with\\\"quotes\\\"".to_string())); } #[test] fn test_operators() { let input = "+ - * / == != <= >= && || -> =>"; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::Plus); assert_eq!(tokens[1].0, Token::Minus); assert_eq!(tokens[2].0, Token::Star); assert_eq!(tokens[3].0, Token::Slash); assert_eq!(tokens[4].0, Token::Equal); assert_eq!(tokens[5].0, Token::NotEqual); assert_eq!(tokens[6].0, Token::LessEqual); assert_eq!(tokens[7].0, Token::GreaterEqual); assert_eq!(tokens[8].0, Token::And); assert_eq!(tokens[9].0, Token::Or); assert_eq!(tokens[10].0, Token::Arrow); assert_eq!(tokens[11].0, Token::FatArrow); } #[test] fn test_error_handling() { let input = "let x = 42 @ invalid"; let (_tokens, errors) = tokenize_with_errors(input); assert!(!errors.is_empty()); assert_eq!(errors[0], 11..12); // Position of '@' } #[test] fn test_source_location() { let input = "fn main() {\n let x = 42;\n}"; let tracker = SourceTracker::new(input); // 'l' in 'let' on line 2 let loc = tracker.location(16); assert_eq!(loc.line, 2); assert_eq!(loc.column, 5); // Get line content let line2 = tracker.line_content(2); assert_eq!(line2, " let x = 42;"); } #[test] fn test_expression_lexer() { let input = "x + 42 * (y - 3)"; let tokens = parse_expression(input); assert_eq!(tokens[0], ExprToken::Identifier("x".to_string())); assert_eq!(tokens[1], ExprToken::Plus); assert_eq!(tokens[2], ExprToken::Number(Some(42))); assert_eq!(tokens[3], ExprToken::Times); assert_eq!(tokens[4], ExprToken::LeftParen); assert_eq!(tokens[5], ExprToken::Identifier("y".to_string())); assert_eq!(tokens[6], ExprToken::Minus); assert_eq!(tokens[7], ExprToken::Number(Some(3))); assert_eq!(tokens[8], ExprToken::RightParen); } } pub struct SourceTracker<'source> { source: &'source str, line_starts: Vec<usize>, }
The SourceTracker builds an index of line start positions for efficient location queries, scanning the input once at initialization to find all newline positions.
The tracker pre-computes line boundaries for O(log n) location lookups. This is crucial for language servers that need to convert between byte offsets and editor positions frequently.
Token Streams
For parsing, it’s often useful to wrap the lexer in a stream that supports peeking at the next token without consuming it.
use logos::{Lexer, Logos}; #[derive(Logos, Debug, PartialEq, Clone)] #[logos(skip r"[ \t\f]+")] // Skip whitespace except newlines pub enum Token { // Keywords #[token("fn")] Function, #[token("let")] Let, #[token("const")] Const, #[token("if")] If, #[token("else")] Else, #[token("while")] While, #[token("for")] For, #[token("return")] Return, #[token("struct")] Struct, #[token("enum")] Enum, #[token("impl")] Impl, #[token("trait")] Trait, #[token("pub")] Pub, #[token("mod")] Mod, #[token("use")] Use, #[token("mut")] Mut, #[token("true")] True, #[token("false")] False, // Identifiers and literals #[regex("[a-zA-Z_][a-zA-Z0-9_]*", |lex| lex.slice().to_string())] Identifier(String), #[regex(r"-?[0-9]+", |lex| lex.slice().parse::<i64>().ok())] Integer(Option<i64>), #[regex(r"-?[0-9]+\.[0-9]+", |lex| lex.slice().parse::<f64>().ok())] Float(Option<f64>), #[regex(r#""([^"\\]|\\.)*""#, |lex| { let s = lex.slice(); s[1..s.len()-1].to_string() })] String(String), #[regex(r"'([^'\\]|\\.)'")] Char, // Comments #[regex(r"//[^\n]*", logos::skip)] #[regex(r"/\*([^*]|\*[^/])*\*/", logos::skip)] Comment, // Operators #[token("+")] Plus, #[token("-")] Minus, #[token("*")] Star, #[token("/")] Slash, #[token("%")] Percent, #[token("=")] Assign, #[token("==")] Equal, #[token("!=")] NotEqual, #[token("<")] Less, #[token("<=")] LessEqual, #[token(">")] Greater, #[token(">=")] GreaterEqual, #[token("&&")] And, #[token("||")] Or, #[token("!")] Not, #[token("&")] Ampersand, #[token("|")] Pipe, #[token("^")] Caret, #[token("<<")] LeftShift, #[token(">>")] RightShift, #[token("+=")] PlusAssign, #[token("-=")] MinusAssign, #[token("*=")] StarAssign, #[token("/=")] SlashAssign, #[token("->")] Arrow, #[token("=>")] FatArrow, #[token("::")] PathSeparator, // Punctuation #[token("(")] LeftParen, #[token(")")] RightParen, #[token("{")] LeftBrace, #[token("}")] RightBrace, #[token("[")] LeftBracket, #[token("]")] RightBracket, #[token(";")] Semicolon, #[token(":")] Colon, #[token(",")] Comma, #[token(".")] Dot, #[token("..")] DotDot, #[token("...")] DotDotDot, #[token("?")] Question, // Special handling for newlines (for line counting) #[token("\n")] Newline, } impl<'source> TokenStream<'source> { pub fn new(source: &'source str) -> Self { Self { lexer: Token::lexer(source), peeked: None, } } pub fn next_token(&mut self) -> Option<Result<Token, ()>> { if let Some(token) = self.peeked.take() { return Some(token); } self.lexer.next() } pub fn peek_token(&mut self) -> Option<&Result<Token, ()>> { if self.peeked.is_none() { self.peeked = self.lexer.next(); } self.peeked.as_ref() } pub fn span(&self) -> std::ops::Range<usize> { self.lexer.span() } pub fn slice(&self) -> &'source str { self.lexer.slice() } pub fn remainder(&self) -> &'source str { self.lexer.remainder() } } #[derive(Debug, Clone)] pub struct SourceLocation { pub line: usize, pub column: usize, pub byte_offset: usize, } pub struct SourceTracker<'source> { source: &'source str, line_starts: Vec<usize>, } impl<'source> SourceTracker<'source> { pub fn new(source: &'source str) -> Self { let mut line_starts = vec![0]; for (i, ch) in source.char_indices() { if ch == '\n' { line_starts.push(i + 1); } } Self { source, line_starts, } } pub fn location(&self, byte_offset: usize) -> SourceLocation { let line = self .line_starts .binary_search(&byte_offset) .unwrap_or_else(|i| i.saturating_sub(1)); let line_start = self.line_starts[line]; let column = self.source[line_start..byte_offset].chars().count(); SourceLocation { line: line + 1, // 1-based column: column + 1, // 1-based byte_offset, } } pub fn line_content(&self, line: usize) -> &'source str { if line == 0 || line > self.line_starts.len() { return ""; } let start = self.line_starts[line - 1]; let end = if line < self.line_starts.len() { self.line_starts[line] - 1 } else { self.source.len() }; &self.source[start..end] } } pub fn tokenize(input: &str) -> Vec<TokenSpan> { let mut tokens = Vec::new(); let mut lexer = Token::lexer(input); while let Some(result) = lexer.next() { if let Ok(token) = result { tokens.push((token, lexer.span())); } } tokens } pub type TokenSpan = (Token, std::ops::Range<usize>); pub type ErrorSpan = std::ops::Range<usize>; pub fn tokenize_with_errors(input: &str) -> (Vec<TokenSpan>, Vec<ErrorSpan>) { let mut tokens = Vec::new(); let mut errors = Vec::new(); let mut lexer = Token::lexer(input); while let Some(result) = lexer.next() { match result { Ok(token) => tokens.push((token, lexer.span())), Err(()) => errors.push(lexer.span()), } } (tokens, errors) } #[derive(Logos, Debug, PartialEq)] pub enum ExprToken { #[regex(r"[0-9]+", |lex| lex.slice().parse::<i32>().ok())] Number(Option<i32>), #[regex(r"[a-zA-Z_][a-zA-Z0-9_]*", |lex| lex.slice().to_string())] Identifier(String), #[token("+")] Plus, #[token("-")] Minus, #[token("*")] Times, #[token("/")] Divide, #[token("(")] LeftParen, #[token(")")] RightParen, #[regex(r"[ \t\n]+", logos::skip)] Whitespace, } pub fn parse_expression(input: &str) -> Vec<ExprToken> { ExprToken::lexer(input).filter_map(Result::ok).collect() } #[derive(Logos, Debug, PartialEq)] #[logos(extras = IndentationTracker)] pub enum IndentedToken { #[token("\n", |lex| { lex.extras.newline(); logos::Skip })] Newline, #[regex(r"[ ]+", |lex| { let spaces = lex.slice().len(); lex.extras.track_indent(spaces) })] Indent(IndentLevel), #[regex(r"[a-zA-Z_][a-zA-Z0-9_]*", |lex| lex.slice().to_string())] Identifier(String), #[token(":")] Colon, #[regex(r"#[^\n]*", logos::skip)] Comment, } #[derive(Default)] pub struct IndentationTracker { at_line_start: bool, _current_indent: usize, indent_stack: Vec<usize>, } impl IndentationTracker { fn newline(&mut self) { self.at_line_start = true; } fn track_indent(&mut self, spaces: usize) -> IndentLevel { if !self.at_line_start { return IndentLevel::None; } self.at_line_start = false; if self.indent_stack.is_empty() { self.indent_stack.push(0); } let previous = *self.indent_stack.last().unwrap(); use std::cmp::Ordering; match spaces.cmp(&previous) { Ordering::Greater => { self.indent_stack.push(spaces); IndentLevel::Indent } Ordering::Less => { let mut dedent_count = 0; while let Some(&level) = self.indent_stack.last() { if level <= spaces { break; } self.indent_stack.pop(); dedent_count += 1; } IndentLevel::Dedent(dedent_count) } Ordering::Equal => IndentLevel::None, } } } #[derive(Debug, PartialEq)] pub enum IndentLevel { None, Indent, Dedent(usize), } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_tokens() { let input = "fn main() { let x = 42; }"; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::Function); assert_eq!(tokens[1].0, Token::Identifier("main".to_string())); assert_eq!(tokens[2].0, Token::LeftParen); assert_eq!(tokens[3].0, Token::RightParen); assert_eq!(tokens[4].0, Token::LeftBrace); assert_eq!(tokens[5].0, Token::Let); } #[test] fn test_numeric_literals() { let input = "42 -17 3.14 -2.718"; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::Integer(Some(42))); assert_eq!(tokens[1].0, Token::Integer(Some(-17))); assert_eq!(tokens[2].0, Token::Float(Some(3.14))); assert_eq!(tokens[3].0, Token::Float(Some(-2.718))); } #[test] fn test_string_literals() { let input = r#""hello" "world\n" "with\"quotes\"""#; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::String("hello".to_string())); assert_eq!(tokens[1].0, Token::String("world\\n".to_string())); assert_eq!(tokens[2].0, Token::String("with\\\"quotes\\\"".to_string())); } #[test] fn test_operators() { let input = "+ - * / == != <= >= && || -> =>"; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::Plus); assert_eq!(tokens[1].0, Token::Minus); assert_eq!(tokens[2].0, Token::Star); assert_eq!(tokens[3].0, Token::Slash); assert_eq!(tokens[4].0, Token::Equal); assert_eq!(tokens[5].0, Token::NotEqual); assert_eq!(tokens[6].0, Token::LessEqual); assert_eq!(tokens[7].0, Token::GreaterEqual); assert_eq!(tokens[8].0, Token::And); assert_eq!(tokens[9].0, Token::Or); assert_eq!(tokens[10].0, Token::Arrow); assert_eq!(tokens[11].0, Token::FatArrow); } #[test] fn test_error_handling() { let input = "let x = 42 @ invalid"; let (_tokens, errors) = tokenize_with_errors(input); assert!(!errors.is_empty()); assert_eq!(errors[0], 11..12); // Position of '@' } #[test] fn test_source_location() { let input = "fn main() {\n let x = 42;\n}"; let tracker = SourceTracker::new(input); // 'l' in 'let' on line 2 let loc = tracker.location(16); assert_eq!(loc.line, 2); assert_eq!(loc.column, 5); // Get line content let line2 = tracker.line_content(2); assert_eq!(line2, " let x = 42;"); } #[test] fn test_expression_lexer() { let input = "x + 42 * (y - 3)"; let tokens = parse_expression(input); assert_eq!(tokens[0], ExprToken::Identifier("x".to_string())); assert_eq!(tokens[1], ExprToken::Plus); assert_eq!(tokens[2], ExprToken::Number(Some(42))); assert_eq!(tokens[3], ExprToken::Times); assert_eq!(tokens[4], ExprToken::LeftParen); assert_eq!(tokens[5], ExprToken::Identifier("y".to_string())); assert_eq!(tokens[6], ExprToken::Minus); assert_eq!(tokens[7], ExprToken::Number(Some(3))); assert_eq!(tokens[8], ExprToken::RightParen); } } pub struct TokenStream<'source> { lexer: Lexer<'source, Token>, peeked: Option<Result<Token, ()>>, }
The peek_token method allows the parser to look at the next token without consuming it, enabling predictive parsing algorithms to make decisions based on lookahead.
The token stream maintains a one-token lookahead buffer, essential for predictive parsing techniques like recursive descent.
Advanced Patterns
Logos supports complex token patterns including comments, escape sequences in strings, and numeric literals with different bases. The derive macro generates efficient matching code for all these patterns.
Comments can be handled by marking them with logos::skip
to automatically discard them, or by capturing them as tokens if needed for documentation processing. String literals can use regex patterns to handle escape sequences, while numeric literals can parse different bases and formats.
Stateful Lexing
Some languages require context-sensitive lexing, such as Python’s significant indentation. Logos supports stateful lexing through the extras system.
use logos::{Lexer, Logos}; #[derive(Logos, Debug, PartialEq, Clone)] #[logos(skip r"[ \t\f]+")] // Skip whitespace except newlines pub enum Token { // Keywords #[token("fn")] Function, #[token("let")] Let, #[token("const")] Const, #[token("if")] If, #[token("else")] Else, #[token("while")] While, #[token("for")] For, #[token("return")] Return, #[token("struct")] Struct, #[token("enum")] Enum, #[token("impl")] Impl, #[token("trait")] Trait, #[token("pub")] Pub, #[token("mod")] Mod, #[token("use")] Use, #[token("mut")] Mut, #[token("true")] True, #[token("false")] False, // Identifiers and literals #[regex("[a-zA-Z_][a-zA-Z0-9_]*", |lex| lex.slice().to_string())] Identifier(String), #[regex(r"-?[0-9]+", |lex| lex.slice().parse::<i64>().ok())] Integer(Option<i64>), #[regex(r"-?[0-9]+\.[0-9]+", |lex| lex.slice().parse::<f64>().ok())] Float(Option<f64>), #[regex(r#""([^"\\]|\\.)*""#, |lex| { let s = lex.slice(); s[1..s.len()-1].to_string() })] String(String), #[regex(r"'([^'\\]|\\.)'")] Char, // Comments #[regex(r"//[^\n]*", logos::skip)] #[regex(r"/\*([^*]|\*[^/])*\*/", logos::skip)] Comment, // Operators #[token("+")] Plus, #[token("-")] Minus, #[token("*")] Star, #[token("/")] Slash, #[token("%")] Percent, #[token("=")] Assign, #[token("==")] Equal, #[token("!=")] NotEqual, #[token("<")] Less, #[token("<=")] LessEqual, #[token(">")] Greater, #[token(">=")] GreaterEqual, #[token("&&")] And, #[token("||")] Or, #[token("!")] Not, #[token("&")] Ampersand, #[token("|")] Pipe, #[token("^")] Caret, #[token("<<")] LeftShift, #[token(">>")] RightShift, #[token("+=")] PlusAssign, #[token("-=")] MinusAssign, #[token("*=")] StarAssign, #[token("/=")] SlashAssign, #[token("->")] Arrow, #[token("=>")] FatArrow, #[token("::")] PathSeparator, // Punctuation #[token("(")] LeftParen, #[token(")")] RightParen, #[token("{")] LeftBrace, #[token("}")] RightBrace, #[token("[")] LeftBracket, #[token("]")] RightBracket, #[token(";")] Semicolon, #[token(":")] Colon, #[token(",")] Comma, #[token(".")] Dot, #[token("..")] DotDot, #[token("...")] DotDotDot, #[token("?")] Question, // Special handling for newlines (for line counting) #[token("\n")] Newline, } pub struct TokenStream<'source> { lexer: Lexer<'source, Token>, peeked: Option<Result<Token, ()>>, } impl<'source> TokenStream<'source> { pub fn new(source: &'source str) -> Self { Self { lexer: Token::lexer(source), peeked: None, } } pub fn next_token(&mut self) -> Option<Result<Token, ()>> { if let Some(token) = self.peeked.take() { return Some(token); } self.lexer.next() } pub fn peek_token(&mut self) -> Option<&Result<Token, ()>> { if self.peeked.is_none() { self.peeked = self.lexer.next(); } self.peeked.as_ref() } pub fn span(&self) -> std::ops::Range<usize> { self.lexer.span() } pub fn slice(&self) -> &'source str { self.lexer.slice() } pub fn remainder(&self) -> &'source str { self.lexer.remainder() } } #[derive(Debug, Clone)] pub struct SourceLocation { pub line: usize, pub column: usize, pub byte_offset: usize, } pub struct SourceTracker<'source> { source: &'source str, line_starts: Vec<usize>, } impl<'source> SourceTracker<'source> { pub fn new(source: &'source str) -> Self { let mut line_starts = vec![0]; for (i, ch) in source.char_indices() { if ch == '\n' { line_starts.push(i + 1); } } Self { source, line_starts, } } pub fn location(&self, byte_offset: usize) -> SourceLocation { let line = self .line_starts .binary_search(&byte_offset) .unwrap_or_else(|i| i.saturating_sub(1)); let line_start = self.line_starts[line]; let column = self.source[line_start..byte_offset].chars().count(); SourceLocation { line: line + 1, // 1-based column: column + 1, // 1-based byte_offset, } } pub fn line_content(&self, line: usize) -> &'source str { if line == 0 || line > self.line_starts.len() { return ""; } let start = self.line_starts[line - 1]; let end = if line < self.line_starts.len() { self.line_starts[line] - 1 } else { self.source.len() }; &self.source[start..end] } } pub fn tokenize(input: &str) -> Vec<TokenSpan> { let mut tokens = Vec::new(); let mut lexer = Token::lexer(input); while let Some(result) = lexer.next() { if let Ok(token) = result { tokens.push((token, lexer.span())); } } tokens } pub type TokenSpan = (Token, std::ops::Range<usize>); pub type ErrorSpan = std::ops::Range<usize>; pub fn tokenize_with_errors(input: &str) -> (Vec<TokenSpan>, Vec<ErrorSpan>) { let mut tokens = Vec::new(); let mut errors = Vec::new(); let mut lexer = Token::lexer(input); while let Some(result) = lexer.next() { match result { Ok(token) => tokens.push((token, lexer.span())), Err(()) => errors.push(lexer.span()), } } (tokens, errors) } #[derive(Logos, Debug, PartialEq)] pub enum ExprToken { #[regex(r"[0-9]+", |lex| lex.slice().parse::<i32>().ok())] Number(Option<i32>), #[regex(r"[a-zA-Z_][a-zA-Z0-9_]*", |lex| lex.slice().to_string())] Identifier(String), #[token("+")] Plus, #[token("-")] Minus, #[token("*")] Times, #[token("/")] Divide, #[token("(")] LeftParen, #[token(")")] RightParen, #[regex(r"[ \t\n]+", logos::skip)] Whitespace, } pub fn parse_expression(input: &str) -> Vec<ExprToken> { ExprToken::lexer(input).filter_map(Result::ok).collect() } #[derive(Logos, Debug, PartialEq)] #[logos(extras = IndentationTracker)] pub enum IndentedToken { #[token("\n", |lex| { lex.extras.newline(); logos::Skip })] Newline, #[regex(r"[ ]+", |lex| { let spaces = lex.slice().len(); lex.extras.track_indent(spaces) })] Indent(IndentLevel), #[regex(r"[a-zA-Z_][a-zA-Z0-9_]*", |lex| lex.slice().to_string())] Identifier(String), #[token(":")] Colon, #[regex(r"#[^\n]*", logos::skip)] Comment, } impl IndentationTracker { fn newline(&mut self) { self.at_line_start = true; } fn track_indent(&mut self, spaces: usize) -> IndentLevel { if !self.at_line_start { return IndentLevel::None; } self.at_line_start = false; if self.indent_stack.is_empty() { self.indent_stack.push(0); } let previous = *self.indent_stack.last().unwrap(); use std::cmp::Ordering; match spaces.cmp(&previous) { Ordering::Greater => { self.indent_stack.push(spaces); IndentLevel::Indent } Ordering::Less => { let mut dedent_count = 0; while let Some(&level) = self.indent_stack.last() { if level <= spaces { break; } self.indent_stack.pop(); dedent_count += 1; } IndentLevel::Dedent(dedent_count) } Ordering::Equal => IndentLevel::None, } } } #[derive(Debug, PartialEq)] pub enum IndentLevel { None, Indent, Dedent(usize), } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_tokens() { let input = "fn main() { let x = 42; }"; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::Function); assert_eq!(tokens[1].0, Token::Identifier("main".to_string())); assert_eq!(tokens[2].0, Token::LeftParen); assert_eq!(tokens[3].0, Token::RightParen); assert_eq!(tokens[4].0, Token::LeftBrace); assert_eq!(tokens[5].0, Token::Let); } #[test] fn test_numeric_literals() { let input = "42 -17 3.14 -2.718"; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::Integer(Some(42))); assert_eq!(tokens[1].0, Token::Integer(Some(-17))); assert_eq!(tokens[2].0, Token::Float(Some(3.14))); assert_eq!(tokens[3].0, Token::Float(Some(-2.718))); } #[test] fn test_string_literals() { let input = r#""hello" "world\n" "with\"quotes\"""#; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::String("hello".to_string())); assert_eq!(tokens[1].0, Token::String("world\\n".to_string())); assert_eq!(tokens[2].0, Token::String("with\\\"quotes\\\"".to_string())); } #[test] fn test_operators() { let input = "+ - * / == != <= >= && || -> =>"; let tokens = tokenize(input); assert_eq!(tokens[0].0, Token::Plus); assert_eq!(tokens[1].0, Token::Minus); assert_eq!(tokens[2].0, Token::Star); assert_eq!(tokens[3].0, Token::Slash); assert_eq!(tokens[4].0, Token::Equal); assert_eq!(tokens[5].0, Token::NotEqual); assert_eq!(tokens[6].0, Token::LessEqual); assert_eq!(tokens[7].0, Token::GreaterEqual); assert_eq!(tokens[8].0, Token::And); assert_eq!(tokens[9].0, Token::Or); assert_eq!(tokens[10].0, Token::Arrow); assert_eq!(tokens[11].0, Token::FatArrow); } #[test] fn test_error_handling() { let input = "let x = 42 @ invalid"; let (_tokens, errors) = tokenize_with_errors(input); assert!(!errors.is_empty()); assert_eq!(errors[0], 11..12); // Position of '@' } #[test] fn test_source_location() { let input = "fn main() {\n let x = 42;\n}"; let tracker = SourceTracker::new(input); // 'l' in 'let' on line 2 let loc = tracker.location(16); assert_eq!(loc.line, 2); assert_eq!(loc.column, 5); // Get line content let line2 = tracker.line_content(2); assert_eq!(line2, " let x = 42;"); } #[test] fn test_expression_lexer() { let input = "x + 42 * (y - 3)"; let tokens = parse_expression(input); assert_eq!(tokens[0], ExprToken::Identifier("x".to_string())); assert_eq!(tokens[1], ExprToken::Plus); assert_eq!(tokens[2], ExprToken::Number(Some(42))); assert_eq!(tokens[3], ExprToken::Times); assert_eq!(tokens[4], ExprToken::LeftParen); assert_eq!(tokens[5], ExprToken::Identifier("y".to_string())); assert_eq!(tokens[6], ExprToken::Minus); assert_eq!(tokens[7], ExprToken::Number(Some(3))); assert_eq!(tokens[8], ExprToken::RightParen); } } #[derive(Default)] pub struct IndentationTracker { at_line_start: bool, _current_indent: usize, indent_stack: Vec<usize>, }
The extras field is accessible in token callbacks, allowing the lexer to maintain state between tokens. This enables proper handling of indent/dedent tokens in indentation-sensitive languages.
Performance Characteristics
Logos generates table-driven DFAs that process input in linear time with minimal branching. The generated code uses Rust’s zero-cost abstractions effectively, with performance comparable to hand-written lexers.
The lexer allocates only for captured token data like identifiers and string literals. Token matching itself is allocation-free, making logos suitable for incremental lexing in language servers where performance is critical.
Best Practices
Design your token enum to minimize allocations. Use zero-sized variants for keywords and operators, capturing data only when necessary. Order patterns from most specific to least specific to ensure correct matching precedence.
Keep regular expressions simple and avoid backtracking patterns. Logos compiles regexes to DFAs at compile time, so complex patterns increase compilation time and generated code size. For truly complex patterns like string interpolation, consider using a two-phase approach with a simpler lexer followed by specialized parsing.
Handle whitespace and comments consistently. Use the skip directive for insignificant whitespace, but consider preserving comments if you need them for documentation generation or code formatting. The lexer can emit comment tokens that the parser can then choose to ignore.
Logos integrates well with parser combinators and hand-written recursive descent parsers. Its iterator interface and error handling make it a natural fit for Rust’s parsing ecosystem, providing a solid foundation for building efficient language processors.
nom
Nom is a parser combinator library focused on binary formats and text protocols, emphasizing zero-copy parsing and streaming capabilities. The library uses a functional programming approach where small parsers combine into larger ones through combinator functions. Nom excels at parsing network protocols, file formats, and configuration languages with excellent performance characteristics.
The core abstraction in nom is the IResult
type, which represents the outcome of a parser. Every parser consumes input and produces either a successful parse with remaining input or an error. This design enables parsers to chain naturally, with each parser consuming part of the input and passing the remainder to the next parser.
Core Types and Parsers
#![allow(unused)] fn main() { use nom::branch::alt; use nom::bytes::complete::{escaped, tag, take_while1}; use nom::character::complete::{alpha1, alphanumeric1, char, digit1, multispace0, one_of}; use nom::combinator::{map, map_res, opt, recognize, value}; use nom::multi::{fold_many0, many0, separated_list0}; use nom::sequence::{delimited, pair, preceded}; use nom::{IResult, Parser}; #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, } /// Parse a floating-point number pub fn float(input: &str) -> IResult<&str, f64> { map_res( recognize((opt(char('-')), digit1, opt((char('.'), digit1)))), |s: &str| s.parse::<f64>(), ) .parse(input) } /// Parse an integer pub fn integer(input: &str) -> IResult<&str, i64> { map_res(recognize(pair(opt(char('-')), digit1)), |s: &str| { s.parse::<i64>() }) .parse(input) } /// Parse a string literal with escape sequences pub fn string_literal(input: &str) -> IResult<&str, String> { delimited( char('"'), map( escaped( take_while1(|c: char| c != '"' && c != '\\'), '\\', one_of(r#""n\rt"#), ), |s: &str| s.to_string(), ), char('"'), ) .parse(input) } /// Parse an identifier pub fn identifier(input: &str) -> IResult<&str, String> { map( recognize(pair( alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))), )), |s: &str| s.to_string(), ) .parse(input) } /// Parse whitespace - wraps a parser with optional whitespace fn ws<'a, O, F>(mut inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O> where F: FnMut(&'a str) -> IResult<&'a str, O>, { move |input| { let (input, _) = multispace0.parse(input)?; let (input, result) = inner(input)?; let (input, _) = multispace0.parse(input)?; Ok((input, result)) } } /// Parse a function call pub fn function_call(input: &str) -> IResult<&str, Expr> { map( ( |i| identifier.parse(i), ws(|i| { delimited( char('('), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), char(')'), ) .parse(i) }), ), |(name, args)| Expr::Call(name, args), ) .parse(input) } /// Parse an array literal pub fn array(input: &str) -> IResult<&str, Expr> { map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), ws(|input| char(']').parse(input)), ), Expr::Array, ) .parse(input) } /// Parse a primary expression pub fn primary(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| integer.parse(i), Expr::Number), map(|i| string_literal.parse(i), Expr::String), |i| function_call.parse(i), |i| array.parse(i), map(|i| identifier.parse(i), Expr::Identifier), delimited( ws(|input| char('(').parse(input)), |i| expression.parse(i), ws(|input| char(')').parse(input)), ), )) .parse(input) } /// Parse a term (multiplication and division) pub fn term(input: &str) -> IResult<&str, Expr> { let (input, init) = primary.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Mul, char('*')), value(BinOp::Div, char('/')))).parse(input) }), |i| primary.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Parse an expression (addition and subtraction) pub fn expression(input: &str) -> IResult<&str, Expr> { let (input, init) = term.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Add, char('+')), value(BinOp::Sub, char('-')))).parse(input) }), |i| term.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Configuration file parser example #[derive(Debug, Clone, PartialEq)] pub struct Config { pub sections: Vec<Section>, } #[derive(Debug, Clone, PartialEq)] pub struct Section { pub name: String, pub entries: Vec<(String, Value)>, } #[derive(Debug, Clone, PartialEq)] pub enum Value { String(String), Number(f64), Boolean(bool), List(Vec<Value>), } /// Parse a configuration value pub fn config_value(input: &str) -> IResult<&str, Value> { alt(( map(float, Value::Number), map(string_literal, Value::String), map(tag("true"), |_| Value::Boolean(true)), map(tag("false"), |_| Value::Boolean(false)), map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| { config_value.parse(i) }), ws(|input| char(']').parse(input)), ), Value::List, ), )) .parse(input) } /// Parse a configuration entry pub fn config_entry(input: &str) -> IResult<&str, (String, Value)> { map( ( ws(|input| identifier.parse(input)), ws(|input| char('=').parse(input)), ws(|input| config_value.parse(input)), ), |(key, _, value)| (key, value), ) .parse(input) } /// Parse a configuration section pub fn config_section(input: &str) -> IResult<&str, Section> { map( ( delimited( ws(|input| char('[').parse(input)), identifier, ws(|input| char(']').parse(input)), ), many0(config_entry), ), |(name, entries)| Section { name, entries }, ) .parse(input) } /// Parse a complete configuration file pub fn parse_config(input: &str) -> IResult<&str, Config> { map(many0(ws(|input| config_section.parse(input))), |sections| { Config { sections } }) .parse(input) } /// Custom error handling with context pub fn parse_with_context(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| identifier.parse(i), Expr::Identifier), delimited( |i| delimited(multispace0, char('('), multispace0).parse(i), |i| parse_with_context.parse(i), |i| delimited(multispace0, char(')'), multispace0).parse(i), ), )) .parse(input) } /// Streaming parser for large files pub fn streaming_parser(input: &str) -> IResult<&str, Vec<Expr>> { many0(delimited( |i| multispace0.parse(i), |i| expression.parse(i), |i| { alt(( map(char(';'), |_| ()), map(|i2| multispace0.parse(i2), |_| ()), )) .parse(i) }, )) .parse(input) } /// Binary format parser pub fn parse_binary_header(input: &[u8]) -> IResult<&[u8], (u32, u32)> { use nom::number::complete::{be_u32, le_u32}; (preceded(tag(&b"MAGIC"[..]), le_u32), be_u32).parse(input) } /// Parser with custom error type #[derive(Debug, PartialEq)] pub enum CustomError { InvalidNumber, UnexpectedToken, MissingDelimiter, } pub fn custom_error_parser(input: &str) -> IResult<&str, Expr> { alt(( map( |i| float.parse(i), |n| { if n.is_finite() { Expr::Float(n) } else { Expr::Float(0.0) // Return default value for invalid numbers } }, ), map(|i| identifier.parse(i), Expr::Identifier), )) .parse(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_float_parser() { assert_eq!(float.parse("3.14"), Ok(("", 3.14))); assert_eq!(float.parse("-2.5"), Ok(("", -2.5))); assert_eq!(float.parse("42"), Ok(("", 42.0))); } #[test] fn test_expression_parser() { use nom::Parser; let result = expression.parse("2 + 3 * 4").unwrap(); assert_eq!( result.1, Expr::Binary( BinOp::Add, Box::new(Expr::Float(2.0)), Box::new(Expr::Binary( BinOp::Mul, Box::new(Expr::Float(3.0)), Box::new(Expr::Float(4.0)) )) ) ); } #[test] fn test_function_call() { use nom::Parser; let result = function_call.parse("max(1, 2, 3)").unwrap(); assert_eq!( result.1, Expr::Call( "max".to_string(), vec![Expr::Float(1.0), Expr::Float(2.0), Expr::Float(3.0)] ) ); } #[test] fn test_config_parser() { use nom::Parser; let config = "[database]\nhost = \"localhost\"\nport = 5432\n"; let result = parse_config.parse(config).unwrap(); assert_eq!(result.1.sections.len(), 1); assert_eq!(result.1.sections[0].name, "database"); assert_eq!(result.1.sections[0].entries.len(), 2); } } /// AST for a simple expression language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Call(String, Vec<Expr>), Array(Vec<Expr>), } }
The expression type demonstrates a typical AST that nom parsers produce. Each variant represents a different syntactic construct that the parser recognizes.
Number Parsing
#![allow(unused)] fn main() { use nom::branch::alt; use nom::bytes::complete::{escaped, tag, take_while1}; use nom::character::complete::{alpha1, alphanumeric1, char, digit1, multispace0, one_of}; use nom::combinator::{map, map_res, opt, recognize, value}; use nom::multi::{fold_many0, many0, separated_list0}; use nom::sequence::{delimited, pair, preceded}; use nom::{IResult, Parser}; /// AST for a simple expression language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Call(String, Vec<Expr>), Array(Vec<Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, } /// Parse an integer pub fn integer(input: &str) -> IResult<&str, i64> { map_res(recognize(pair(opt(char('-')), digit1)), |s: &str| { s.parse::<i64>() }) .parse(input) } /// Parse a string literal with escape sequences pub fn string_literal(input: &str) -> IResult<&str, String> { delimited( char('"'), map( escaped( take_while1(|c: char| c != '"' && c != '\\'), '\\', one_of(r#""n\rt"#), ), |s: &str| s.to_string(), ), char('"'), ) .parse(input) } /// Parse an identifier pub fn identifier(input: &str) -> IResult<&str, String> { map( recognize(pair( alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))), )), |s: &str| s.to_string(), ) .parse(input) } /// Parse whitespace - wraps a parser with optional whitespace fn ws<'a, O, F>(mut inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O> where F: FnMut(&'a str) -> IResult<&'a str, O>, { move |input| { let (input, _) = multispace0.parse(input)?; let (input, result) = inner(input)?; let (input, _) = multispace0.parse(input)?; Ok((input, result)) } } /// Parse a function call pub fn function_call(input: &str) -> IResult<&str, Expr> { map( ( |i| identifier.parse(i), ws(|i| { delimited( char('('), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), char(')'), ) .parse(i) }), ), |(name, args)| Expr::Call(name, args), ) .parse(input) } /// Parse an array literal pub fn array(input: &str) -> IResult<&str, Expr> { map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), ws(|input| char(']').parse(input)), ), Expr::Array, ) .parse(input) } /// Parse a primary expression pub fn primary(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| integer.parse(i), Expr::Number), map(|i| string_literal.parse(i), Expr::String), |i| function_call.parse(i), |i| array.parse(i), map(|i| identifier.parse(i), Expr::Identifier), delimited( ws(|input| char('(').parse(input)), |i| expression.parse(i), ws(|input| char(')').parse(input)), ), )) .parse(input) } /// Parse a term (multiplication and division) pub fn term(input: &str) -> IResult<&str, Expr> { let (input, init) = primary.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Mul, char('*')), value(BinOp::Div, char('/')))).parse(input) }), |i| primary.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Parse an expression (addition and subtraction) pub fn expression(input: &str) -> IResult<&str, Expr> { let (input, init) = term.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Add, char('+')), value(BinOp::Sub, char('-')))).parse(input) }), |i| term.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Configuration file parser example #[derive(Debug, Clone, PartialEq)] pub struct Config { pub sections: Vec<Section>, } #[derive(Debug, Clone, PartialEq)] pub struct Section { pub name: String, pub entries: Vec<(String, Value)>, } #[derive(Debug, Clone, PartialEq)] pub enum Value { String(String), Number(f64), Boolean(bool), List(Vec<Value>), } /// Parse a configuration value pub fn config_value(input: &str) -> IResult<&str, Value> { alt(( map(float, Value::Number), map(string_literal, Value::String), map(tag("true"), |_| Value::Boolean(true)), map(tag("false"), |_| Value::Boolean(false)), map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| { config_value.parse(i) }), ws(|input| char(']').parse(input)), ), Value::List, ), )) .parse(input) } /// Parse a configuration entry pub fn config_entry(input: &str) -> IResult<&str, (String, Value)> { map( ( ws(|input| identifier.parse(input)), ws(|input| char('=').parse(input)), ws(|input| config_value.parse(input)), ), |(key, _, value)| (key, value), ) .parse(input) } /// Parse a configuration section pub fn config_section(input: &str) -> IResult<&str, Section> { map( ( delimited( ws(|input| char('[').parse(input)), identifier, ws(|input| char(']').parse(input)), ), many0(config_entry), ), |(name, entries)| Section { name, entries }, ) .parse(input) } /// Parse a complete configuration file pub fn parse_config(input: &str) -> IResult<&str, Config> { map(many0(ws(|input| config_section.parse(input))), |sections| { Config { sections } }) .parse(input) } /// Custom error handling with context pub fn parse_with_context(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| identifier.parse(i), Expr::Identifier), delimited( |i| delimited(multispace0, char('('), multispace0).parse(i), |i| parse_with_context.parse(i), |i| delimited(multispace0, char(')'), multispace0).parse(i), ), )) .parse(input) } /// Streaming parser for large files pub fn streaming_parser(input: &str) -> IResult<&str, Vec<Expr>> { many0(delimited( |i| multispace0.parse(i), |i| expression.parse(i), |i| { alt(( map(char(';'), |_| ()), map(|i2| multispace0.parse(i2), |_| ()), )) .parse(i) }, )) .parse(input) } /// Binary format parser pub fn parse_binary_header(input: &[u8]) -> IResult<&[u8], (u32, u32)> { use nom::number::complete::{be_u32, le_u32}; (preceded(tag(&b"MAGIC"[..]), le_u32), be_u32).parse(input) } /// Parser with custom error type #[derive(Debug, PartialEq)] pub enum CustomError { InvalidNumber, UnexpectedToken, MissingDelimiter, } pub fn custom_error_parser(input: &str) -> IResult<&str, Expr> { alt(( map( |i| float.parse(i), |n| { if n.is_finite() { Expr::Float(n) } else { Expr::Float(0.0) // Return default value for invalid numbers } }, ), map(|i| identifier.parse(i), Expr::Identifier), )) .parse(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_float_parser() { assert_eq!(float.parse("3.14"), Ok(("", 3.14))); assert_eq!(float.parse("-2.5"), Ok(("", -2.5))); assert_eq!(float.parse("42"), Ok(("", 42.0))); } #[test] fn test_expression_parser() { use nom::Parser; let result = expression.parse("2 + 3 * 4").unwrap(); assert_eq!( result.1, Expr::Binary( BinOp::Add, Box::new(Expr::Float(2.0)), Box::new(Expr::Binary( BinOp::Mul, Box::new(Expr::Float(3.0)), Box::new(Expr::Float(4.0)) )) ) ); } #[test] fn test_function_call() { use nom::Parser; let result = function_call.parse("max(1, 2, 3)").unwrap(); assert_eq!( result.1, Expr::Call( "max".to_string(), vec![Expr::Float(1.0), Expr::Float(2.0), Expr::Float(3.0)] ) ); } #[test] fn test_config_parser() { use nom::Parser; let config = "[database]\nhost = \"localhost\"\nport = 5432\n"; let result = parse_config.parse(config).unwrap(); assert_eq!(result.1.sections.len(), 1); assert_eq!(result.1.sections[0].name, "database"); assert_eq!(result.1.sections[0].entries.len(), 2); } } /// Parse a floating-point number pub fn float(input: &str) -> IResult<&str, f64> { map_res( recognize((opt(char('-')), digit1, opt((char('.'), digit1)))), |s: &str| s.parse::<f64>(), ) .parse(input) } }
The float parser showcases nom’s approach to parsing numeric values. The recognize
combinator captures the matched input as a string slice, while map_res
applies a fallible transformation. This pattern avoids allocation by working directly with input slices.
#![allow(unused)] fn main() { use nom::branch::alt; use nom::bytes::complete::{escaped, tag, take_while1}; use nom::character::complete::{alpha1, alphanumeric1, char, digit1, multispace0, one_of}; use nom::combinator::{map, map_res, opt, recognize, value}; use nom::multi::{fold_many0, many0, separated_list0}; use nom::sequence::{delimited, pair, preceded}; use nom::{IResult, Parser}; /// AST for a simple expression language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Call(String, Vec<Expr>), Array(Vec<Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, } /// Parse a floating-point number pub fn float(input: &str) -> IResult<&str, f64> { map_res( recognize((opt(char('-')), digit1, opt((char('.'), digit1)))), |s: &str| s.parse::<f64>(), ) .parse(input) } /// Parse a string literal with escape sequences pub fn string_literal(input: &str) -> IResult<&str, String> { delimited( char('"'), map( escaped( take_while1(|c: char| c != '"' && c != '\\'), '\\', one_of(r#""n\rt"#), ), |s: &str| s.to_string(), ), char('"'), ) .parse(input) } /// Parse an identifier pub fn identifier(input: &str) -> IResult<&str, String> { map( recognize(pair( alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))), )), |s: &str| s.to_string(), ) .parse(input) } /// Parse whitespace - wraps a parser with optional whitespace fn ws<'a, O, F>(mut inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O> where F: FnMut(&'a str) -> IResult<&'a str, O>, { move |input| { let (input, _) = multispace0.parse(input)?; let (input, result) = inner(input)?; let (input, _) = multispace0.parse(input)?; Ok((input, result)) } } /// Parse a function call pub fn function_call(input: &str) -> IResult<&str, Expr> { map( ( |i| identifier.parse(i), ws(|i| { delimited( char('('), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), char(')'), ) .parse(i) }), ), |(name, args)| Expr::Call(name, args), ) .parse(input) } /// Parse an array literal pub fn array(input: &str) -> IResult<&str, Expr> { map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), ws(|input| char(']').parse(input)), ), Expr::Array, ) .parse(input) } /// Parse a primary expression pub fn primary(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| integer.parse(i), Expr::Number), map(|i| string_literal.parse(i), Expr::String), |i| function_call.parse(i), |i| array.parse(i), map(|i| identifier.parse(i), Expr::Identifier), delimited( ws(|input| char('(').parse(input)), |i| expression.parse(i), ws(|input| char(')').parse(input)), ), )) .parse(input) } /// Parse a term (multiplication and division) pub fn term(input: &str) -> IResult<&str, Expr> { let (input, init) = primary.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Mul, char('*')), value(BinOp::Div, char('/')))).parse(input) }), |i| primary.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Parse an expression (addition and subtraction) pub fn expression(input: &str) -> IResult<&str, Expr> { let (input, init) = term.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Add, char('+')), value(BinOp::Sub, char('-')))).parse(input) }), |i| term.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Configuration file parser example #[derive(Debug, Clone, PartialEq)] pub struct Config { pub sections: Vec<Section>, } #[derive(Debug, Clone, PartialEq)] pub struct Section { pub name: String, pub entries: Vec<(String, Value)>, } #[derive(Debug, Clone, PartialEq)] pub enum Value { String(String), Number(f64), Boolean(bool), List(Vec<Value>), } /// Parse a configuration value pub fn config_value(input: &str) -> IResult<&str, Value> { alt(( map(float, Value::Number), map(string_literal, Value::String), map(tag("true"), |_| Value::Boolean(true)), map(tag("false"), |_| Value::Boolean(false)), map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| { config_value.parse(i) }), ws(|input| char(']').parse(input)), ), Value::List, ), )) .parse(input) } /// Parse a configuration entry pub fn config_entry(input: &str) -> IResult<&str, (String, Value)> { map( ( ws(|input| identifier.parse(input)), ws(|input| char('=').parse(input)), ws(|input| config_value.parse(input)), ), |(key, _, value)| (key, value), ) .parse(input) } /// Parse a configuration section pub fn config_section(input: &str) -> IResult<&str, Section> { map( ( delimited( ws(|input| char('[').parse(input)), identifier, ws(|input| char(']').parse(input)), ), many0(config_entry), ), |(name, entries)| Section { name, entries }, ) .parse(input) } /// Parse a complete configuration file pub fn parse_config(input: &str) -> IResult<&str, Config> { map(many0(ws(|input| config_section.parse(input))), |sections| { Config { sections } }) .parse(input) } /// Custom error handling with context pub fn parse_with_context(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| identifier.parse(i), Expr::Identifier), delimited( |i| delimited(multispace0, char('('), multispace0).parse(i), |i| parse_with_context.parse(i), |i| delimited(multispace0, char(')'), multispace0).parse(i), ), )) .parse(input) } /// Streaming parser for large files pub fn streaming_parser(input: &str) -> IResult<&str, Vec<Expr>> { many0(delimited( |i| multispace0.parse(i), |i| expression.parse(i), |i| { alt(( map(char(';'), |_| ()), map(|i2| multispace0.parse(i2), |_| ()), )) .parse(i) }, )) .parse(input) } /// Binary format parser pub fn parse_binary_header(input: &[u8]) -> IResult<&[u8], (u32, u32)> { use nom::number::complete::{be_u32, le_u32}; (preceded(tag(&b"MAGIC"[..]), le_u32), be_u32).parse(input) } /// Parser with custom error type #[derive(Debug, PartialEq)] pub enum CustomError { InvalidNumber, UnexpectedToken, MissingDelimiter, } pub fn custom_error_parser(input: &str) -> IResult<&str, Expr> { alt(( map( |i| float.parse(i), |n| { if n.is_finite() { Expr::Float(n) } else { Expr::Float(0.0) // Return default value for invalid numbers } }, ), map(|i| identifier.parse(i), Expr::Identifier), )) .parse(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_float_parser() { assert_eq!(float.parse("3.14"), Ok(("", 3.14))); assert_eq!(float.parse("-2.5"), Ok(("", -2.5))); assert_eq!(float.parse("42"), Ok(("", 42.0))); } #[test] fn test_expression_parser() { use nom::Parser; let result = expression.parse("2 + 3 * 4").unwrap(); assert_eq!( result.1, Expr::Binary( BinOp::Add, Box::new(Expr::Float(2.0)), Box::new(Expr::Binary( BinOp::Mul, Box::new(Expr::Float(3.0)), Box::new(Expr::Float(4.0)) )) ) ); } #[test] fn test_function_call() { use nom::Parser; let result = function_call.parse("max(1, 2, 3)").unwrap(); assert_eq!( result.1, Expr::Call( "max".to_string(), vec![Expr::Float(1.0), Expr::Float(2.0), Expr::Float(3.0)] ) ); } #[test] fn test_config_parser() { use nom::Parser; let config = "[database]\nhost = \"localhost\"\nport = 5432\n"; let result = parse_config.parse(config).unwrap(); assert_eq!(result.1.sections.len(), 1); assert_eq!(result.1.sections[0].name, "database"); assert_eq!(result.1.sections[0].entries.len(), 2); } } /// Parse an integer pub fn integer(input: &str) -> IResult<&str, i64> { map_res(recognize(pair(opt(char('-')), digit1)), |s: &str| { s.parse::<i64>() }) .parse(input) } }
Integer parsing follows a similar pattern but handles signed integers. The pair
combinator sequences two parsers, and opt
makes a parser optional, enabling parsing of both positive and negative numbers.
String and Identifier Parsing
#![allow(unused)] fn main() { use nom::branch::alt; use nom::bytes::complete::{escaped, tag, take_while1}; use nom::character::complete::{alpha1, alphanumeric1, char, digit1, multispace0, one_of}; use nom::combinator::{map, map_res, opt, recognize, value}; use nom::multi::{fold_many0, many0, separated_list0}; use nom::sequence::{delimited, pair, preceded}; use nom::{IResult, Parser}; /// AST for a simple expression language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Call(String, Vec<Expr>), Array(Vec<Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, } /// Parse a floating-point number pub fn float(input: &str) -> IResult<&str, f64> { map_res( recognize((opt(char('-')), digit1, opt((char('.'), digit1)))), |s: &str| s.parse::<f64>(), ) .parse(input) } /// Parse an integer pub fn integer(input: &str) -> IResult<&str, i64> { map_res(recognize(pair(opt(char('-')), digit1)), |s: &str| { s.parse::<i64>() }) .parse(input) } /// Parse an identifier pub fn identifier(input: &str) -> IResult<&str, String> { map( recognize(pair( alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))), )), |s: &str| s.to_string(), ) .parse(input) } /// Parse whitespace - wraps a parser with optional whitespace fn ws<'a, O, F>(mut inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O> where F: FnMut(&'a str) -> IResult<&'a str, O>, { move |input| { let (input, _) = multispace0.parse(input)?; let (input, result) = inner(input)?; let (input, _) = multispace0.parse(input)?; Ok((input, result)) } } /// Parse a function call pub fn function_call(input: &str) -> IResult<&str, Expr> { map( ( |i| identifier.parse(i), ws(|i| { delimited( char('('), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), char(')'), ) .parse(i) }), ), |(name, args)| Expr::Call(name, args), ) .parse(input) } /// Parse an array literal pub fn array(input: &str) -> IResult<&str, Expr> { map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), ws(|input| char(']').parse(input)), ), Expr::Array, ) .parse(input) } /// Parse a primary expression pub fn primary(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| integer.parse(i), Expr::Number), map(|i| string_literal.parse(i), Expr::String), |i| function_call.parse(i), |i| array.parse(i), map(|i| identifier.parse(i), Expr::Identifier), delimited( ws(|input| char('(').parse(input)), |i| expression.parse(i), ws(|input| char(')').parse(input)), ), )) .parse(input) } /// Parse a term (multiplication and division) pub fn term(input: &str) -> IResult<&str, Expr> { let (input, init) = primary.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Mul, char('*')), value(BinOp::Div, char('/')))).parse(input) }), |i| primary.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Parse an expression (addition and subtraction) pub fn expression(input: &str) -> IResult<&str, Expr> { let (input, init) = term.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Add, char('+')), value(BinOp::Sub, char('-')))).parse(input) }), |i| term.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Configuration file parser example #[derive(Debug, Clone, PartialEq)] pub struct Config { pub sections: Vec<Section>, } #[derive(Debug, Clone, PartialEq)] pub struct Section { pub name: String, pub entries: Vec<(String, Value)>, } #[derive(Debug, Clone, PartialEq)] pub enum Value { String(String), Number(f64), Boolean(bool), List(Vec<Value>), } /// Parse a configuration value pub fn config_value(input: &str) -> IResult<&str, Value> { alt(( map(float, Value::Number), map(string_literal, Value::String), map(tag("true"), |_| Value::Boolean(true)), map(tag("false"), |_| Value::Boolean(false)), map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| { config_value.parse(i) }), ws(|input| char(']').parse(input)), ), Value::List, ), )) .parse(input) } /// Parse a configuration entry pub fn config_entry(input: &str) -> IResult<&str, (String, Value)> { map( ( ws(|input| identifier.parse(input)), ws(|input| char('=').parse(input)), ws(|input| config_value.parse(input)), ), |(key, _, value)| (key, value), ) .parse(input) } /// Parse a configuration section pub fn config_section(input: &str) -> IResult<&str, Section> { map( ( delimited( ws(|input| char('[').parse(input)), identifier, ws(|input| char(']').parse(input)), ), many0(config_entry), ), |(name, entries)| Section { name, entries }, ) .parse(input) } /// Parse a complete configuration file pub fn parse_config(input: &str) -> IResult<&str, Config> { map(many0(ws(|input| config_section.parse(input))), |sections| { Config { sections } }) .parse(input) } /// Custom error handling with context pub fn parse_with_context(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| identifier.parse(i), Expr::Identifier), delimited( |i| delimited(multispace0, char('('), multispace0).parse(i), |i| parse_with_context.parse(i), |i| delimited(multispace0, char(')'), multispace0).parse(i), ), )) .parse(input) } /// Streaming parser for large files pub fn streaming_parser(input: &str) -> IResult<&str, Vec<Expr>> { many0(delimited( |i| multispace0.parse(i), |i| expression.parse(i), |i| { alt(( map(char(';'), |_| ()), map(|i2| multispace0.parse(i2), |_| ()), )) .parse(i) }, )) .parse(input) } /// Binary format parser pub fn parse_binary_header(input: &[u8]) -> IResult<&[u8], (u32, u32)> { use nom::number::complete::{be_u32, le_u32}; (preceded(tag(&b"MAGIC"[..]), le_u32), be_u32).parse(input) } /// Parser with custom error type #[derive(Debug, PartialEq)] pub enum CustomError { InvalidNumber, UnexpectedToken, MissingDelimiter, } pub fn custom_error_parser(input: &str) -> IResult<&str, Expr> { alt(( map( |i| float.parse(i), |n| { if n.is_finite() { Expr::Float(n) } else { Expr::Float(0.0) // Return default value for invalid numbers } }, ), map(|i| identifier.parse(i), Expr::Identifier), )) .parse(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_float_parser() { assert_eq!(float.parse("3.14"), Ok(("", 3.14))); assert_eq!(float.parse("-2.5"), Ok(("", -2.5))); assert_eq!(float.parse("42"), Ok(("", 42.0))); } #[test] fn test_expression_parser() { use nom::Parser; let result = expression.parse("2 + 3 * 4").unwrap(); assert_eq!( result.1, Expr::Binary( BinOp::Add, Box::new(Expr::Float(2.0)), Box::new(Expr::Binary( BinOp::Mul, Box::new(Expr::Float(3.0)), Box::new(Expr::Float(4.0)) )) ) ); } #[test] fn test_function_call() { use nom::Parser; let result = function_call.parse("max(1, 2, 3)").unwrap(); assert_eq!( result.1, Expr::Call( "max".to_string(), vec![Expr::Float(1.0), Expr::Float(2.0), Expr::Float(3.0)] ) ); } #[test] fn test_config_parser() { use nom::Parser; let config = "[database]\nhost = \"localhost\"\nport = 5432\n"; let result = parse_config.parse(config).unwrap(); assert_eq!(result.1.sections.len(), 1); assert_eq!(result.1.sections[0].name, "database"); assert_eq!(result.1.sections[0].entries.len(), 2); } } /// Parse a string literal with escape sequences pub fn string_literal(input: &str) -> IResult<&str, String> { delimited( char('"'), map( escaped( take_while1(|c: char| c != '"' && c != '\\'), '\\', one_of(r#""n\rt"#), ), |s: &str| s.to_string(), ), char('"'), ) .parse(input) } }
String literal parsing demonstrates nom’s handling of escape sequences. The escaped
combinator recognizes escaped characters within strings, supporting common escape sequences like newlines and quotes. The delimited
combinator extracts content between delimiters.
#![allow(unused)] fn main() { use nom::branch::alt; use nom::bytes::complete::{escaped, tag, take_while1}; use nom::character::complete::{alpha1, alphanumeric1, char, digit1, multispace0, one_of}; use nom::combinator::{map, map_res, opt, recognize, value}; use nom::multi::{fold_many0, many0, separated_list0}; use nom::sequence::{delimited, pair, preceded}; use nom::{IResult, Parser}; /// AST for a simple expression language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Call(String, Vec<Expr>), Array(Vec<Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, } /// Parse a floating-point number pub fn float(input: &str) -> IResult<&str, f64> { map_res( recognize((opt(char('-')), digit1, opt((char('.'), digit1)))), |s: &str| s.parse::<f64>(), ) .parse(input) } /// Parse an integer pub fn integer(input: &str) -> IResult<&str, i64> { map_res(recognize(pair(opt(char('-')), digit1)), |s: &str| { s.parse::<i64>() }) .parse(input) } /// Parse a string literal with escape sequences pub fn string_literal(input: &str) -> IResult<&str, String> { delimited( char('"'), map( escaped( take_while1(|c: char| c != '"' && c != '\\'), '\\', one_of(r#""n\rt"#), ), |s: &str| s.to_string(), ), char('"'), ) .parse(input) } /// Parse whitespace - wraps a parser with optional whitespace fn ws<'a, O, F>(mut inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O> where F: FnMut(&'a str) -> IResult<&'a str, O>, { move |input| { let (input, _) = multispace0.parse(input)?; let (input, result) = inner(input)?; let (input, _) = multispace0.parse(input)?; Ok((input, result)) } } /// Parse a function call pub fn function_call(input: &str) -> IResult<&str, Expr> { map( ( |i| identifier.parse(i), ws(|i| { delimited( char('('), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), char(')'), ) .parse(i) }), ), |(name, args)| Expr::Call(name, args), ) .parse(input) } /// Parse an array literal pub fn array(input: &str) -> IResult<&str, Expr> { map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), ws(|input| char(']').parse(input)), ), Expr::Array, ) .parse(input) } /// Parse a primary expression pub fn primary(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| integer.parse(i), Expr::Number), map(|i| string_literal.parse(i), Expr::String), |i| function_call.parse(i), |i| array.parse(i), map(|i| identifier.parse(i), Expr::Identifier), delimited( ws(|input| char('(').parse(input)), |i| expression.parse(i), ws(|input| char(')').parse(input)), ), )) .parse(input) } /// Parse a term (multiplication and division) pub fn term(input: &str) -> IResult<&str, Expr> { let (input, init) = primary.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Mul, char('*')), value(BinOp::Div, char('/')))).parse(input) }), |i| primary.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Parse an expression (addition and subtraction) pub fn expression(input: &str) -> IResult<&str, Expr> { let (input, init) = term.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Add, char('+')), value(BinOp::Sub, char('-')))).parse(input) }), |i| term.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Configuration file parser example #[derive(Debug, Clone, PartialEq)] pub struct Config { pub sections: Vec<Section>, } #[derive(Debug, Clone, PartialEq)] pub struct Section { pub name: String, pub entries: Vec<(String, Value)>, } #[derive(Debug, Clone, PartialEq)] pub enum Value { String(String), Number(f64), Boolean(bool), List(Vec<Value>), } /// Parse a configuration value pub fn config_value(input: &str) -> IResult<&str, Value> { alt(( map(float, Value::Number), map(string_literal, Value::String), map(tag("true"), |_| Value::Boolean(true)), map(tag("false"), |_| Value::Boolean(false)), map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| { config_value.parse(i) }), ws(|input| char(']').parse(input)), ), Value::List, ), )) .parse(input) } /// Parse a configuration entry pub fn config_entry(input: &str) -> IResult<&str, (String, Value)> { map( ( ws(|input| identifier.parse(input)), ws(|input| char('=').parse(input)), ws(|input| config_value.parse(input)), ), |(key, _, value)| (key, value), ) .parse(input) } /// Parse a configuration section pub fn config_section(input: &str) -> IResult<&str, Section> { map( ( delimited( ws(|input| char('[').parse(input)), identifier, ws(|input| char(']').parse(input)), ), many0(config_entry), ), |(name, entries)| Section { name, entries }, ) .parse(input) } /// Parse a complete configuration file pub fn parse_config(input: &str) -> IResult<&str, Config> { map(many0(ws(|input| config_section.parse(input))), |sections| { Config { sections } }) .parse(input) } /// Custom error handling with context pub fn parse_with_context(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| identifier.parse(i), Expr::Identifier), delimited( |i| delimited(multispace0, char('('), multispace0).parse(i), |i| parse_with_context.parse(i), |i| delimited(multispace0, char(')'), multispace0).parse(i), ), )) .parse(input) } /// Streaming parser for large files pub fn streaming_parser(input: &str) -> IResult<&str, Vec<Expr>> { many0(delimited( |i| multispace0.parse(i), |i| expression.parse(i), |i| { alt(( map(char(';'), |_| ()), map(|i2| multispace0.parse(i2), |_| ()), )) .parse(i) }, )) .parse(input) } /// Binary format parser pub fn parse_binary_header(input: &[u8]) -> IResult<&[u8], (u32, u32)> { use nom::number::complete::{be_u32, le_u32}; (preceded(tag(&b"MAGIC"[..]), le_u32), be_u32).parse(input) } /// Parser with custom error type #[derive(Debug, PartialEq)] pub enum CustomError { InvalidNumber, UnexpectedToken, MissingDelimiter, } pub fn custom_error_parser(input: &str) -> IResult<&str, Expr> { alt(( map( |i| float.parse(i), |n| { if n.is_finite() { Expr::Float(n) } else { Expr::Float(0.0) // Return default value for invalid numbers } }, ), map(|i| identifier.parse(i), Expr::Identifier), )) .parse(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_float_parser() { assert_eq!(float.parse("3.14"), Ok(("", 3.14))); assert_eq!(float.parse("-2.5"), Ok(("", -2.5))); assert_eq!(float.parse("42"), Ok(("", 42.0))); } #[test] fn test_expression_parser() { use nom::Parser; let result = expression.parse("2 + 3 * 4").unwrap(); assert_eq!( result.1, Expr::Binary( BinOp::Add, Box::new(Expr::Float(2.0)), Box::new(Expr::Binary( BinOp::Mul, Box::new(Expr::Float(3.0)), Box::new(Expr::Float(4.0)) )) ) ); } #[test] fn test_function_call() { use nom::Parser; let result = function_call.parse("max(1, 2, 3)").unwrap(); assert_eq!( result.1, Expr::Call( "max".to_string(), vec![Expr::Float(1.0), Expr::Float(2.0), Expr::Float(3.0)] ) ); } #[test] fn test_config_parser() { use nom::Parser; let config = "[database]\nhost = \"localhost\"\nport = 5432\n"; let result = parse_config.parse(config).unwrap(); assert_eq!(result.1.sections.len(), 1); assert_eq!(result.1.sections[0].name, "database"); assert_eq!(result.1.sections[0].entries.len(), 2); } } /// Parse an identifier pub fn identifier(input: &str) -> IResult<&str, String> { map( recognize(pair( alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))), )), |s: &str| s.to_string(), ) .parse(input) } }
Identifier parsing shows how to build parsers for programming language tokens. The recognize
combinator returns the matched input slice rather than the parsed components, avoiding string allocation. The alt
combinator tries multiple alternatives until one succeeds.
Expression Parsing
#![allow(unused)] fn main() { use nom::branch::alt; use nom::bytes::complete::{escaped, tag, take_while1}; use nom::character::complete::{alpha1, alphanumeric1, char, digit1, multispace0, one_of}; use nom::combinator::{map, map_res, opt, recognize, value}; use nom::multi::{fold_many0, many0, separated_list0}; use nom::sequence::{delimited, pair, preceded}; use nom::{IResult, Parser}; /// AST for a simple expression language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Call(String, Vec<Expr>), Array(Vec<Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, } /// Parse a floating-point number pub fn float(input: &str) -> IResult<&str, f64> { map_res( recognize((opt(char('-')), digit1, opt((char('.'), digit1)))), |s: &str| s.parse::<f64>(), ) .parse(input) } /// Parse an integer pub fn integer(input: &str) -> IResult<&str, i64> { map_res(recognize(pair(opt(char('-')), digit1)), |s: &str| { s.parse::<i64>() }) .parse(input) } /// Parse a string literal with escape sequences pub fn string_literal(input: &str) -> IResult<&str, String> { delimited( char('"'), map( escaped( take_while1(|c: char| c != '"' && c != '\\'), '\\', one_of(r#""n\rt"#), ), |s: &str| s.to_string(), ), char('"'), ) .parse(input) } /// Parse an identifier pub fn identifier(input: &str) -> IResult<&str, String> { map( recognize(pair( alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))), )), |s: &str| s.to_string(), ) .parse(input) } /// Parse whitespace - wraps a parser with optional whitespace fn ws<'a, O, F>(mut inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O> where F: FnMut(&'a str) -> IResult<&'a str, O>, { move |input| { let (input, _) = multispace0.parse(input)?; let (input, result) = inner(input)?; let (input, _) = multispace0.parse(input)?; Ok((input, result)) } } /// Parse a function call pub fn function_call(input: &str) -> IResult<&str, Expr> { map( ( |i| identifier.parse(i), ws(|i| { delimited( char('('), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), char(')'), ) .parse(i) }), ), |(name, args)| Expr::Call(name, args), ) .parse(input) } /// Parse an array literal pub fn array(input: &str) -> IResult<&str, Expr> { map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), ws(|input| char(']').parse(input)), ), Expr::Array, ) .parse(input) } /// Parse a primary expression pub fn primary(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| integer.parse(i), Expr::Number), map(|i| string_literal.parse(i), Expr::String), |i| function_call.parse(i), |i| array.parse(i), map(|i| identifier.parse(i), Expr::Identifier), delimited( ws(|input| char('(').parse(input)), |i| expression.parse(i), ws(|input| char(')').parse(input)), ), )) .parse(input) } /// Parse a term (multiplication and division) pub fn term(input: &str) -> IResult<&str, Expr> { let (input, init) = primary.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Mul, char('*')), value(BinOp::Div, char('/')))).parse(input) }), |i| primary.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Configuration file parser example #[derive(Debug, Clone, PartialEq)] pub struct Config { pub sections: Vec<Section>, } #[derive(Debug, Clone, PartialEq)] pub struct Section { pub name: String, pub entries: Vec<(String, Value)>, } #[derive(Debug, Clone, PartialEq)] pub enum Value { String(String), Number(f64), Boolean(bool), List(Vec<Value>), } /// Parse a configuration value pub fn config_value(input: &str) -> IResult<&str, Value> { alt(( map(float, Value::Number), map(string_literal, Value::String), map(tag("true"), |_| Value::Boolean(true)), map(tag("false"), |_| Value::Boolean(false)), map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| { config_value.parse(i) }), ws(|input| char(']').parse(input)), ), Value::List, ), )) .parse(input) } /// Parse a configuration entry pub fn config_entry(input: &str) -> IResult<&str, (String, Value)> { map( ( ws(|input| identifier.parse(input)), ws(|input| char('=').parse(input)), ws(|input| config_value.parse(input)), ), |(key, _, value)| (key, value), ) .parse(input) } /// Parse a configuration section pub fn config_section(input: &str) -> IResult<&str, Section> { map( ( delimited( ws(|input| char('[').parse(input)), identifier, ws(|input| char(']').parse(input)), ), many0(config_entry), ), |(name, entries)| Section { name, entries }, ) .parse(input) } /// Parse a complete configuration file pub fn parse_config(input: &str) -> IResult<&str, Config> { map(many0(ws(|input| config_section.parse(input))), |sections| { Config { sections } }) .parse(input) } /// Custom error handling with context pub fn parse_with_context(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| identifier.parse(i), Expr::Identifier), delimited( |i| delimited(multispace0, char('('), multispace0).parse(i), |i| parse_with_context.parse(i), |i| delimited(multispace0, char(')'), multispace0).parse(i), ), )) .parse(input) } /// Streaming parser for large files pub fn streaming_parser(input: &str) -> IResult<&str, Vec<Expr>> { many0(delimited( |i| multispace0.parse(i), |i| expression.parse(i), |i| { alt(( map(char(';'), |_| ()), map(|i2| multispace0.parse(i2), |_| ()), )) .parse(i) }, )) .parse(input) } /// Binary format parser pub fn parse_binary_header(input: &[u8]) -> IResult<&[u8], (u32, u32)> { use nom::number::complete::{be_u32, le_u32}; (preceded(tag(&b"MAGIC"[..]), le_u32), be_u32).parse(input) } /// Parser with custom error type #[derive(Debug, PartialEq)] pub enum CustomError { InvalidNumber, UnexpectedToken, MissingDelimiter, } pub fn custom_error_parser(input: &str) -> IResult<&str, Expr> { alt(( map( |i| float.parse(i), |n| { if n.is_finite() { Expr::Float(n) } else { Expr::Float(0.0) // Return default value for invalid numbers } }, ), map(|i| identifier.parse(i), Expr::Identifier), )) .parse(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_float_parser() { assert_eq!(float.parse("3.14"), Ok(("", 3.14))); assert_eq!(float.parse("-2.5"), Ok(("", -2.5))); assert_eq!(float.parse("42"), Ok(("", 42.0))); } #[test] fn test_expression_parser() { use nom::Parser; let result = expression.parse("2 + 3 * 4").unwrap(); assert_eq!( result.1, Expr::Binary( BinOp::Add, Box::new(Expr::Float(2.0)), Box::new(Expr::Binary( BinOp::Mul, Box::new(Expr::Float(3.0)), Box::new(Expr::Float(4.0)) )) ) ); } #[test] fn test_function_call() { use nom::Parser; let result = function_call.parse("max(1, 2, 3)").unwrap(); assert_eq!( result.1, Expr::Call( "max".to_string(), vec![Expr::Float(1.0), Expr::Float(2.0), Expr::Float(3.0)] ) ); } #[test] fn test_config_parser() { use nom::Parser; let config = "[database]\nhost = \"localhost\"\nport = 5432\n"; let result = parse_config.parse(config).unwrap(); assert_eq!(result.1.sections.len(), 1); assert_eq!(result.1.sections[0].name, "database"); assert_eq!(result.1.sections[0].entries.len(), 2); } } /// Parse an expression (addition and subtraction) pub fn expression(input: &str) -> IResult<&str, Expr> { let (input, init) = term.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Add, char('+')), value(BinOp::Sub, char('-')))).parse(input) }), |i| term.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } }
Expression parsing demonstrates operator precedence through parser layering. The fold_many0
combinator implements left-associative binary operators by folding a sequence of operations. Higher precedence operations like multiplication are parsed in the term
function, called from within expression parsing.
The separation of term
and expression
functions creates the precedence hierarchy. Terms handle multiplication and division, while expressions handle addition and subtraction. This structure ensures correct operator precedence without explicit precedence declarations.
Function Calls and Arrays
#![allow(unused)] fn main() { use nom::branch::alt; use nom::bytes::complete::{escaped, tag, take_while1}; use nom::character::complete::{alpha1, alphanumeric1, char, digit1, multispace0, one_of}; use nom::combinator::{map, map_res, opt, recognize, value}; use nom::multi::{fold_many0, many0, separated_list0}; use nom::sequence::{delimited, pair, preceded}; use nom::{IResult, Parser}; /// AST for a simple expression language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Call(String, Vec<Expr>), Array(Vec<Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, } /// Parse a floating-point number pub fn float(input: &str) -> IResult<&str, f64> { map_res( recognize((opt(char('-')), digit1, opt((char('.'), digit1)))), |s: &str| s.parse::<f64>(), ) .parse(input) } /// Parse an integer pub fn integer(input: &str) -> IResult<&str, i64> { map_res(recognize(pair(opt(char('-')), digit1)), |s: &str| { s.parse::<i64>() }) .parse(input) } /// Parse a string literal with escape sequences pub fn string_literal(input: &str) -> IResult<&str, String> { delimited( char('"'), map( escaped( take_while1(|c: char| c != '"' && c != '\\'), '\\', one_of(r#""n\rt"#), ), |s: &str| s.to_string(), ), char('"'), ) .parse(input) } /// Parse an identifier pub fn identifier(input: &str) -> IResult<&str, String> { map( recognize(pair( alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))), )), |s: &str| s.to_string(), ) .parse(input) } /// Parse whitespace - wraps a parser with optional whitespace fn ws<'a, O, F>(mut inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O> where F: FnMut(&'a str) -> IResult<&'a str, O>, { move |input| { let (input, _) = multispace0.parse(input)?; let (input, result) = inner(input)?; let (input, _) = multispace0.parse(input)?; Ok((input, result)) } } /// Parse an array literal pub fn array(input: &str) -> IResult<&str, Expr> { map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), ws(|input| char(']').parse(input)), ), Expr::Array, ) .parse(input) } /// Parse a primary expression pub fn primary(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| integer.parse(i), Expr::Number), map(|i| string_literal.parse(i), Expr::String), |i| function_call.parse(i), |i| array.parse(i), map(|i| identifier.parse(i), Expr::Identifier), delimited( ws(|input| char('(').parse(input)), |i| expression.parse(i), ws(|input| char(')').parse(input)), ), )) .parse(input) } /// Parse a term (multiplication and division) pub fn term(input: &str) -> IResult<&str, Expr> { let (input, init) = primary.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Mul, char('*')), value(BinOp::Div, char('/')))).parse(input) }), |i| primary.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Parse an expression (addition and subtraction) pub fn expression(input: &str) -> IResult<&str, Expr> { let (input, init) = term.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Add, char('+')), value(BinOp::Sub, char('-')))).parse(input) }), |i| term.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Configuration file parser example #[derive(Debug, Clone, PartialEq)] pub struct Config { pub sections: Vec<Section>, } #[derive(Debug, Clone, PartialEq)] pub struct Section { pub name: String, pub entries: Vec<(String, Value)>, } #[derive(Debug, Clone, PartialEq)] pub enum Value { String(String), Number(f64), Boolean(bool), List(Vec<Value>), } /// Parse a configuration value pub fn config_value(input: &str) -> IResult<&str, Value> { alt(( map(float, Value::Number), map(string_literal, Value::String), map(tag("true"), |_| Value::Boolean(true)), map(tag("false"), |_| Value::Boolean(false)), map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| { config_value.parse(i) }), ws(|input| char(']').parse(input)), ), Value::List, ), )) .parse(input) } /// Parse a configuration entry pub fn config_entry(input: &str) -> IResult<&str, (String, Value)> { map( ( ws(|input| identifier.parse(input)), ws(|input| char('=').parse(input)), ws(|input| config_value.parse(input)), ), |(key, _, value)| (key, value), ) .parse(input) } /// Parse a configuration section pub fn config_section(input: &str) -> IResult<&str, Section> { map( ( delimited( ws(|input| char('[').parse(input)), identifier, ws(|input| char(']').parse(input)), ), many0(config_entry), ), |(name, entries)| Section { name, entries }, ) .parse(input) } /// Parse a complete configuration file pub fn parse_config(input: &str) -> IResult<&str, Config> { map(many0(ws(|input| config_section.parse(input))), |sections| { Config { sections } }) .parse(input) } /// Custom error handling with context pub fn parse_with_context(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| identifier.parse(i), Expr::Identifier), delimited( |i| delimited(multispace0, char('('), multispace0).parse(i), |i| parse_with_context.parse(i), |i| delimited(multispace0, char(')'), multispace0).parse(i), ), )) .parse(input) } /// Streaming parser for large files pub fn streaming_parser(input: &str) -> IResult<&str, Vec<Expr>> { many0(delimited( |i| multispace0.parse(i), |i| expression.parse(i), |i| { alt(( map(char(';'), |_| ()), map(|i2| multispace0.parse(i2), |_| ()), )) .parse(i) }, )) .parse(input) } /// Binary format parser pub fn parse_binary_header(input: &[u8]) -> IResult<&[u8], (u32, u32)> { use nom::number::complete::{be_u32, le_u32}; (preceded(tag(&b"MAGIC"[..]), le_u32), be_u32).parse(input) } /// Parser with custom error type #[derive(Debug, PartialEq)] pub enum CustomError { InvalidNumber, UnexpectedToken, MissingDelimiter, } pub fn custom_error_parser(input: &str) -> IResult<&str, Expr> { alt(( map( |i| float.parse(i), |n| { if n.is_finite() { Expr::Float(n) } else { Expr::Float(0.0) // Return default value for invalid numbers } }, ), map(|i| identifier.parse(i), Expr::Identifier), )) .parse(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_float_parser() { assert_eq!(float.parse("3.14"), Ok(("", 3.14))); assert_eq!(float.parse("-2.5"), Ok(("", -2.5))); assert_eq!(float.parse("42"), Ok(("", 42.0))); } #[test] fn test_expression_parser() { use nom::Parser; let result = expression.parse("2 + 3 * 4").unwrap(); assert_eq!( result.1, Expr::Binary( BinOp::Add, Box::new(Expr::Float(2.0)), Box::new(Expr::Binary( BinOp::Mul, Box::new(Expr::Float(3.0)), Box::new(Expr::Float(4.0)) )) ) ); } #[test] fn test_function_call() { use nom::Parser; let result = function_call.parse("max(1, 2, 3)").unwrap(); assert_eq!( result.1, Expr::Call( "max".to_string(), vec![Expr::Float(1.0), Expr::Float(2.0), Expr::Float(3.0)] ) ); } #[test] fn test_config_parser() { use nom::Parser; let config = "[database]\nhost = \"localhost\"\nport = 5432\n"; let result = parse_config.parse(config).unwrap(); assert_eq!(result.1.sections.len(), 1); assert_eq!(result.1.sections[0].name, "database"); assert_eq!(result.1.sections[0].entries.len(), 2); } } /// Parse a function call pub fn function_call(input: &str) -> IResult<&str, Expr> { map( ( |i| identifier.parse(i), ws(|i| { delimited( char('('), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), char(')'), ) .parse(i) }), ), |(name, args)| Expr::Call(name, args), ) .parse(input) } }
Function call parsing combines several nom features. The tuple
combinator sequences multiple parsers, capturing all results. The separated_list0
combinator handles comma-separated argument lists, a common pattern in programming languages.
#![allow(unused)] fn main() { use nom::branch::alt; use nom::bytes::complete::{escaped, tag, take_while1}; use nom::character::complete::{alpha1, alphanumeric1, char, digit1, multispace0, one_of}; use nom::combinator::{map, map_res, opt, recognize, value}; use nom::multi::{fold_many0, many0, separated_list0}; use nom::sequence::{delimited, pair, preceded}; use nom::{IResult, Parser}; /// AST for a simple expression language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Call(String, Vec<Expr>), Array(Vec<Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, } /// Parse a floating-point number pub fn float(input: &str) -> IResult<&str, f64> { map_res( recognize((opt(char('-')), digit1, opt((char('.'), digit1)))), |s: &str| s.parse::<f64>(), ) .parse(input) } /// Parse an integer pub fn integer(input: &str) -> IResult<&str, i64> { map_res(recognize(pair(opt(char('-')), digit1)), |s: &str| { s.parse::<i64>() }) .parse(input) } /// Parse a string literal with escape sequences pub fn string_literal(input: &str) -> IResult<&str, String> { delimited( char('"'), map( escaped( take_while1(|c: char| c != '"' && c != '\\'), '\\', one_of(r#""n\rt"#), ), |s: &str| s.to_string(), ), char('"'), ) .parse(input) } /// Parse an identifier pub fn identifier(input: &str) -> IResult<&str, String> { map( recognize(pair( alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))), )), |s: &str| s.to_string(), ) .parse(input) } /// Parse whitespace - wraps a parser with optional whitespace fn ws<'a, O, F>(mut inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O> where F: FnMut(&'a str) -> IResult<&'a str, O>, { move |input| { let (input, _) = multispace0.parse(input)?; let (input, result) = inner(input)?; let (input, _) = multispace0.parse(input)?; Ok((input, result)) } } /// Parse a function call pub fn function_call(input: &str) -> IResult<&str, Expr> { map( ( |i| identifier.parse(i), ws(|i| { delimited( char('('), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), char(')'), ) .parse(i) }), ), |(name, args)| Expr::Call(name, args), ) .parse(input) } /// Parse a primary expression pub fn primary(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| integer.parse(i), Expr::Number), map(|i| string_literal.parse(i), Expr::String), |i| function_call.parse(i), |i| array.parse(i), map(|i| identifier.parse(i), Expr::Identifier), delimited( ws(|input| char('(').parse(input)), |i| expression.parse(i), ws(|input| char(')').parse(input)), ), )) .parse(input) } /// Parse a term (multiplication and division) pub fn term(input: &str) -> IResult<&str, Expr> { let (input, init) = primary.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Mul, char('*')), value(BinOp::Div, char('/')))).parse(input) }), |i| primary.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Parse an expression (addition and subtraction) pub fn expression(input: &str) -> IResult<&str, Expr> { let (input, init) = term.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Add, char('+')), value(BinOp::Sub, char('-')))).parse(input) }), |i| term.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Configuration file parser example #[derive(Debug, Clone, PartialEq)] pub struct Config { pub sections: Vec<Section>, } #[derive(Debug, Clone, PartialEq)] pub struct Section { pub name: String, pub entries: Vec<(String, Value)>, } #[derive(Debug, Clone, PartialEq)] pub enum Value { String(String), Number(f64), Boolean(bool), List(Vec<Value>), } /// Parse a configuration value pub fn config_value(input: &str) -> IResult<&str, Value> { alt(( map(float, Value::Number), map(string_literal, Value::String), map(tag("true"), |_| Value::Boolean(true)), map(tag("false"), |_| Value::Boolean(false)), map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| { config_value.parse(i) }), ws(|input| char(']').parse(input)), ), Value::List, ), )) .parse(input) } /// Parse a configuration entry pub fn config_entry(input: &str) -> IResult<&str, (String, Value)> { map( ( ws(|input| identifier.parse(input)), ws(|input| char('=').parse(input)), ws(|input| config_value.parse(input)), ), |(key, _, value)| (key, value), ) .parse(input) } /// Parse a configuration section pub fn config_section(input: &str) -> IResult<&str, Section> { map( ( delimited( ws(|input| char('[').parse(input)), identifier, ws(|input| char(']').parse(input)), ), many0(config_entry), ), |(name, entries)| Section { name, entries }, ) .parse(input) } /// Parse a complete configuration file pub fn parse_config(input: &str) -> IResult<&str, Config> { map(many0(ws(|input| config_section.parse(input))), |sections| { Config { sections } }) .parse(input) } /// Custom error handling with context pub fn parse_with_context(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| identifier.parse(i), Expr::Identifier), delimited( |i| delimited(multispace0, char('('), multispace0).parse(i), |i| parse_with_context.parse(i), |i| delimited(multispace0, char(')'), multispace0).parse(i), ), )) .parse(input) } /// Streaming parser for large files pub fn streaming_parser(input: &str) -> IResult<&str, Vec<Expr>> { many0(delimited( |i| multispace0.parse(i), |i| expression.parse(i), |i| { alt(( map(char(';'), |_| ()), map(|i2| multispace0.parse(i2), |_| ()), )) .parse(i) }, )) .parse(input) } /// Binary format parser pub fn parse_binary_header(input: &[u8]) -> IResult<&[u8], (u32, u32)> { use nom::number::complete::{be_u32, le_u32}; (preceded(tag(&b"MAGIC"[..]), le_u32), be_u32).parse(input) } /// Parser with custom error type #[derive(Debug, PartialEq)] pub enum CustomError { InvalidNumber, UnexpectedToken, MissingDelimiter, } pub fn custom_error_parser(input: &str) -> IResult<&str, Expr> { alt(( map( |i| float.parse(i), |n| { if n.is_finite() { Expr::Float(n) } else { Expr::Float(0.0) // Return default value for invalid numbers } }, ), map(|i| identifier.parse(i), Expr::Identifier), )) .parse(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_float_parser() { assert_eq!(float.parse("3.14"), Ok(("", 3.14))); assert_eq!(float.parse("-2.5"), Ok(("", -2.5))); assert_eq!(float.parse("42"), Ok(("", 42.0))); } #[test] fn test_expression_parser() { use nom::Parser; let result = expression.parse("2 + 3 * 4").unwrap(); assert_eq!( result.1, Expr::Binary( BinOp::Add, Box::new(Expr::Float(2.0)), Box::new(Expr::Binary( BinOp::Mul, Box::new(Expr::Float(3.0)), Box::new(Expr::Float(4.0)) )) ) ); } #[test] fn test_function_call() { use nom::Parser; let result = function_call.parse("max(1, 2, 3)").unwrap(); assert_eq!( result.1, Expr::Call( "max".to_string(), vec![Expr::Float(1.0), Expr::Float(2.0), Expr::Float(3.0)] ) ); } #[test] fn test_config_parser() { use nom::Parser; let config = "[database]\nhost = \"localhost\"\nport = 5432\n"; let result = parse_config.parse(config).unwrap(); assert_eq!(result.1.sections.len(), 1); assert_eq!(result.1.sections[0].name, "database"); assert_eq!(result.1.sections[0].entries.len(), 2); } } /// Parse an array literal pub fn array(input: &str) -> IResult<&str, Expr> { map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), ws(|input| char(']').parse(input)), ), Expr::Array, ) .parse(input) } }
Array parsing uses similar techniques but with different delimiters. The ws
helper function handles whitespace around tokens, a critical aspect of parsing human-readable formats.
Configuration File Parsing
#![allow(unused)] fn main() { use nom::branch::alt; use nom::bytes::complete::{escaped, tag, take_while1}; use nom::character::complete::{alpha1, alphanumeric1, char, digit1, multispace0, one_of}; use nom::combinator::{map, map_res, opt, recognize, value}; use nom::multi::{fold_many0, many0, separated_list0}; use nom::sequence::{delimited, pair, preceded}; use nom::{IResult, Parser}; /// AST for a simple expression language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Call(String, Vec<Expr>), Array(Vec<Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, } /// Parse a floating-point number pub fn float(input: &str) -> IResult<&str, f64> { map_res( recognize((opt(char('-')), digit1, opt((char('.'), digit1)))), |s: &str| s.parse::<f64>(), ) .parse(input) } /// Parse an integer pub fn integer(input: &str) -> IResult<&str, i64> { map_res(recognize(pair(opt(char('-')), digit1)), |s: &str| { s.parse::<i64>() }) .parse(input) } /// Parse a string literal with escape sequences pub fn string_literal(input: &str) -> IResult<&str, String> { delimited( char('"'), map( escaped( take_while1(|c: char| c != '"' && c != '\\'), '\\', one_of(r#""n\rt"#), ), |s: &str| s.to_string(), ), char('"'), ) .parse(input) } /// Parse an identifier pub fn identifier(input: &str) -> IResult<&str, String> { map( recognize(pair( alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))), )), |s: &str| s.to_string(), ) .parse(input) } /// Parse whitespace - wraps a parser with optional whitespace fn ws<'a, O, F>(mut inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O> where F: FnMut(&'a str) -> IResult<&'a str, O>, { move |input| { let (input, _) = multispace0.parse(input)?; let (input, result) = inner(input)?; let (input, _) = multispace0.parse(input)?; Ok((input, result)) } } /// Parse a function call pub fn function_call(input: &str) -> IResult<&str, Expr> { map( ( |i| identifier.parse(i), ws(|i| { delimited( char('('), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), char(')'), ) .parse(i) }), ), |(name, args)| Expr::Call(name, args), ) .parse(input) } /// Parse an array literal pub fn array(input: &str) -> IResult<&str, Expr> { map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), ws(|input| char(']').parse(input)), ), Expr::Array, ) .parse(input) } /// Parse a primary expression pub fn primary(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| integer.parse(i), Expr::Number), map(|i| string_literal.parse(i), Expr::String), |i| function_call.parse(i), |i| array.parse(i), map(|i| identifier.parse(i), Expr::Identifier), delimited( ws(|input| char('(').parse(input)), |i| expression.parse(i), ws(|input| char(')').parse(input)), ), )) .parse(input) } /// Parse a term (multiplication and division) pub fn term(input: &str) -> IResult<&str, Expr> { let (input, init) = primary.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Mul, char('*')), value(BinOp::Div, char('/')))).parse(input) }), |i| primary.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Parse an expression (addition and subtraction) pub fn expression(input: &str) -> IResult<&str, Expr> { let (input, init) = term.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Add, char('+')), value(BinOp::Sub, char('-')))).parse(input) }), |i| term.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } #[derive(Debug, Clone, PartialEq)] pub struct Section { pub name: String, pub entries: Vec<(String, Value)>, } #[derive(Debug, Clone, PartialEq)] pub enum Value { String(String), Number(f64), Boolean(bool), List(Vec<Value>), } /// Parse a configuration value pub fn config_value(input: &str) -> IResult<&str, Value> { alt(( map(float, Value::Number), map(string_literal, Value::String), map(tag("true"), |_| Value::Boolean(true)), map(tag("false"), |_| Value::Boolean(false)), map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| { config_value.parse(i) }), ws(|input| char(']').parse(input)), ), Value::List, ), )) .parse(input) } /// Parse a configuration entry pub fn config_entry(input: &str) -> IResult<&str, (String, Value)> { map( ( ws(|input| identifier.parse(input)), ws(|input| char('=').parse(input)), ws(|input| config_value.parse(input)), ), |(key, _, value)| (key, value), ) .parse(input) } /// Parse a configuration section pub fn config_section(input: &str) -> IResult<&str, Section> { map( ( delimited( ws(|input| char('[').parse(input)), identifier, ws(|input| char(']').parse(input)), ), many0(config_entry), ), |(name, entries)| Section { name, entries }, ) .parse(input) } /// Parse a complete configuration file pub fn parse_config(input: &str) -> IResult<&str, Config> { map(many0(ws(|input| config_section.parse(input))), |sections| { Config { sections } }) .parse(input) } /// Custom error handling with context pub fn parse_with_context(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| identifier.parse(i), Expr::Identifier), delimited( |i| delimited(multispace0, char('('), multispace0).parse(i), |i| parse_with_context.parse(i), |i| delimited(multispace0, char(')'), multispace0).parse(i), ), )) .parse(input) } /// Streaming parser for large files pub fn streaming_parser(input: &str) -> IResult<&str, Vec<Expr>> { many0(delimited( |i| multispace0.parse(i), |i| expression.parse(i), |i| { alt(( map(char(';'), |_| ()), map(|i2| multispace0.parse(i2), |_| ()), )) .parse(i) }, )) .parse(input) } /// Binary format parser pub fn parse_binary_header(input: &[u8]) -> IResult<&[u8], (u32, u32)> { use nom::number::complete::{be_u32, le_u32}; (preceded(tag(&b"MAGIC"[..]), le_u32), be_u32).parse(input) } /// Parser with custom error type #[derive(Debug, PartialEq)] pub enum CustomError { InvalidNumber, UnexpectedToken, MissingDelimiter, } pub fn custom_error_parser(input: &str) -> IResult<&str, Expr> { alt(( map( |i| float.parse(i), |n| { if n.is_finite() { Expr::Float(n) } else { Expr::Float(0.0) // Return default value for invalid numbers } }, ), map(|i| identifier.parse(i), Expr::Identifier), )) .parse(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_float_parser() { assert_eq!(float.parse("3.14"), Ok(("", 3.14))); assert_eq!(float.parse("-2.5"), Ok(("", -2.5))); assert_eq!(float.parse("42"), Ok(("", 42.0))); } #[test] fn test_expression_parser() { use nom::Parser; let result = expression.parse("2 + 3 * 4").unwrap(); assert_eq!( result.1, Expr::Binary( BinOp::Add, Box::new(Expr::Float(2.0)), Box::new(Expr::Binary( BinOp::Mul, Box::new(Expr::Float(3.0)), Box::new(Expr::Float(4.0)) )) ) ); } #[test] fn test_function_call() { use nom::Parser; let result = function_call.parse("max(1, 2, 3)").unwrap(); assert_eq!( result.1, Expr::Call( "max".to_string(), vec![Expr::Float(1.0), Expr::Float(2.0), Expr::Float(3.0)] ) ); } #[test] fn test_config_parser() { use nom::Parser; let config = "[database]\nhost = \"localhost\"\nport = 5432\n"; let result = parse_config.parse(config).unwrap(); assert_eq!(result.1.sections.len(), 1); assert_eq!(result.1.sections[0].name, "database"); assert_eq!(result.1.sections[0].entries.len(), 2); } } /// Configuration file parser example #[derive(Debug, Clone, PartialEq)] pub struct Config { pub sections: Vec<Section>, } }
#![allow(unused)] fn main() { use nom::branch::alt; use nom::bytes::complete::{escaped, tag, take_while1}; use nom::character::complete::{alpha1, alphanumeric1, char, digit1, multispace0, one_of}; use nom::combinator::{map, map_res, opt, recognize, value}; use nom::multi::{fold_many0, many0, separated_list0}; use nom::sequence::{delimited, pair, preceded}; use nom::{IResult, Parser}; /// AST for a simple expression language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Call(String, Vec<Expr>), Array(Vec<Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, } /// Parse a floating-point number pub fn float(input: &str) -> IResult<&str, f64> { map_res( recognize((opt(char('-')), digit1, opt((char('.'), digit1)))), |s: &str| s.parse::<f64>(), ) .parse(input) } /// Parse an integer pub fn integer(input: &str) -> IResult<&str, i64> { map_res(recognize(pair(opt(char('-')), digit1)), |s: &str| { s.parse::<i64>() }) .parse(input) } /// Parse a string literal with escape sequences pub fn string_literal(input: &str) -> IResult<&str, String> { delimited( char('"'), map( escaped( take_while1(|c: char| c != '"' && c != '\\'), '\\', one_of(r#""n\rt"#), ), |s: &str| s.to_string(), ), char('"'), ) .parse(input) } /// Parse an identifier pub fn identifier(input: &str) -> IResult<&str, String> { map( recognize(pair( alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))), )), |s: &str| s.to_string(), ) .parse(input) } /// Parse whitespace - wraps a parser with optional whitespace fn ws<'a, O, F>(mut inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O> where F: FnMut(&'a str) -> IResult<&'a str, O>, { move |input| { let (input, _) = multispace0.parse(input)?; let (input, result) = inner(input)?; let (input, _) = multispace0.parse(input)?; Ok((input, result)) } } /// Parse a function call pub fn function_call(input: &str) -> IResult<&str, Expr> { map( ( |i| identifier.parse(i), ws(|i| { delimited( char('('), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), char(')'), ) .parse(i) }), ), |(name, args)| Expr::Call(name, args), ) .parse(input) } /// Parse an array literal pub fn array(input: &str) -> IResult<&str, Expr> { map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), ws(|input| char(']').parse(input)), ), Expr::Array, ) .parse(input) } /// Parse a primary expression pub fn primary(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| integer.parse(i), Expr::Number), map(|i| string_literal.parse(i), Expr::String), |i| function_call.parse(i), |i| array.parse(i), map(|i| identifier.parse(i), Expr::Identifier), delimited( ws(|input| char('(').parse(input)), |i| expression.parse(i), ws(|input| char(')').parse(input)), ), )) .parse(input) } /// Parse a term (multiplication and division) pub fn term(input: &str) -> IResult<&str, Expr> { let (input, init) = primary.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Mul, char('*')), value(BinOp::Div, char('/')))).parse(input) }), |i| primary.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Parse an expression (addition and subtraction) pub fn expression(input: &str) -> IResult<&str, Expr> { let (input, init) = term.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Add, char('+')), value(BinOp::Sub, char('-')))).parse(input) }), |i| term.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Configuration file parser example #[derive(Debug, Clone, PartialEq)] pub struct Config { pub sections: Vec<Section>, } #[derive(Debug, Clone, PartialEq)] pub struct Section { pub name: String, pub entries: Vec<(String, Value)>, } /// Parse a configuration value pub fn config_value(input: &str) -> IResult<&str, Value> { alt(( map(float, Value::Number), map(string_literal, Value::String), map(tag("true"), |_| Value::Boolean(true)), map(tag("false"), |_| Value::Boolean(false)), map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| { config_value.parse(i) }), ws(|input| char(']').parse(input)), ), Value::List, ), )) .parse(input) } /// Parse a configuration entry pub fn config_entry(input: &str) -> IResult<&str, (String, Value)> { map( ( ws(|input| identifier.parse(input)), ws(|input| char('=').parse(input)), ws(|input| config_value.parse(input)), ), |(key, _, value)| (key, value), ) .parse(input) } /// Parse a configuration section pub fn config_section(input: &str) -> IResult<&str, Section> { map( ( delimited( ws(|input| char('[').parse(input)), identifier, ws(|input| char(']').parse(input)), ), many0(config_entry), ), |(name, entries)| Section { name, entries }, ) .parse(input) } /// Parse a complete configuration file pub fn parse_config(input: &str) -> IResult<&str, Config> { map(many0(ws(|input| config_section.parse(input))), |sections| { Config { sections } }) .parse(input) } /// Custom error handling with context pub fn parse_with_context(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| identifier.parse(i), Expr::Identifier), delimited( |i| delimited(multispace0, char('('), multispace0).parse(i), |i| parse_with_context.parse(i), |i| delimited(multispace0, char(')'), multispace0).parse(i), ), )) .parse(input) } /// Streaming parser for large files pub fn streaming_parser(input: &str) -> IResult<&str, Vec<Expr>> { many0(delimited( |i| multispace0.parse(i), |i| expression.parse(i), |i| { alt(( map(char(';'), |_| ()), map(|i2| multispace0.parse(i2), |_| ()), )) .parse(i) }, )) .parse(input) } /// Binary format parser pub fn parse_binary_header(input: &[u8]) -> IResult<&[u8], (u32, u32)> { use nom::number::complete::{be_u32, le_u32}; (preceded(tag(&b"MAGIC"[..]), le_u32), be_u32).parse(input) } /// Parser with custom error type #[derive(Debug, PartialEq)] pub enum CustomError { InvalidNumber, UnexpectedToken, MissingDelimiter, } pub fn custom_error_parser(input: &str) -> IResult<&str, Expr> { alt(( map( |i| float.parse(i), |n| { if n.is_finite() { Expr::Float(n) } else { Expr::Float(0.0) // Return default value for invalid numbers } }, ), map(|i| identifier.parse(i), Expr::Identifier), )) .parse(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_float_parser() { assert_eq!(float.parse("3.14"), Ok(("", 3.14))); assert_eq!(float.parse("-2.5"), Ok(("", -2.5))); assert_eq!(float.parse("42"), Ok(("", 42.0))); } #[test] fn test_expression_parser() { use nom::Parser; let result = expression.parse("2 + 3 * 4").unwrap(); assert_eq!( result.1, Expr::Binary( BinOp::Add, Box::new(Expr::Float(2.0)), Box::new(Expr::Binary( BinOp::Mul, Box::new(Expr::Float(3.0)), Box::new(Expr::Float(4.0)) )) ) ); } #[test] fn test_function_call() { use nom::Parser; let result = function_call.parse("max(1, 2, 3)").unwrap(); assert_eq!( result.1, Expr::Call( "max".to_string(), vec![Expr::Float(1.0), Expr::Float(2.0), Expr::Float(3.0)] ) ); } #[test] fn test_config_parser() { use nom::Parser; let config = "[database]\nhost = \"localhost\"\nport = 5432\n"; let result = parse_config.parse(config).unwrap(); assert_eq!(result.1.sections.len(), 1); assert_eq!(result.1.sections[0].name, "database"); assert_eq!(result.1.sections[0].entries.len(), 2); } } #[derive(Debug, Clone, PartialEq)] pub enum Value { String(String), Number(f64), Boolean(bool), List(Vec<Value>), } }
Configuration parsing demonstrates nom’s suitability for structured data formats. The types represent a typical configuration file structure with sections and key-value pairs.
#![allow(unused)] fn main() { use nom::branch::alt; use nom::bytes::complete::{escaped, tag, take_while1}; use nom::character::complete::{alpha1, alphanumeric1, char, digit1, multispace0, one_of}; use nom::combinator::{map, map_res, opt, recognize, value}; use nom::multi::{fold_many0, many0, separated_list0}; use nom::sequence::{delimited, pair, preceded}; use nom::{IResult, Parser}; /// AST for a simple expression language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Call(String, Vec<Expr>), Array(Vec<Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, } /// Parse a floating-point number pub fn float(input: &str) -> IResult<&str, f64> { map_res( recognize((opt(char('-')), digit1, opt((char('.'), digit1)))), |s: &str| s.parse::<f64>(), ) .parse(input) } /// Parse an integer pub fn integer(input: &str) -> IResult<&str, i64> { map_res(recognize(pair(opt(char('-')), digit1)), |s: &str| { s.parse::<i64>() }) .parse(input) } /// Parse a string literal with escape sequences pub fn string_literal(input: &str) -> IResult<&str, String> { delimited( char('"'), map( escaped( take_while1(|c: char| c != '"' && c != '\\'), '\\', one_of(r#""n\rt"#), ), |s: &str| s.to_string(), ), char('"'), ) .parse(input) } /// Parse an identifier pub fn identifier(input: &str) -> IResult<&str, String> { map( recognize(pair( alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))), )), |s: &str| s.to_string(), ) .parse(input) } /// Parse whitespace - wraps a parser with optional whitespace fn ws<'a, O, F>(mut inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O> where F: FnMut(&'a str) -> IResult<&'a str, O>, { move |input| { let (input, _) = multispace0.parse(input)?; let (input, result) = inner(input)?; let (input, _) = multispace0.parse(input)?; Ok((input, result)) } } /// Parse a function call pub fn function_call(input: &str) -> IResult<&str, Expr> { map( ( |i| identifier.parse(i), ws(|i| { delimited( char('('), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), char(')'), ) .parse(i) }), ), |(name, args)| Expr::Call(name, args), ) .parse(input) } /// Parse an array literal pub fn array(input: &str) -> IResult<&str, Expr> { map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), ws(|input| char(']').parse(input)), ), Expr::Array, ) .parse(input) } /// Parse a primary expression pub fn primary(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| integer.parse(i), Expr::Number), map(|i| string_literal.parse(i), Expr::String), |i| function_call.parse(i), |i| array.parse(i), map(|i| identifier.parse(i), Expr::Identifier), delimited( ws(|input| char('(').parse(input)), |i| expression.parse(i), ws(|input| char(')').parse(input)), ), )) .parse(input) } /// Parse a term (multiplication and division) pub fn term(input: &str) -> IResult<&str, Expr> { let (input, init) = primary.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Mul, char('*')), value(BinOp::Div, char('/')))).parse(input) }), |i| primary.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Parse an expression (addition and subtraction) pub fn expression(input: &str) -> IResult<&str, Expr> { let (input, init) = term.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Add, char('+')), value(BinOp::Sub, char('-')))).parse(input) }), |i| term.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Configuration file parser example #[derive(Debug, Clone, PartialEq)] pub struct Config { pub sections: Vec<Section>, } #[derive(Debug, Clone, PartialEq)] pub struct Section { pub name: String, pub entries: Vec<(String, Value)>, } #[derive(Debug, Clone, PartialEq)] pub enum Value { String(String), Number(f64), Boolean(bool), List(Vec<Value>), } /// Parse a configuration value pub fn config_value(input: &str) -> IResult<&str, Value> { alt(( map(float, Value::Number), map(string_literal, Value::String), map(tag("true"), |_| Value::Boolean(true)), map(tag("false"), |_| Value::Boolean(false)), map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| { config_value.parse(i) }), ws(|input| char(']').parse(input)), ), Value::List, ), )) .parse(input) } /// Parse a configuration entry pub fn config_entry(input: &str) -> IResult<&str, (String, Value)> { map( ( ws(|input| identifier.parse(input)), ws(|input| char('=').parse(input)), ws(|input| config_value.parse(input)), ), |(key, _, value)| (key, value), ) .parse(input) } /// Parse a configuration section pub fn config_section(input: &str) -> IResult<&str, Section> { map( ( delimited( ws(|input| char('[').parse(input)), identifier, ws(|input| char(']').parse(input)), ), many0(config_entry), ), |(name, entries)| Section { name, entries }, ) .parse(input) } /// Custom error handling with context pub fn parse_with_context(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| identifier.parse(i), Expr::Identifier), delimited( |i| delimited(multispace0, char('('), multispace0).parse(i), |i| parse_with_context.parse(i), |i| delimited(multispace0, char(')'), multispace0).parse(i), ), )) .parse(input) } /// Streaming parser for large files pub fn streaming_parser(input: &str) -> IResult<&str, Vec<Expr>> { many0(delimited( |i| multispace0.parse(i), |i| expression.parse(i), |i| { alt(( map(char(';'), |_| ()), map(|i2| multispace0.parse(i2), |_| ()), )) .parse(i) }, )) .parse(input) } /// Binary format parser pub fn parse_binary_header(input: &[u8]) -> IResult<&[u8], (u32, u32)> { use nom::number::complete::{be_u32, le_u32}; (preceded(tag(&b"MAGIC"[..]), le_u32), be_u32).parse(input) } /// Parser with custom error type #[derive(Debug, PartialEq)] pub enum CustomError { InvalidNumber, UnexpectedToken, MissingDelimiter, } pub fn custom_error_parser(input: &str) -> IResult<&str, Expr> { alt(( map( |i| float.parse(i), |n| { if n.is_finite() { Expr::Float(n) } else { Expr::Float(0.0) // Return default value for invalid numbers } }, ), map(|i| identifier.parse(i), Expr::Identifier), )) .parse(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_float_parser() { assert_eq!(float.parse("3.14"), Ok(("", 3.14))); assert_eq!(float.parse("-2.5"), Ok(("", -2.5))); assert_eq!(float.parse("42"), Ok(("", 42.0))); } #[test] fn test_expression_parser() { use nom::Parser; let result = expression.parse("2 + 3 * 4").unwrap(); assert_eq!( result.1, Expr::Binary( BinOp::Add, Box::new(Expr::Float(2.0)), Box::new(Expr::Binary( BinOp::Mul, Box::new(Expr::Float(3.0)), Box::new(Expr::Float(4.0)) )) ) ); } #[test] fn test_function_call() { use nom::Parser; let result = function_call.parse("max(1, 2, 3)").unwrap(); assert_eq!( result.1, Expr::Call( "max".to_string(), vec![Expr::Float(1.0), Expr::Float(2.0), Expr::Float(3.0)] ) ); } #[test] fn test_config_parser() { use nom::Parser; let config = "[database]\nhost = \"localhost\"\nport = 5432\n"; let result = parse_config.parse(config).unwrap(); assert_eq!(result.1.sections.len(), 1); assert_eq!(result.1.sections[0].name, "database"); assert_eq!(result.1.sections[0].entries.len(), 2); } } /// Parse a complete configuration file pub fn parse_config(input: &str) -> IResult<&str, Config> { map(many0(ws(|input| config_section.parse(input))), |sections| { Config { sections } }) .parse(input) } }
The configuration parser builds up from smaller parsers for values, entries, and sections. Each parser focuses on one aspect of the format, combining through nom’s compositional approach. The many0
combinator parses zero or more occurrences, building collections incrementally.
Error Handling
#![allow(unused)] fn main() { use nom::branch::alt; use nom::bytes::complete::{escaped, tag, take_while1}; use nom::character::complete::{alpha1, alphanumeric1, char, digit1, multispace0, one_of}; use nom::combinator::{map, map_res, opt, recognize, value}; use nom::multi::{fold_many0, many0, separated_list0}; use nom::sequence::{delimited, pair, preceded}; use nom::{IResult, Parser}; /// AST for a simple expression language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Call(String, Vec<Expr>), Array(Vec<Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, } /// Parse a floating-point number pub fn float(input: &str) -> IResult<&str, f64> { map_res( recognize((opt(char('-')), digit1, opt((char('.'), digit1)))), |s: &str| s.parse::<f64>(), ) .parse(input) } /// Parse an integer pub fn integer(input: &str) -> IResult<&str, i64> { map_res(recognize(pair(opt(char('-')), digit1)), |s: &str| { s.parse::<i64>() }) .parse(input) } /// Parse a string literal with escape sequences pub fn string_literal(input: &str) -> IResult<&str, String> { delimited( char('"'), map( escaped( take_while1(|c: char| c != '"' && c != '\\'), '\\', one_of(r#""n\rt"#), ), |s: &str| s.to_string(), ), char('"'), ) .parse(input) } /// Parse an identifier pub fn identifier(input: &str) -> IResult<&str, String> { map( recognize(pair( alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))), )), |s: &str| s.to_string(), ) .parse(input) } /// Parse whitespace - wraps a parser with optional whitespace fn ws<'a, O, F>(mut inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O> where F: FnMut(&'a str) -> IResult<&'a str, O>, { move |input| { let (input, _) = multispace0.parse(input)?; let (input, result) = inner(input)?; let (input, _) = multispace0.parse(input)?; Ok((input, result)) } } /// Parse a function call pub fn function_call(input: &str) -> IResult<&str, Expr> { map( ( |i| identifier.parse(i), ws(|i| { delimited( char('('), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), char(')'), ) .parse(i) }), ), |(name, args)| Expr::Call(name, args), ) .parse(input) } /// Parse an array literal pub fn array(input: &str) -> IResult<&str, Expr> { map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), ws(|input| char(']').parse(input)), ), Expr::Array, ) .parse(input) } /// Parse a primary expression pub fn primary(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| integer.parse(i), Expr::Number), map(|i| string_literal.parse(i), Expr::String), |i| function_call.parse(i), |i| array.parse(i), map(|i| identifier.parse(i), Expr::Identifier), delimited( ws(|input| char('(').parse(input)), |i| expression.parse(i), ws(|input| char(')').parse(input)), ), )) .parse(input) } /// Parse a term (multiplication and division) pub fn term(input: &str) -> IResult<&str, Expr> { let (input, init) = primary.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Mul, char('*')), value(BinOp::Div, char('/')))).parse(input) }), |i| primary.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Parse an expression (addition and subtraction) pub fn expression(input: &str) -> IResult<&str, Expr> { let (input, init) = term.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Add, char('+')), value(BinOp::Sub, char('-')))).parse(input) }), |i| term.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Configuration file parser example #[derive(Debug, Clone, PartialEq)] pub struct Config { pub sections: Vec<Section>, } #[derive(Debug, Clone, PartialEq)] pub struct Section { pub name: String, pub entries: Vec<(String, Value)>, } #[derive(Debug, Clone, PartialEq)] pub enum Value { String(String), Number(f64), Boolean(bool), List(Vec<Value>), } /// Parse a configuration value pub fn config_value(input: &str) -> IResult<&str, Value> { alt(( map(float, Value::Number), map(string_literal, Value::String), map(tag("true"), |_| Value::Boolean(true)), map(tag("false"), |_| Value::Boolean(false)), map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| { config_value.parse(i) }), ws(|input| char(']').parse(input)), ), Value::List, ), )) .parse(input) } /// Parse a configuration entry pub fn config_entry(input: &str) -> IResult<&str, (String, Value)> { map( ( ws(|input| identifier.parse(input)), ws(|input| char('=').parse(input)), ws(|input| config_value.parse(input)), ), |(key, _, value)| (key, value), ) .parse(input) } /// Parse a configuration section pub fn config_section(input: &str) -> IResult<&str, Section> { map( ( delimited( ws(|input| char('[').parse(input)), identifier, ws(|input| char(']').parse(input)), ), many0(config_entry), ), |(name, entries)| Section { name, entries }, ) .parse(input) } /// Parse a complete configuration file pub fn parse_config(input: &str) -> IResult<&str, Config> { map(many0(ws(|input| config_section.parse(input))), |sections| { Config { sections } }) .parse(input) } /// Streaming parser for large files pub fn streaming_parser(input: &str) -> IResult<&str, Vec<Expr>> { many0(delimited( |i| multispace0.parse(i), |i| expression.parse(i), |i| { alt(( map(char(';'), |_| ()), map(|i2| multispace0.parse(i2), |_| ()), )) .parse(i) }, )) .parse(input) } /// Binary format parser pub fn parse_binary_header(input: &[u8]) -> IResult<&[u8], (u32, u32)> { use nom::number::complete::{be_u32, le_u32}; (preceded(tag(&b"MAGIC"[..]), le_u32), be_u32).parse(input) } /// Parser with custom error type #[derive(Debug, PartialEq)] pub enum CustomError { InvalidNumber, UnexpectedToken, MissingDelimiter, } pub fn custom_error_parser(input: &str) -> IResult<&str, Expr> { alt(( map( |i| float.parse(i), |n| { if n.is_finite() { Expr::Float(n) } else { Expr::Float(0.0) // Return default value for invalid numbers } }, ), map(|i| identifier.parse(i), Expr::Identifier), )) .parse(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_float_parser() { assert_eq!(float.parse("3.14"), Ok(("", 3.14))); assert_eq!(float.parse("-2.5"), Ok(("", -2.5))); assert_eq!(float.parse("42"), Ok(("", 42.0))); } #[test] fn test_expression_parser() { use nom::Parser; let result = expression.parse("2 + 3 * 4").unwrap(); assert_eq!( result.1, Expr::Binary( BinOp::Add, Box::new(Expr::Float(2.0)), Box::new(Expr::Binary( BinOp::Mul, Box::new(Expr::Float(3.0)), Box::new(Expr::Float(4.0)) )) ) ); } #[test] fn test_function_call() { use nom::Parser; let result = function_call.parse("max(1, 2, 3)").unwrap(); assert_eq!( result.1, Expr::Call( "max".to_string(), vec![Expr::Float(1.0), Expr::Float(2.0), Expr::Float(3.0)] ) ); } #[test] fn test_config_parser() { use nom::Parser; let config = "[database]\nhost = \"localhost\"\nport = 5432\n"; let result = parse_config.parse(config).unwrap(); assert_eq!(result.1.sections.len(), 1); assert_eq!(result.1.sections[0].name, "database"); assert_eq!(result.1.sections[0].entries.len(), 2); } } /// Custom error handling with context pub fn parse_with_context(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| identifier.parse(i), Expr::Identifier), delimited( |i| delimited(multispace0, char('('), multispace0).parse(i), |i| parse_with_context.parse(i), |i| delimited(multispace0, char(')'), multispace0).parse(i), ), )) .parse(input) } }
Context-aware parsing improves error messages by annotating parsers with descriptive labels. The context
combinator wraps parsers with error context, while cut
prevents backtracking after partial matches. This combination provides precise error messages indicating exactly where parsing failed.
The VerboseError
type collects detailed error information including the error location and a trace of attempted parses. This information helps developers understand why parsing failed and where in the grammar the error occurred.
Streaming and Binary Parsing
#![allow(unused)] fn main() { use nom::branch::alt; use nom::bytes::complete::{escaped, tag, take_while1}; use nom::character::complete::{alpha1, alphanumeric1, char, digit1, multispace0, one_of}; use nom::combinator::{map, map_res, opt, recognize, value}; use nom::multi::{fold_many0, many0, separated_list0}; use nom::sequence::{delimited, pair, preceded}; use nom::{IResult, Parser}; /// AST for a simple expression language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Call(String, Vec<Expr>), Array(Vec<Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, } /// Parse a floating-point number pub fn float(input: &str) -> IResult<&str, f64> { map_res( recognize((opt(char('-')), digit1, opt((char('.'), digit1)))), |s: &str| s.parse::<f64>(), ) .parse(input) } /// Parse an integer pub fn integer(input: &str) -> IResult<&str, i64> { map_res(recognize(pair(opt(char('-')), digit1)), |s: &str| { s.parse::<i64>() }) .parse(input) } /// Parse a string literal with escape sequences pub fn string_literal(input: &str) -> IResult<&str, String> { delimited( char('"'), map( escaped( take_while1(|c: char| c != '"' && c != '\\'), '\\', one_of(r#""n\rt"#), ), |s: &str| s.to_string(), ), char('"'), ) .parse(input) } /// Parse an identifier pub fn identifier(input: &str) -> IResult<&str, String> { map( recognize(pair( alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))), )), |s: &str| s.to_string(), ) .parse(input) } /// Parse whitespace - wraps a parser with optional whitespace fn ws<'a, O, F>(mut inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O> where F: FnMut(&'a str) -> IResult<&'a str, O>, { move |input| { let (input, _) = multispace0.parse(input)?; let (input, result) = inner(input)?; let (input, _) = multispace0.parse(input)?; Ok((input, result)) } } /// Parse a function call pub fn function_call(input: &str) -> IResult<&str, Expr> { map( ( |i| identifier.parse(i), ws(|i| { delimited( char('('), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), char(')'), ) .parse(i) }), ), |(name, args)| Expr::Call(name, args), ) .parse(input) } /// Parse an array literal pub fn array(input: &str) -> IResult<&str, Expr> { map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), ws(|input| char(']').parse(input)), ), Expr::Array, ) .parse(input) } /// Parse a primary expression pub fn primary(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| integer.parse(i), Expr::Number), map(|i| string_literal.parse(i), Expr::String), |i| function_call.parse(i), |i| array.parse(i), map(|i| identifier.parse(i), Expr::Identifier), delimited( ws(|input| char('(').parse(input)), |i| expression.parse(i), ws(|input| char(')').parse(input)), ), )) .parse(input) } /// Parse a term (multiplication and division) pub fn term(input: &str) -> IResult<&str, Expr> { let (input, init) = primary.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Mul, char('*')), value(BinOp::Div, char('/')))).parse(input) }), |i| primary.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Parse an expression (addition and subtraction) pub fn expression(input: &str) -> IResult<&str, Expr> { let (input, init) = term.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Add, char('+')), value(BinOp::Sub, char('-')))).parse(input) }), |i| term.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Configuration file parser example #[derive(Debug, Clone, PartialEq)] pub struct Config { pub sections: Vec<Section>, } #[derive(Debug, Clone, PartialEq)] pub struct Section { pub name: String, pub entries: Vec<(String, Value)>, } #[derive(Debug, Clone, PartialEq)] pub enum Value { String(String), Number(f64), Boolean(bool), List(Vec<Value>), } /// Parse a configuration value pub fn config_value(input: &str) -> IResult<&str, Value> { alt(( map(float, Value::Number), map(string_literal, Value::String), map(tag("true"), |_| Value::Boolean(true)), map(tag("false"), |_| Value::Boolean(false)), map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| { config_value.parse(i) }), ws(|input| char(']').parse(input)), ), Value::List, ), )) .parse(input) } /// Parse a configuration entry pub fn config_entry(input: &str) -> IResult<&str, (String, Value)> { map( ( ws(|input| identifier.parse(input)), ws(|input| char('=').parse(input)), ws(|input| config_value.parse(input)), ), |(key, _, value)| (key, value), ) .parse(input) } /// Parse a configuration section pub fn config_section(input: &str) -> IResult<&str, Section> { map( ( delimited( ws(|input| char('[').parse(input)), identifier, ws(|input| char(']').parse(input)), ), many0(config_entry), ), |(name, entries)| Section { name, entries }, ) .parse(input) } /// Parse a complete configuration file pub fn parse_config(input: &str) -> IResult<&str, Config> { map(many0(ws(|input| config_section.parse(input))), |sections| { Config { sections } }) .parse(input) } /// Custom error handling with context pub fn parse_with_context(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| identifier.parse(i), Expr::Identifier), delimited( |i| delimited(multispace0, char('('), multispace0).parse(i), |i| parse_with_context.parse(i), |i| delimited(multispace0, char(')'), multispace0).parse(i), ), )) .parse(input) } /// Binary format parser pub fn parse_binary_header(input: &[u8]) -> IResult<&[u8], (u32, u32)> { use nom::number::complete::{be_u32, le_u32}; (preceded(tag(&b"MAGIC"[..]), le_u32), be_u32).parse(input) } /// Parser with custom error type #[derive(Debug, PartialEq)] pub enum CustomError { InvalidNumber, UnexpectedToken, MissingDelimiter, } pub fn custom_error_parser(input: &str) -> IResult<&str, Expr> { alt(( map( |i| float.parse(i), |n| { if n.is_finite() { Expr::Float(n) } else { Expr::Float(0.0) // Return default value for invalid numbers } }, ), map(|i| identifier.parse(i), Expr::Identifier), )) .parse(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_float_parser() { assert_eq!(float.parse("3.14"), Ok(("", 3.14))); assert_eq!(float.parse("-2.5"), Ok(("", -2.5))); assert_eq!(float.parse("42"), Ok(("", 42.0))); } #[test] fn test_expression_parser() { use nom::Parser; let result = expression.parse("2 + 3 * 4").unwrap(); assert_eq!( result.1, Expr::Binary( BinOp::Add, Box::new(Expr::Float(2.0)), Box::new(Expr::Binary( BinOp::Mul, Box::new(Expr::Float(3.0)), Box::new(Expr::Float(4.0)) )) ) ); } #[test] fn test_function_call() { use nom::Parser; let result = function_call.parse("max(1, 2, 3)").unwrap(); assert_eq!( result.1, Expr::Call( "max".to_string(), vec![Expr::Float(1.0), Expr::Float(2.0), Expr::Float(3.0)] ) ); } #[test] fn test_config_parser() { use nom::Parser; let config = "[database]\nhost = \"localhost\"\nport = 5432\n"; let result = parse_config.parse(config).unwrap(); assert_eq!(result.1.sections.len(), 1); assert_eq!(result.1.sections[0].name, "database"); assert_eq!(result.1.sections[0].entries.len(), 2); } } /// Streaming parser for large files pub fn streaming_parser(input: &str) -> IResult<&str, Vec<Expr>> { many0(delimited( |i| multispace0.parse(i), |i| expression.parse(i), |i| { alt(( map(char(';'), |_| ()), map(|i2| multispace0.parse(i2), |_| ()), )) .parse(i) }, )) .parse(input) } }
Streaming parsing handles input that may not be completely available. The parser processes available data and indicates how much input was consumed. This approach works well for network protocols and large files that cannot fit in memory.
#![allow(unused)] fn main() { use nom::branch::alt; use nom::bytes::complete::{escaped, tag, take_while1}; use nom::character::complete::{alpha1, alphanumeric1, char, digit1, multispace0, one_of}; use nom::combinator::{map, map_res, opt, recognize, value}; use nom::multi::{fold_many0, many0, separated_list0}; use nom::sequence::{delimited, pair, preceded}; use nom::{IResult, Parser}; /// AST for a simple expression language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Identifier(String), Binary(BinOp, Box<Expr>, Box<Expr>), Call(String, Vec<Expr>), Array(Vec<Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinOp { Add, Sub, Mul, Div, } /// Parse a floating-point number pub fn float(input: &str) -> IResult<&str, f64> { map_res( recognize((opt(char('-')), digit1, opt((char('.'), digit1)))), |s: &str| s.parse::<f64>(), ) .parse(input) } /// Parse an integer pub fn integer(input: &str) -> IResult<&str, i64> { map_res(recognize(pair(opt(char('-')), digit1)), |s: &str| { s.parse::<i64>() }) .parse(input) } /// Parse a string literal with escape sequences pub fn string_literal(input: &str) -> IResult<&str, String> { delimited( char('"'), map( escaped( take_while1(|c: char| c != '"' && c != '\\'), '\\', one_of(r#""n\rt"#), ), |s: &str| s.to_string(), ), char('"'), ) .parse(input) } /// Parse an identifier pub fn identifier(input: &str) -> IResult<&str, String> { map( recognize(pair( alt((alpha1, tag("_"))), many0(alt((alphanumeric1, tag("_")))), )), |s: &str| s.to_string(), ) .parse(input) } /// Parse whitespace - wraps a parser with optional whitespace fn ws<'a, O, F>(mut inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O> where F: FnMut(&'a str) -> IResult<&'a str, O>, { move |input| { let (input, _) = multispace0.parse(input)?; let (input, result) = inner(input)?; let (input, _) = multispace0.parse(input)?; Ok((input, result)) } } /// Parse a function call pub fn function_call(input: &str) -> IResult<&str, Expr> { map( ( |i| identifier.parse(i), ws(|i| { delimited( char('('), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), char(')'), ) .parse(i) }), ), |(name, args)| Expr::Call(name, args), ) .parse(input) } /// Parse an array literal pub fn array(input: &str) -> IResult<&str, Expr> { map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| expression.parse(i)), ws(|input| char(']').parse(input)), ), Expr::Array, ) .parse(input) } /// Parse a primary expression pub fn primary(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| integer.parse(i), Expr::Number), map(|i| string_literal.parse(i), Expr::String), |i| function_call.parse(i), |i| array.parse(i), map(|i| identifier.parse(i), Expr::Identifier), delimited( ws(|input| char('(').parse(input)), |i| expression.parse(i), ws(|input| char(')').parse(input)), ), )) .parse(input) } /// Parse a term (multiplication and division) pub fn term(input: &str) -> IResult<&str, Expr> { let (input, init) = primary.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Mul, char('*')), value(BinOp::Div, char('/')))).parse(input) }), |i| primary.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Parse an expression (addition and subtraction) pub fn expression(input: &str) -> IResult<&str, Expr> { let (input, init) = term.parse(input)?; fold_many0( pair( ws(|input| { alt((value(BinOp::Add, char('+')), value(BinOp::Sub, char('-')))).parse(input) }), |i| term.parse(i), ), move || init.clone(), |acc, (op, val)| Expr::Binary(op, Box::new(acc), Box::new(val)), ) .parse(input) } /// Configuration file parser example #[derive(Debug, Clone, PartialEq)] pub struct Config { pub sections: Vec<Section>, } #[derive(Debug, Clone, PartialEq)] pub struct Section { pub name: String, pub entries: Vec<(String, Value)>, } #[derive(Debug, Clone, PartialEq)] pub enum Value { String(String), Number(f64), Boolean(bool), List(Vec<Value>), } /// Parse a configuration value pub fn config_value(input: &str) -> IResult<&str, Value> { alt(( map(float, Value::Number), map(string_literal, Value::String), map(tag("true"), |_| Value::Boolean(true)), map(tag("false"), |_| Value::Boolean(false)), map( delimited( ws(|input| char('[').parse(input)), separated_list0(ws(|input| char(',').parse(input)), |i| { config_value.parse(i) }), ws(|input| char(']').parse(input)), ), Value::List, ), )) .parse(input) } /// Parse a configuration entry pub fn config_entry(input: &str) -> IResult<&str, (String, Value)> { map( ( ws(|input| identifier.parse(input)), ws(|input| char('=').parse(input)), ws(|input| config_value.parse(input)), ), |(key, _, value)| (key, value), ) .parse(input) } /// Parse a configuration section pub fn config_section(input: &str) -> IResult<&str, Section> { map( ( delimited( ws(|input| char('[').parse(input)), identifier, ws(|input| char(']').parse(input)), ), many0(config_entry), ), |(name, entries)| Section { name, entries }, ) .parse(input) } /// Parse a complete configuration file pub fn parse_config(input: &str) -> IResult<&str, Config> { map(many0(ws(|input| config_section.parse(input))), |sections| { Config { sections } }) .parse(input) } /// Custom error handling with context pub fn parse_with_context(input: &str) -> IResult<&str, Expr> { alt(( map(|i| float.parse(i), Expr::Float), map(|i| identifier.parse(i), Expr::Identifier), delimited( |i| delimited(multispace0, char('('), multispace0).parse(i), |i| parse_with_context.parse(i), |i| delimited(multispace0, char(')'), multispace0).parse(i), ), )) .parse(input) } /// Streaming parser for large files pub fn streaming_parser(input: &str) -> IResult<&str, Vec<Expr>> { many0(delimited( |i| multispace0.parse(i), |i| expression.parse(i), |i| { alt(( map(char(';'), |_| ()), map(|i2| multispace0.parse(i2), |_| ()), )) .parse(i) }, )) .parse(input) } /// Parser with custom error type #[derive(Debug, PartialEq)] pub enum CustomError { InvalidNumber, UnexpectedToken, MissingDelimiter, } pub fn custom_error_parser(input: &str) -> IResult<&str, Expr> { alt(( map( |i| float.parse(i), |n| { if n.is_finite() { Expr::Float(n) } else { Expr::Float(0.0) // Return default value for invalid numbers } }, ), map(|i| identifier.parse(i), Expr::Identifier), )) .parse(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_float_parser() { assert_eq!(float.parse("3.14"), Ok(("", 3.14))); assert_eq!(float.parse("-2.5"), Ok(("", -2.5))); assert_eq!(float.parse("42"), Ok(("", 42.0))); } #[test] fn test_expression_parser() { use nom::Parser; let result = expression.parse("2 + 3 * 4").unwrap(); assert_eq!( result.1, Expr::Binary( BinOp::Add, Box::new(Expr::Float(2.0)), Box::new(Expr::Binary( BinOp::Mul, Box::new(Expr::Float(3.0)), Box::new(Expr::Float(4.0)) )) ) ); } #[test] fn test_function_call() { use nom::Parser; let result = function_call.parse("max(1, 2, 3)").unwrap(); assert_eq!( result.1, Expr::Call( "max".to_string(), vec![Expr::Float(1.0), Expr::Float(2.0), Expr::Float(3.0)] ) ); } #[test] fn test_config_parser() { use nom::Parser; let config = "[database]\nhost = \"localhost\"\nport = 5432\n"; let result = parse_config.parse(config).unwrap(); assert_eq!(result.1.sections.len(), 1); assert_eq!(result.1.sections[0].name, "database"); assert_eq!(result.1.sections[0].entries.len(), 2); } } /// Binary format parser pub fn parse_binary_header(input: &[u8]) -> IResult<&[u8], (u32, u32)> { use nom::number::complete::{be_u32, le_u32}; (preceded(tag(&b"MAGIC"[..]), le_u32), be_u32).parse(input) } }
Binary format parsing showcases nom’s byte-level parsing capabilities. The library provides parsers for various integer encodings, network byte order, and fixed-size data. The take
combinator extracts a specific number of bytes, while endian-specific parsers handle byte order conversions.
Performance Optimization
Nom achieves excellent performance through zero-copy parsing. Parsers work directly with input slices, avoiding string allocation until necessary. The recognize
combinator returns matched input slices, and parsers can pass ownership of subslices rather than copying data.
Careful combinator choice impacts performance. The alt
combinator tries alternatives sequentially, so placing common cases first reduces average parsing time. The many0
and many1
combinators can be replaced with fold_many0
and fold_many1
to avoid intermediate vector allocation.
Nom’s macros generate specialized code for each parser combination, eliminating function call overhead. The generated code often compiles to efficient machine code comparable to hand-written parsers.
Integration Patterns
Nom parsers integrate well with other Rust libraries. The &str
and &[u8]
input types work with standard library types, while the IResult
type integrates with error handling libraries. Parsed ASTs can be processed by subsequent compiler passes or serialized to other formats.
For incremental parsing, nom parsers can save state between invocations. The remaining input from one parse becomes the starting point for the next, enabling parsing of streaming data or interactive input.
Custom input types allow parsing from non-standard sources. Implementing nom’s input traits enables parsing from rope data structures, memory-mapped files, or network streams.
Best Practices
Structure parsers hierarchically with clear separation of concerns. Each parser should handle one grammatical construct, making the grammar evident from the code structure. Use descriptive names that match the grammar terminology.
Test parsers extensively with both valid and invalid input. Property-based testing verifies parser properties like consuming all valid input or rejecting invalid constructs. Fuzzing finds edge cases in parser implementations.
Profile parsers on representative input to identify performance bottlenecks. Complex alternatives or excessive backtracking impact performance. Consider using peek
to look ahead without consuming input when making parsing decisions.
Handle errors gracefully with appropriate error types. The VerboseError
type aids development, while custom error types provide better user experience. Use context
and cut
to improve error messages.
Document the grammar alongside the parser implementation. Comments should explain the grammatical constructs being parsed and any deviations from standard grammar notation. Examples of valid input clarify the parser’s behavior.
peg
The peg
crate provides a parser generator based on Parsing Expression Grammars (PEGs). PEGs offer a powerful alternative to traditional parsing approaches, combining the ease of writing recursive descent parsers with the declarative nature of grammar specifications. Unlike context-free grammars used by tools like yacc or LALR parsers, PEGs are unambiguous by design - the first matching alternative always wins, eliminating shift/reduce conflicts.
For compiler construction, peg excels at rapidly prototyping language parsers. The grammar syntax closely mirrors how you think about your language’s structure, and the generated parser includes automatic error reporting with position information. PEGs handle unlimited lookahead naturally and support semantic actions directly in the grammar, making it straightforward to build ASTs during parsing.
Grammar Definition
PEG grammars are defined using Rust macros that generate efficient parsers at compile time:
#![allow(unused)] #![module!("peg/src/lib.rs", language)] fn main() { }
This concise grammar definition expands into a complete recursive descent parser with error handling and position tracking.
Expression Parsing
The grammar handles expressions with proper operator precedence:
#![allow(unused)] fn main() { use std::collections::HashMap; #[derive(Debug, Clone, PartialEq)] pub enum BinaryOp { Add, Sub, Mul, Div, Mod, Pow, Eq, Ne, Lt, Le, Gt, Ge, And, Or, Cons, Append, } #[derive(Debug, Clone, PartialEq)] pub enum UnaryOp { Neg, Not, } #[derive(Debug, Clone, PartialEq)] pub enum Statement { Expression(Expr), Definition { name: String, value: Expr, }, TypeDef { name: String, constructors: Vec<(String, Vec<String>)>, }, } #[derive(Debug, Clone, PartialEq)] pub struct Program { pub statements: Vec<Statement>, } /// Error type for parser errors #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub line: usize, pub column: usize, pub expected: Vec<String>, } impl std::fmt::Display for ParseError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "Parse error at line {}, column {}: {}", self.line, self.column, self.message ) } } impl std::error::Error for ParseError {} peg::parser! { pub grammar functional_parser() for str { /// Parse a complete program pub rule program() -> Program = _ statements:statement()* _ { Program { statements } } /// Parse a statement rule statement() -> Statement = definition() / type_definition() / expression_statement() /// Parse a variable definition rule definition() -> Statement = "def" _ name:identifier() _ "=" _ value:expression() _ { Statement::Definition { name, value } } /// Parse a type definition rule type_definition() -> Statement = "type" _ name:identifier() _ "=" _ constructors:constructor_list() _ { Statement::TypeDef { name, constructors } } /// Parse constructor list for type definitions rule constructor_list() -> Vec<(String, Vec<String>)> = head:constructor() tail:(_ "|" _ c:constructor() { c })* { let mut result = vec![head]; result.extend(tail); result } /// Parse a constructor rule constructor() -> (String, Vec<String>) = name:identifier() args:(_ "(" _ args:type_list() _ ")" { args })? { (name, args.unwrap_or_default()) } /// Parse a list of types rule type_list() -> Vec<String> = head:identifier() tail:(_ "," _ t:identifier() { t })* { let mut result = vec![head]; result.extend(tail); result } /// Parse an expression statement rule expression_statement() -> Statement = expr:expression() { Statement::Expression(expr) } /// Parse expressions with left-associative operators pub rule expression() -> Expr = precedence!{ x:(@) _ "||" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Or, right: Box::new(y) } } -- x:(@) _ "&&" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::And, right: Box::new(y) } } -- x:(@) _ "==" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Eq, right: Box::new(y) } } x:(@) _ "!=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Ne, right: Box::new(y) } } -- x:(@) _ "<=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Le, right: Box::new(y) } } x:(@) _ ">=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Ge, right: Box::new(y) } } x:(@) _ "<" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Lt, right: Box::new(y) } } x:(@) _ ">" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Gt, right: Box::new(y) } } -- x:(@) _ "+" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Add, right: Box::new(y) } } x:(@) _ "-" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Sub, right: Box::new(y) } } -- x:(@) _ "*" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Mul, right: Box::new(y) } } x:(@) _ "/" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Div, right: Box::new(y) } } x:(@) _ "%" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Mod, right: Box::new(y) } } -- x:@ _ "**" _ y:(@) { Expr::Binary { left: Box::new(x), op: BinaryOp::Pow, right: Box::new(y) } } -- "-" _ e:@ { Expr::Unary { op: UnaryOp::Neg, expr: Box::new(e) } } "not" _ e:@ { Expr::Unary { op: UnaryOp::Not, expr: Box::new(e) } } -- e:postfix() { e } } /// Postfix expressions (function calls) rule postfix() -> Expr = e:atom() calls:call_suffix()* { calls.into_iter().fold(e, |func, args| { Expr::Call { func: Box::new(func), args } }) } rule call_suffix() -> Vec<Expr> = _ "(" _ args:argument_list() _ ")" { args } /// Parse atomic expressions rule atom() -> Expr = float() // Must come before number / number() / string_literal() / boolean() / list() / record() / lambda() / let_expression() / if_expression() / identifier_expr() / "(" _ e:expression() _ ")" { e } /// Parse numbers (integers only) rule number() -> Expr = n:$("-"? ['0'..='9']+) !("." ['0'..='9']) {? n.parse::<i64>() .map(Expr::Number) .map_err(|_| "number") } /// Parse floating-point numbers rule float() -> Expr = n:$("-"? ['0'..='9']+ "." ['0'..='9']+) {? n.parse::<f64>() .map(Expr::Float) .map_err(|_| "float") } /// Parse string literals rule string_literal() -> Expr = "\"" chars:string_char()* "\"" { Expr::String(chars.into_iter().collect()) } /// Parse string characters with escape sequences rule string_char() -> char = "\\\\" { '\\' } / "\\\"" { '"' } / "\\n" { '\n' } / "\\t" { '\t' } / "\\r" { '\r' } / !['"' | '\\'] c:char() { c } /// Parse any character rule char() -> char = c:$([_]) { c.chars().next().unwrap() } /// Parse boolean literals rule boolean() -> Expr = "true" !identifier_char() { Expr::Bool(true) } / "false" !identifier_char() { Expr::Bool(false) } /// Parse lists rule list() -> Expr = "[" _ elements:expression_list() _ "]" { Expr::List(elements) } /// Parse expression lists rule expression_list() -> Vec<Expr> = head:expression() tail:(_ "," _ e:expression() { e })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse argument lists (for function calls) rule argument_list() -> Vec<Expr> = expression_list() /// Parse records (key-value mappings) rule record() -> Expr = "{" _ fields:field_list() _ "}" { Expr::Record(fields.into_iter().collect()) } /// Parse field lists for records rule field_list() -> Vec<(String, Expr)> = head:field() tail:(_ "," _ f:field() { f })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse a single field rule field() -> (String, Expr) = key:identifier() _ ":" _ value:expression() { (key, value) } /// Parse lambda expressions rule lambda() -> Expr = "\\" _ params:parameter_list() _ "->" _ body:expression() { Expr::Lambda { params, body: Box::new(body) } } / "fn" _ params:parameter_list() _ "->" _ body:expression() { Expr::Lambda { params, body: Box::new(body) } } /// Parse parameter lists rule parameter_list() -> Vec<String> = "(" _ params:identifier_list() _ ")" { params } / param:identifier() { vec![param] } /// Parse identifier lists rule identifier_list() -> Vec<String> = head:identifier() tail:(_ "," _ id:identifier() { id })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse let expressions rule let_expression() -> Expr = "let" _ bindings:binding_list() _ "in" _ body:expression() { Expr::Let { bindings, body: Box::new(body) } } /// Parse binding lists for let expressions rule binding_list() -> Vec<(String, Expr)> = head:binding() tail:(_ "," _ b:binding() { b })* { let mut result = vec![head]; result.extend(tail); result } /// Parse a single binding rule binding() -> (String, Expr) = name:identifier() _ "=" _ value:expression() { (name, value) } /// Parse if expressions rule if_expression() -> Expr = "if" _ cond:expression() _ "then" _ then_branch:expression() else_branch:(_ "else" _ e:expression() { e })? { Expr::If { condition: Box::new(cond), then_branch: Box::new(then_branch), else_branch: else_branch.map(Box::new), } } /// Parse identifier expressions rule identifier_expr() -> Expr = id:identifier() { Expr::Identifier(id) } /// Parse identifiers rule identifier() -> String = !reserved_word() s:$(identifier_start() identifier_char()*) { s.to_string() } rule identifier_start() -> () = ['a'..='z' | 'A'..='Z' | '_'] {} rule identifier_char() -> () = ['a'..='z' | 'A'..='Z' | '0'..='9' | '_'] {} /// Reserved words that can't be identifiers rule reserved_word() = ("if" / "then" / "else" / "let" / "in" / "fn" / "def" / "type" / "true" / "false" / "not") !identifier_char() /// Whitespace rule _() = quiet!{ (whitespace() / comment())* } rule whitespace() = [' ' | '\t' | '\n' | '\r']+ rule comment() = "//" (!"\n" [_])* / "/*" (!"*/" [_])* "*/" } } /// Simple evaluator for mathematical expressions pub fn evaluate(expr: &Expr) -> Result<f64, String> { match expr { Expr::Number(n) => Ok(*n as f64), Expr::Float(f) => Ok(*f), Expr::Binary { left, op, right } => { let l = evaluate(left)?; let r = evaluate(right)?; match op { BinaryOp::Add => Ok(l + r), BinaryOp::Sub => Ok(l - r), BinaryOp::Mul => Ok(l * r), BinaryOp::Div => { if r == 0.0 { Err("Division by zero".to_string()) } else { Ok(l / r) } } BinaryOp::Pow => Ok(l.powf(r)), _ => Err(format!("Cannot evaluate operator {:?}", op)), } } Expr::Unary { op: UnaryOp::Neg, expr, } => Ok(-evaluate(expr)?), _ => Err("Cannot evaluate this expression".to_string()), } } /// Parse a simple expression pub fn parse_expression(input: &str) -> Result<Expr, peg::error::ParseError<peg::str::LineCol>> { functional_parser::expression(input) } /// Parse a complete program pub fn parse_program(input: &str) -> Result<Program, peg::error::ParseError<peg::str::LineCol>> { functional_parser::program(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_number_parsing() { let result = parse_expression("42").unwrap(); assert_eq!(result, Expr::Number(42)); let result = parse_expression("-17").unwrap(); assert_eq!( result, Expr::Unary { op: UnaryOp::Neg, expr: Box::new(Expr::Number(17)) } ); } #[test] fn test_binary_expression() { let result = parse_expression("2 + 3").unwrap(); if let Expr::Binary { left, op, right } = result { assert_eq!(*left, Expr::Number(2)); assert_eq!(op, BinaryOp::Add); assert_eq!(*right, Expr::Number(3)); } else { panic!("Expected binary expression"); } } #[test] fn test_operator_precedence() { let result = parse_expression("2 + 3 * 4").unwrap(); // Should parse as 2 + (3 * 4) if let Expr::Binary { left, op, right } = result { assert_eq!(*left, Expr::Number(2)); assert_eq!(op, BinaryOp::Add); if let Expr::Binary { left: rl, op: rop, right: rr, } = right.as_ref() { assert_eq!(rl.as_ref(), &Expr::Number(3)); assert_eq!(*rop, BinaryOp::Mul); assert_eq!(rr.as_ref(), &Expr::Number(4)); } else { panic!("Expected binary expression on right"); } } else { panic!("Expected binary expression"); } } #[test] fn test_evaluation() { let expr = parse_expression("2 + 3 * 4").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 14.0); let expr = parse_expression("(2 + 3) * 4").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 20.0); let expr = parse_expression("2 ** 3").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 8.0); } #[test] fn test_function_call() { let result = parse_expression("foo(1, 2, 3)").unwrap(); if let Expr::Call { func, args } = result { assert_eq!(*func, Expr::Identifier("foo".to_string())); assert_eq!(args.len(), 3); assert_eq!(args[0], Expr::Number(1)); assert_eq!(args[1], Expr::Number(2)); assert_eq!(args[2], Expr::Number(3)); } else { panic!("Expected function call"); } } #[test] fn test_string_literals() { let result = parse_expression("\"hello world\"").unwrap(); assert_eq!(result, Expr::String("hello world".to_string())); let result = parse_expression("\"escaped\\nnewline\"").unwrap(); assert_eq!(result, Expr::String("escaped\nnewline".to_string())); } #[test] fn test_list_parsing() { let result = parse_expression("[1, 2, 3]").unwrap(); assert_eq!( result, Expr::List(vec![Expr::Number(1), Expr::Number(2), Expr::Number(3)]) ); let result = parse_expression("[]").unwrap(); assert_eq!(result, Expr::List(vec![])); } #[test] fn test_let_expression() { let result = parse_expression("let x = 5 in x + 1").unwrap(); if let Expr::Let { bindings, body } = result { assert_eq!(bindings.len(), 1); assert_eq!(bindings[0].0, "x"); assert_eq!(bindings[0].1, Expr::Number(5)); if let Expr::Binary { left, op, right } = &*body { assert_eq!(**left, Expr::Identifier("x".to_string())); assert_eq!(*op, BinaryOp::Add); assert_eq!(**right, Expr::Number(1)); } else { panic!("Expected binary expression in body"); } } else { panic!("Expected let expression"); } } #[test] fn test_if_expression() { let result = parse_expression("if true then 1 else 2").unwrap(); if let Expr::If { condition, then_branch, else_branch, } = result { assert_eq!(*condition, Expr::Bool(true)); assert_eq!(*then_branch, Expr::Number(1)); assert_eq!( else_branch.as_ref().map(|b| b.as_ref()), Some(&Expr::Number(2)) ); } else { panic!("Expected if expression"); } } #[test] fn test_lambda_expression() { let result = parse_expression("\\x -> x + 1").unwrap(); if let Expr::Lambda { params, body } = result { assert_eq!(params, vec!["x"]); if let Expr::Binary { left, op, right } = &*body { assert_eq!(**left, Expr::Identifier("x".to_string())); assert_eq!(*op, BinaryOp::Add); assert_eq!(**right, Expr::Number(1)); } else { panic!("Expected binary expression in body"); } } else { panic!("Expected lambda expression"); } } #[test] fn test_error_reporting() { let result = parse_expression("2 + "); assert!(result.is_err()); } } /// AST nodes for a functional programming language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Bool(bool), Identifier(String), Binary { left: Box<Expr>, op: BinaryOp, right: Box<Expr>, }, Unary { op: UnaryOp, expr: Box<Expr>, }, Call { func: Box<Expr>, args: Vec<Expr>, }, Lambda { params: Vec<String>, body: Box<Expr>, }, Let { bindings: Vec<(String, Expr)>, body: Box<Expr>, }, If { condition: Box<Expr>, then_branch: Box<Expr>, else_branch: Option<Box<Expr>>, }, List(Vec<Expr>), Record(HashMap<String, Expr>), } }
#![allow(unused)] fn main() { use std::collections::HashMap; /// AST nodes for a functional programming language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Bool(bool), Identifier(String), Binary { left: Box<Expr>, op: BinaryOp, right: Box<Expr>, }, Unary { op: UnaryOp, expr: Box<Expr>, }, Call { func: Box<Expr>, args: Vec<Expr>, }, Lambda { params: Vec<String>, body: Box<Expr>, }, Let { bindings: Vec<(String, Expr)>, body: Box<Expr>, }, If { condition: Box<Expr>, then_branch: Box<Expr>, else_branch: Option<Box<Expr>>, }, List(Vec<Expr>), Record(HashMap<String, Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum UnaryOp { Neg, Not, } #[derive(Debug, Clone, PartialEq)] pub enum Statement { Expression(Expr), Definition { name: String, value: Expr, }, TypeDef { name: String, constructors: Vec<(String, Vec<String>)>, }, } #[derive(Debug, Clone, PartialEq)] pub struct Program { pub statements: Vec<Statement>, } /// Error type for parser errors #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub line: usize, pub column: usize, pub expected: Vec<String>, } impl std::fmt::Display for ParseError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "Parse error at line {}, column {}: {}", self.line, self.column, self.message ) } } impl std::error::Error for ParseError {} peg::parser! { pub grammar functional_parser() for str { /// Parse a complete program pub rule program() -> Program = _ statements:statement()* _ { Program { statements } } /// Parse a statement rule statement() -> Statement = definition() / type_definition() / expression_statement() /// Parse a variable definition rule definition() -> Statement = "def" _ name:identifier() _ "=" _ value:expression() _ { Statement::Definition { name, value } } /// Parse a type definition rule type_definition() -> Statement = "type" _ name:identifier() _ "=" _ constructors:constructor_list() _ { Statement::TypeDef { name, constructors } } /// Parse constructor list for type definitions rule constructor_list() -> Vec<(String, Vec<String>)> = head:constructor() tail:(_ "|" _ c:constructor() { c })* { let mut result = vec![head]; result.extend(tail); result } /// Parse a constructor rule constructor() -> (String, Vec<String>) = name:identifier() args:(_ "(" _ args:type_list() _ ")" { args })? { (name, args.unwrap_or_default()) } /// Parse a list of types rule type_list() -> Vec<String> = head:identifier() tail:(_ "," _ t:identifier() { t })* { let mut result = vec![head]; result.extend(tail); result } /// Parse an expression statement rule expression_statement() -> Statement = expr:expression() { Statement::Expression(expr) } /// Parse expressions with left-associative operators pub rule expression() -> Expr = precedence!{ x:(@) _ "||" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Or, right: Box::new(y) } } -- x:(@) _ "&&" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::And, right: Box::new(y) } } -- x:(@) _ "==" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Eq, right: Box::new(y) } } x:(@) _ "!=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Ne, right: Box::new(y) } } -- x:(@) _ "<=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Le, right: Box::new(y) } } x:(@) _ ">=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Ge, right: Box::new(y) } } x:(@) _ "<" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Lt, right: Box::new(y) } } x:(@) _ ">" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Gt, right: Box::new(y) } } -- x:(@) _ "+" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Add, right: Box::new(y) } } x:(@) _ "-" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Sub, right: Box::new(y) } } -- x:(@) _ "*" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Mul, right: Box::new(y) } } x:(@) _ "/" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Div, right: Box::new(y) } } x:(@) _ "%" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Mod, right: Box::new(y) } } -- x:@ _ "**" _ y:(@) { Expr::Binary { left: Box::new(x), op: BinaryOp::Pow, right: Box::new(y) } } -- "-" _ e:@ { Expr::Unary { op: UnaryOp::Neg, expr: Box::new(e) } } "not" _ e:@ { Expr::Unary { op: UnaryOp::Not, expr: Box::new(e) } } -- e:postfix() { e } } /// Postfix expressions (function calls) rule postfix() -> Expr = e:atom() calls:call_suffix()* { calls.into_iter().fold(e, |func, args| { Expr::Call { func: Box::new(func), args } }) } rule call_suffix() -> Vec<Expr> = _ "(" _ args:argument_list() _ ")" { args } /// Parse atomic expressions rule atom() -> Expr = float() // Must come before number / number() / string_literal() / boolean() / list() / record() / lambda() / let_expression() / if_expression() / identifier_expr() / "(" _ e:expression() _ ")" { e } /// Parse numbers (integers only) rule number() -> Expr = n:$("-"? ['0'..='9']+) !("." ['0'..='9']) {? n.parse::<i64>() .map(Expr::Number) .map_err(|_| "number") } /// Parse floating-point numbers rule float() -> Expr = n:$("-"? ['0'..='9']+ "." ['0'..='9']+) {? n.parse::<f64>() .map(Expr::Float) .map_err(|_| "float") } /// Parse string literals rule string_literal() -> Expr = "\"" chars:string_char()* "\"" { Expr::String(chars.into_iter().collect()) } /// Parse string characters with escape sequences rule string_char() -> char = "\\\\" { '\\' } / "\\\"" { '"' } / "\\n" { '\n' } / "\\t" { '\t' } / "\\r" { '\r' } / !['"' | '\\'] c:char() { c } /// Parse any character rule char() -> char = c:$([_]) { c.chars().next().unwrap() } /// Parse boolean literals rule boolean() -> Expr = "true" !identifier_char() { Expr::Bool(true) } / "false" !identifier_char() { Expr::Bool(false) } /// Parse lists rule list() -> Expr = "[" _ elements:expression_list() _ "]" { Expr::List(elements) } /// Parse expression lists rule expression_list() -> Vec<Expr> = head:expression() tail:(_ "," _ e:expression() { e })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse argument lists (for function calls) rule argument_list() -> Vec<Expr> = expression_list() /// Parse records (key-value mappings) rule record() -> Expr = "{" _ fields:field_list() _ "}" { Expr::Record(fields.into_iter().collect()) } /// Parse field lists for records rule field_list() -> Vec<(String, Expr)> = head:field() tail:(_ "," _ f:field() { f })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse a single field rule field() -> (String, Expr) = key:identifier() _ ":" _ value:expression() { (key, value) } /// Parse lambda expressions rule lambda() -> Expr = "\\" _ params:parameter_list() _ "->" _ body:expression() { Expr::Lambda { params, body: Box::new(body) } } / "fn" _ params:parameter_list() _ "->" _ body:expression() { Expr::Lambda { params, body: Box::new(body) } } /// Parse parameter lists rule parameter_list() -> Vec<String> = "(" _ params:identifier_list() _ ")" { params } / param:identifier() { vec![param] } /// Parse identifier lists rule identifier_list() -> Vec<String> = head:identifier() tail:(_ "," _ id:identifier() { id })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse let expressions rule let_expression() -> Expr = "let" _ bindings:binding_list() _ "in" _ body:expression() { Expr::Let { bindings, body: Box::new(body) } } /// Parse binding lists for let expressions rule binding_list() -> Vec<(String, Expr)> = head:binding() tail:(_ "," _ b:binding() { b })* { let mut result = vec![head]; result.extend(tail); result } /// Parse a single binding rule binding() -> (String, Expr) = name:identifier() _ "=" _ value:expression() { (name, value) } /// Parse if expressions rule if_expression() -> Expr = "if" _ cond:expression() _ "then" _ then_branch:expression() else_branch:(_ "else" _ e:expression() { e })? { Expr::If { condition: Box::new(cond), then_branch: Box::new(then_branch), else_branch: else_branch.map(Box::new), } } /// Parse identifier expressions rule identifier_expr() -> Expr = id:identifier() { Expr::Identifier(id) } /// Parse identifiers rule identifier() -> String = !reserved_word() s:$(identifier_start() identifier_char()*) { s.to_string() } rule identifier_start() -> () = ['a'..='z' | 'A'..='Z' | '_'] {} rule identifier_char() -> () = ['a'..='z' | 'A'..='Z' | '0'..='9' | '_'] {} /// Reserved words that can't be identifiers rule reserved_word() = ("if" / "then" / "else" / "let" / "in" / "fn" / "def" / "type" / "true" / "false" / "not") !identifier_char() /// Whitespace rule _() = quiet!{ (whitespace() / comment())* } rule whitespace() = [' ' | '\t' | '\n' | '\r']+ rule comment() = "//" (!"\n" [_])* / "/*" (!"*/" [_])* "*/" } } /// Simple evaluator for mathematical expressions pub fn evaluate(expr: &Expr) -> Result<f64, String> { match expr { Expr::Number(n) => Ok(*n as f64), Expr::Float(f) => Ok(*f), Expr::Binary { left, op, right } => { let l = evaluate(left)?; let r = evaluate(right)?; match op { BinaryOp::Add => Ok(l + r), BinaryOp::Sub => Ok(l - r), BinaryOp::Mul => Ok(l * r), BinaryOp::Div => { if r == 0.0 { Err("Division by zero".to_string()) } else { Ok(l / r) } } BinaryOp::Pow => Ok(l.powf(r)), _ => Err(format!("Cannot evaluate operator {:?}", op)), } } Expr::Unary { op: UnaryOp::Neg, expr, } => Ok(-evaluate(expr)?), _ => Err("Cannot evaluate this expression".to_string()), } } /// Parse a simple expression pub fn parse_expression(input: &str) -> Result<Expr, peg::error::ParseError<peg::str::LineCol>> { functional_parser::expression(input) } /// Parse a complete program pub fn parse_program(input: &str) -> Result<Program, peg::error::ParseError<peg::str::LineCol>> { functional_parser::program(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_number_parsing() { let result = parse_expression("42").unwrap(); assert_eq!(result, Expr::Number(42)); let result = parse_expression("-17").unwrap(); assert_eq!( result, Expr::Unary { op: UnaryOp::Neg, expr: Box::new(Expr::Number(17)) } ); } #[test] fn test_binary_expression() { let result = parse_expression("2 + 3").unwrap(); if let Expr::Binary { left, op, right } = result { assert_eq!(*left, Expr::Number(2)); assert_eq!(op, BinaryOp::Add); assert_eq!(*right, Expr::Number(3)); } else { panic!("Expected binary expression"); } } #[test] fn test_operator_precedence() { let result = parse_expression("2 + 3 * 4").unwrap(); // Should parse as 2 + (3 * 4) if let Expr::Binary { left, op, right } = result { assert_eq!(*left, Expr::Number(2)); assert_eq!(op, BinaryOp::Add); if let Expr::Binary { left: rl, op: rop, right: rr, } = right.as_ref() { assert_eq!(rl.as_ref(), &Expr::Number(3)); assert_eq!(*rop, BinaryOp::Mul); assert_eq!(rr.as_ref(), &Expr::Number(4)); } else { panic!("Expected binary expression on right"); } } else { panic!("Expected binary expression"); } } #[test] fn test_evaluation() { let expr = parse_expression("2 + 3 * 4").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 14.0); let expr = parse_expression("(2 + 3) * 4").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 20.0); let expr = parse_expression("2 ** 3").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 8.0); } #[test] fn test_function_call() { let result = parse_expression("foo(1, 2, 3)").unwrap(); if let Expr::Call { func, args } = result { assert_eq!(*func, Expr::Identifier("foo".to_string())); assert_eq!(args.len(), 3); assert_eq!(args[0], Expr::Number(1)); assert_eq!(args[1], Expr::Number(2)); assert_eq!(args[2], Expr::Number(3)); } else { panic!("Expected function call"); } } #[test] fn test_string_literals() { let result = parse_expression("\"hello world\"").unwrap(); assert_eq!(result, Expr::String("hello world".to_string())); let result = parse_expression("\"escaped\\nnewline\"").unwrap(); assert_eq!(result, Expr::String("escaped\nnewline".to_string())); } #[test] fn test_list_parsing() { let result = parse_expression("[1, 2, 3]").unwrap(); assert_eq!( result, Expr::List(vec![Expr::Number(1), Expr::Number(2), Expr::Number(3)]) ); let result = parse_expression("[]").unwrap(); assert_eq!(result, Expr::List(vec![])); } #[test] fn test_let_expression() { let result = parse_expression("let x = 5 in x + 1").unwrap(); if let Expr::Let { bindings, body } = result { assert_eq!(bindings.len(), 1); assert_eq!(bindings[0].0, "x"); assert_eq!(bindings[0].1, Expr::Number(5)); if let Expr::Binary { left, op, right } = &*body { assert_eq!(**left, Expr::Identifier("x".to_string())); assert_eq!(*op, BinaryOp::Add); assert_eq!(**right, Expr::Number(1)); } else { panic!("Expected binary expression in body"); } } else { panic!("Expected let expression"); } } #[test] fn test_if_expression() { let result = parse_expression("if true then 1 else 2").unwrap(); if let Expr::If { condition, then_branch, else_branch, } = result { assert_eq!(*condition, Expr::Bool(true)); assert_eq!(*then_branch, Expr::Number(1)); assert_eq!( else_branch.as_ref().map(|b| b.as_ref()), Some(&Expr::Number(2)) ); } else { panic!("Expected if expression"); } } #[test] fn test_lambda_expression() { let result = parse_expression("\\x -> x + 1").unwrap(); if let Expr::Lambda { params, body } = result { assert_eq!(params, vec!["x"]); if let Expr::Binary { left, op, right } = &*body { assert_eq!(**left, Expr::Identifier("x".to_string())); assert_eq!(*op, BinaryOp::Add); assert_eq!(**right, Expr::Number(1)); } else { panic!("Expected binary expression in body"); } } else { panic!("Expected lambda expression"); } } #[test] fn test_error_reporting() { let result = parse_expression("2 + "); assert!(result.is_err()); } } #[derive(Debug, Clone, PartialEq)] pub enum BinaryOp { Add, Sub, Mul, Div, Mod, Pow, Eq, Ne, Lt, Le, Gt, Ge, And, Or, Cons, Append, } }
The precedence climbing in the grammar ensures that 1 + 2 * 3
parses as 1 + (2 * 3)
rather than (1 + 2) * 3
.
Literal Parsing
PEG excels at parsing various literal formats with precise control:
#![allow(unused)] fn main() { use std::collections::HashMap; /// AST nodes for a functional programming language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Bool(bool), Identifier(String), Binary { left: Box<Expr>, op: BinaryOp, right: Box<Expr>, }, Unary { op: UnaryOp, expr: Box<Expr>, }, Call { func: Box<Expr>, args: Vec<Expr>, }, Lambda { params: Vec<String>, body: Box<Expr>, }, Let { bindings: Vec<(String, Expr)>, body: Box<Expr>, }, If { condition: Box<Expr>, then_branch: Box<Expr>, else_branch: Option<Box<Expr>>, }, List(Vec<Expr>), Record(HashMap<String, Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinaryOp { Add, Sub, Mul, Div, Mod, Pow, Eq, Ne, Lt, Le, Gt, Ge, And, Or, Cons, Append, } #[derive(Debug, Clone, PartialEq)] pub enum UnaryOp { Neg, Not, } #[derive(Debug, Clone, PartialEq)] pub enum Statement { Expression(Expr), Definition { name: String, value: Expr, }, TypeDef { name: String, constructors: Vec<(String, Vec<String>)>, }, } #[derive(Debug, Clone, PartialEq)] pub struct Program { pub statements: Vec<Statement>, } /// Error type for parser errors #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub line: usize, pub column: usize, pub expected: Vec<String>, } impl std::fmt::Display for ParseError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "Parse error at line {}, column {}: {}", self.line, self.column, self.message ) } } impl std::error::Error for ParseError {} peg::parser! { pub grammar functional_parser() for str { /// Parse a complete program pub rule program() -> Program = _ statements:statement()* _ { Program { statements } } /// Parse a statement rule statement() -> Statement = definition() / type_definition() / expression_statement() /// Parse a variable definition rule definition() -> Statement = "def" _ name:identifier() _ "=" _ value:expression() _ { Statement::Definition { name, value } } /// Parse a type definition rule type_definition() -> Statement = "type" _ name:identifier() _ "=" _ constructors:constructor_list() _ { Statement::TypeDef { name, constructors } } /// Parse constructor list for type definitions rule constructor_list() -> Vec<(String, Vec<String>)> = head:constructor() tail:(_ "|" _ c:constructor() { c })* { let mut result = vec![head]; result.extend(tail); result } /// Parse a constructor rule constructor() -> (String, Vec<String>) = name:identifier() args:(_ "(" _ args:type_list() _ ")" { args })? { (name, args.unwrap_or_default()) } /// Parse a list of types rule type_list() -> Vec<String> = head:identifier() tail:(_ "," _ t:identifier() { t })* { let mut result = vec![head]; result.extend(tail); result } /// Parse an expression statement rule expression_statement() -> Statement = expr:expression() { Statement::Expression(expr) } /// Parse expressions with left-associative operators pub rule expression() -> Expr = precedence!{ x:(@) _ "||" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Or, right: Box::new(y) } } -- x:(@) _ "&&" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::And, right: Box::new(y) } } -- x:(@) _ "==" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Eq, right: Box::new(y) } } x:(@) _ "!=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Ne, right: Box::new(y) } } -- x:(@) _ "<=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Le, right: Box::new(y) } } x:(@) _ ">=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Ge, right: Box::new(y) } } x:(@) _ "<" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Lt, right: Box::new(y) } } x:(@) _ ">" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Gt, right: Box::new(y) } } -- x:(@) _ "+" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Add, right: Box::new(y) } } x:(@) _ "-" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Sub, right: Box::new(y) } } -- x:(@) _ "*" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Mul, right: Box::new(y) } } x:(@) _ "/" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Div, right: Box::new(y) } } x:(@) _ "%" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Mod, right: Box::new(y) } } -- x:@ _ "**" _ y:(@) { Expr::Binary { left: Box::new(x), op: BinaryOp::Pow, right: Box::new(y) } } -- "-" _ e:@ { Expr::Unary { op: UnaryOp::Neg, expr: Box::new(e) } } "not" _ e:@ { Expr::Unary { op: UnaryOp::Not, expr: Box::new(e) } } -- e:postfix() { e } } /// Postfix expressions (function calls) rule postfix() -> Expr = e:atom() calls:call_suffix()* { calls.into_iter().fold(e, |func, args| { Expr::Call { func: Box::new(func), args } }) } rule call_suffix() -> Vec<Expr> = _ "(" _ args:argument_list() _ ")" { args } /// Parse atomic expressions rule atom() -> Expr = float() // Must come before number / number() / string_literal() / boolean() / list() / record() / lambda() / let_expression() / if_expression() / identifier_expr() / "(" _ e:expression() _ ")" { e } /// Parse numbers (integers only) rule number() -> Expr = n:$("-"? ['0'..='9']+) !("." ['0'..='9']) {? n.parse::<i64>() .map(Expr::Number) .map_err(|_| "number") } /// Parse floating-point numbers rule float() -> Expr = n:$("-"? ['0'..='9']+ "." ['0'..='9']+) {? n.parse::<f64>() .map(Expr::Float) .map_err(|_| "float") } /// Parse string literals rule string_literal() -> Expr = "\"" chars:string_char()* "\"" { Expr::String(chars.into_iter().collect()) } /// Parse string characters with escape sequences rule string_char() -> char = "\\\\" { '\\' } / "\\\"" { '"' } / "\\n" { '\n' } / "\\t" { '\t' } / "\\r" { '\r' } / !['"' | '\\'] c:char() { c } /// Parse any character rule char() -> char = c:$([_]) { c.chars().next().unwrap() } /// Parse boolean literals rule boolean() -> Expr = "true" !identifier_char() { Expr::Bool(true) } / "false" !identifier_char() { Expr::Bool(false) } /// Parse lists rule list() -> Expr = "[" _ elements:expression_list() _ "]" { Expr::List(elements) } /// Parse expression lists rule expression_list() -> Vec<Expr> = head:expression() tail:(_ "," _ e:expression() { e })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse argument lists (for function calls) rule argument_list() -> Vec<Expr> = expression_list() /// Parse records (key-value mappings) rule record() -> Expr = "{" _ fields:field_list() _ "}" { Expr::Record(fields.into_iter().collect()) } /// Parse field lists for records rule field_list() -> Vec<(String, Expr)> = head:field() tail:(_ "," _ f:field() { f })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse a single field rule field() -> (String, Expr) = key:identifier() _ ":" _ value:expression() { (key, value) } /// Parse lambda expressions rule lambda() -> Expr = "\\" _ params:parameter_list() _ "->" _ body:expression() { Expr::Lambda { params, body: Box::new(body) } } / "fn" _ params:parameter_list() _ "->" _ body:expression() { Expr::Lambda { params, body: Box::new(body) } } /// Parse parameter lists rule parameter_list() -> Vec<String> = "(" _ params:identifier_list() _ ")" { params } / param:identifier() { vec![param] } /// Parse identifier lists rule identifier_list() -> Vec<String> = head:identifier() tail:(_ "," _ id:identifier() { id })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse let expressions rule let_expression() -> Expr = "let" _ bindings:binding_list() _ "in" _ body:expression() { Expr::Let { bindings, body: Box::new(body) } } /// Parse binding lists for let expressions rule binding_list() -> Vec<(String, Expr)> = head:binding() tail:(_ "," _ b:binding() { b })* { let mut result = vec![head]; result.extend(tail); result } /// Parse a single binding rule binding() -> (String, Expr) = name:identifier() _ "=" _ value:expression() { (name, value) } /// Parse if expressions rule if_expression() -> Expr = "if" _ cond:expression() _ "then" _ then_branch:expression() else_branch:(_ "else" _ e:expression() { e })? { Expr::If { condition: Box::new(cond), then_branch: Box::new(then_branch), else_branch: else_branch.map(Box::new), } } /// Parse identifier expressions rule identifier_expr() -> Expr = id:identifier() { Expr::Identifier(id) } /// Parse identifiers rule identifier() -> String = !reserved_word() s:$(identifier_start() identifier_char()*) { s.to_string() } rule identifier_start() -> () = ['a'..='z' | 'A'..='Z' | '_'] {} rule identifier_char() -> () = ['a'..='z' | 'A'..='Z' | '0'..='9' | '_'] {} /// Reserved words that can't be identifiers rule reserved_word() = ("if" / "then" / "else" / "let" / "in" / "fn" / "def" / "type" / "true" / "false" / "not") !identifier_char() /// Whitespace rule _() = quiet!{ (whitespace() / comment())* } rule whitespace() = [' ' | '\t' | '\n' | '\r']+ rule comment() = "//" (!"\n" [_])* / "/*" (!"*/" [_])* "*/" } } /// Simple evaluator for mathematical expressions pub fn evaluate(expr: &Expr) -> Result<f64, String> { match expr { Expr::Number(n) => Ok(*n as f64), Expr::Float(f) => Ok(*f), Expr::Binary { left, op, right } => { let l = evaluate(left)?; let r = evaluate(right)?; match op { BinaryOp::Add => Ok(l + r), BinaryOp::Sub => Ok(l - r), BinaryOp::Mul => Ok(l * r), BinaryOp::Div => { if r == 0.0 { Err("Division by zero".to_string()) } else { Ok(l / r) } } BinaryOp::Pow => Ok(l.powf(r)), _ => Err(format!("Cannot evaluate operator {:?}", op)), } } Expr::Unary { op: UnaryOp::Neg, expr, } => Ok(-evaluate(expr)?), _ => Err("Cannot evaluate this expression".to_string()), } } /// Parse a complete program pub fn parse_program(input: &str) -> Result<Program, peg::error::ParseError<peg::str::LineCol>> { functional_parser::program(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_number_parsing() { let result = parse_expression("42").unwrap(); assert_eq!(result, Expr::Number(42)); let result = parse_expression("-17").unwrap(); assert_eq!( result, Expr::Unary { op: UnaryOp::Neg, expr: Box::new(Expr::Number(17)) } ); } #[test] fn test_binary_expression() { let result = parse_expression("2 + 3").unwrap(); if let Expr::Binary { left, op, right } = result { assert_eq!(*left, Expr::Number(2)); assert_eq!(op, BinaryOp::Add); assert_eq!(*right, Expr::Number(3)); } else { panic!("Expected binary expression"); } } #[test] fn test_operator_precedence() { let result = parse_expression("2 + 3 * 4").unwrap(); // Should parse as 2 + (3 * 4) if let Expr::Binary { left, op, right } = result { assert_eq!(*left, Expr::Number(2)); assert_eq!(op, BinaryOp::Add); if let Expr::Binary { left: rl, op: rop, right: rr, } = right.as_ref() { assert_eq!(rl.as_ref(), &Expr::Number(3)); assert_eq!(*rop, BinaryOp::Mul); assert_eq!(rr.as_ref(), &Expr::Number(4)); } else { panic!("Expected binary expression on right"); } } else { panic!("Expected binary expression"); } } #[test] fn test_evaluation() { let expr = parse_expression("2 + 3 * 4").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 14.0); let expr = parse_expression("(2 + 3) * 4").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 20.0); let expr = parse_expression("2 ** 3").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 8.0); } #[test] fn test_function_call() { let result = parse_expression("foo(1, 2, 3)").unwrap(); if let Expr::Call { func, args } = result { assert_eq!(*func, Expr::Identifier("foo".to_string())); assert_eq!(args.len(), 3); assert_eq!(args[0], Expr::Number(1)); assert_eq!(args[1], Expr::Number(2)); assert_eq!(args[2], Expr::Number(3)); } else { panic!("Expected function call"); } } #[test] fn test_string_literals() { let result = parse_expression("\"hello world\"").unwrap(); assert_eq!(result, Expr::String("hello world".to_string())); let result = parse_expression("\"escaped\\nnewline\"").unwrap(); assert_eq!(result, Expr::String("escaped\nnewline".to_string())); } #[test] fn test_list_parsing() { let result = parse_expression("[1, 2, 3]").unwrap(); assert_eq!( result, Expr::List(vec![Expr::Number(1), Expr::Number(2), Expr::Number(3)]) ); let result = parse_expression("[]").unwrap(); assert_eq!(result, Expr::List(vec![])); } #[test] fn test_let_expression() { let result = parse_expression("let x = 5 in x + 1").unwrap(); if let Expr::Let { bindings, body } = result { assert_eq!(bindings.len(), 1); assert_eq!(bindings[0].0, "x"); assert_eq!(bindings[0].1, Expr::Number(5)); if let Expr::Binary { left, op, right } = &*body { assert_eq!(**left, Expr::Identifier("x".to_string())); assert_eq!(*op, BinaryOp::Add); assert_eq!(**right, Expr::Number(1)); } else { panic!("Expected binary expression in body"); } } else { panic!("Expected let expression"); } } #[test] fn test_if_expression() { let result = parse_expression("if true then 1 else 2").unwrap(); if let Expr::If { condition, then_branch, else_branch, } = result { assert_eq!(*condition, Expr::Bool(true)); assert_eq!(*then_branch, Expr::Number(1)); assert_eq!( else_branch.as_ref().map(|b| b.as_ref()), Some(&Expr::Number(2)) ); } else { panic!("Expected if expression"); } } #[test] fn test_lambda_expression() { let result = parse_expression("\\x -> x + 1").unwrap(); if let Expr::Lambda { params, body } = result { assert_eq!(params, vec!["x"]); if let Expr::Binary { left, op, right } = &*body { assert_eq!(**left, Expr::Identifier("x".to_string())); assert_eq!(*op, BinaryOp::Add); assert_eq!(**right, Expr::Number(1)); } else { panic!("Expected binary expression in body"); } } else { panic!("Expected lambda expression"); } } #[test] fn test_error_reporting() { let result = parse_expression("2 + "); assert!(result.is_err()); } } /// Parse a simple expression pub fn parse_expression(input: &str) -> Result<Expr, peg::error::ParseError<peg::str::LineCol>> { functional_parser::expression(input) } }
The grammar handles integers, floats, strings with escape sequences, booleans, and identifiers with appropriate validation rules.
Function Calls and Lambda Expressions
Parsing function application and lambda expressions demonstrates PEG’s ability to handle complex nested structures:
The grammar supports both simple function calls like f(x, y)
and curried application like f x y
, as well as lambda expressions with multiple parameters.
Let Expressions and Conditionals
Structured expressions show how PEG handles keyword-based constructs:
The let
and if
expressions demonstrate how PEG naturally handles indentation-insensitive syntax with clear keyword boundaries.
Lists and Records
Collection types showcase PEG’s repetition and separator handling:
The grammar uses PEG’s repetition operators (**
for separated lists) to elegantly handle comma-separated values with optional trailing commas.
Program Structure
A complete program consists of multiple statements:
#![allow(unused)] fn main() { use std::collections::HashMap; /// AST nodes for a functional programming language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Bool(bool), Identifier(String), Binary { left: Box<Expr>, op: BinaryOp, right: Box<Expr>, }, Unary { op: UnaryOp, expr: Box<Expr>, }, Call { func: Box<Expr>, args: Vec<Expr>, }, Lambda { params: Vec<String>, body: Box<Expr>, }, Let { bindings: Vec<(String, Expr)>, body: Box<Expr>, }, If { condition: Box<Expr>, then_branch: Box<Expr>, else_branch: Option<Box<Expr>>, }, List(Vec<Expr>), Record(HashMap<String, Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinaryOp { Add, Sub, Mul, Div, Mod, Pow, Eq, Ne, Lt, Le, Gt, Ge, And, Or, Cons, Append, } #[derive(Debug, Clone, PartialEq)] pub enum UnaryOp { Neg, Not, } #[derive(Debug, Clone, PartialEq)] pub struct Program { pub statements: Vec<Statement>, } /// Error type for parser errors #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub line: usize, pub column: usize, pub expected: Vec<String>, } impl std::fmt::Display for ParseError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "Parse error at line {}, column {}: {}", self.line, self.column, self.message ) } } impl std::error::Error for ParseError {} peg::parser! { pub grammar functional_parser() for str { /// Parse a complete program pub rule program() -> Program = _ statements:statement()* _ { Program { statements } } /// Parse a statement rule statement() -> Statement = definition() / type_definition() / expression_statement() /// Parse a variable definition rule definition() -> Statement = "def" _ name:identifier() _ "=" _ value:expression() _ { Statement::Definition { name, value } } /// Parse a type definition rule type_definition() -> Statement = "type" _ name:identifier() _ "=" _ constructors:constructor_list() _ { Statement::TypeDef { name, constructors } } /// Parse constructor list for type definitions rule constructor_list() -> Vec<(String, Vec<String>)> = head:constructor() tail:(_ "|" _ c:constructor() { c })* { let mut result = vec![head]; result.extend(tail); result } /// Parse a constructor rule constructor() -> (String, Vec<String>) = name:identifier() args:(_ "(" _ args:type_list() _ ")" { args })? { (name, args.unwrap_or_default()) } /// Parse a list of types rule type_list() -> Vec<String> = head:identifier() tail:(_ "," _ t:identifier() { t })* { let mut result = vec![head]; result.extend(tail); result } /// Parse an expression statement rule expression_statement() -> Statement = expr:expression() { Statement::Expression(expr) } /// Parse expressions with left-associative operators pub rule expression() -> Expr = precedence!{ x:(@) _ "||" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Or, right: Box::new(y) } } -- x:(@) _ "&&" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::And, right: Box::new(y) } } -- x:(@) _ "==" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Eq, right: Box::new(y) } } x:(@) _ "!=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Ne, right: Box::new(y) } } -- x:(@) _ "<=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Le, right: Box::new(y) } } x:(@) _ ">=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Ge, right: Box::new(y) } } x:(@) _ "<" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Lt, right: Box::new(y) } } x:(@) _ ">" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Gt, right: Box::new(y) } } -- x:(@) _ "+" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Add, right: Box::new(y) } } x:(@) _ "-" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Sub, right: Box::new(y) } } -- x:(@) _ "*" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Mul, right: Box::new(y) } } x:(@) _ "/" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Div, right: Box::new(y) } } x:(@) _ "%" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Mod, right: Box::new(y) } } -- x:@ _ "**" _ y:(@) { Expr::Binary { left: Box::new(x), op: BinaryOp::Pow, right: Box::new(y) } } -- "-" _ e:@ { Expr::Unary { op: UnaryOp::Neg, expr: Box::new(e) } } "not" _ e:@ { Expr::Unary { op: UnaryOp::Not, expr: Box::new(e) } } -- e:postfix() { e } } /// Postfix expressions (function calls) rule postfix() -> Expr = e:atom() calls:call_suffix()* { calls.into_iter().fold(e, |func, args| { Expr::Call { func: Box::new(func), args } }) } rule call_suffix() -> Vec<Expr> = _ "(" _ args:argument_list() _ ")" { args } /// Parse atomic expressions rule atom() -> Expr = float() // Must come before number / number() / string_literal() / boolean() / list() / record() / lambda() / let_expression() / if_expression() / identifier_expr() / "(" _ e:expression() _ ")" { e } /// Parse numbers (integers only) rule number() -> Expr = n:$("-"? ['0'..='9']+) !("." ['0'..='9']) {? n.parse::<i64>() .map(Expr::Number) .map_err(|_| "number") } /// Parse floating-point numbers rule float() -> Expr = n:$("-"? ['0'..='9']+ "." ['0'..='9']+) {? n.parse::<f64>() .map(Expr::Float) .map_err(|_| "float") } /// Parse string literals rule string_literal() -> Expr = "\"" chars:string_char()* "\"" { Expr::String(chars.into_iter().collect()) } /// Parse string characters with escape sequences rule string_char() -> char = "\\\\" { '\\' } / "\\\"" { '"' } / "\\n" { '\n' } / "\\t" { '\t' } / "\\r" { '\r' } / !['"' | '\\'] c:char() { c } /// Parse any character rule char() -> char = c:$([_]) { c.chars().next().unwrap() } /// Parse boolean literals rule boolean() -> Expr = "true" !identifier_char() { Expr::Bool(true) } / "false" !identifier_char() { Expr::Bool(false) } /// Parse lists rule list() -> Expr = "[" _ elements:expression_list() _ "]" { Expr::List(elements) } /// Parse expression lists rule expression_list() -> Vec<Expr> = head:expression() tail:(_ "," _ e:expression() { e })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse argument lists (for function calls) rule argument_list() -> Vec<Expr> = expression_list() /// Parse records (key-value mappings) rule record() -> Expr = "{" _ fields:field_list() _ "}" { Expr::Record(fields.into_iter().collect()) } /// Parse field lists for records rule field_list() -> Vec<(String, Expr)> = head:field() tail:(_ "," _ f:field() { f })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse a single field rule field() -> (String, Expr) = key:identifier() _ ":" _ value:expression() { (key, value) } /// Parse lambda expressions rule lambda() -> Expr = "\\" _ params:parameter_list() _ "->" _ body:expression() { Expr::Lambda { params, body: Box::new(body) } } / "fn" _ params:parameter_list() _ "->" _ body:expression() { Expr::Lambda { params, body: Box::new(body) } } /// Parse parameter lists rule parameter_list() -> Vec<String> = "(" _ params:identifier_list() _ ")" { params } / param:identifier() { vec![param] } /// Parse identifier lists rule identifier_list() -> Vec<String> = head:identifier() tail:(_ "," _ id:identifier() { id })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse let expressions rule let_expression() -> Expr = "let" _ bindings:binding_list() _ "in" _ body:expression() { Expr::Let { bindings, body: Box::new(body) } } /// Parse binding lists for let expressions rule binding_list() -> Vec<(String, Expr)> = head:binding() tail:(_ "," _ b:binding() { b })* { let mut result = vec![head]; result.extend(tail); result } /// Parse a single binding rule binding() -> (String, Expr) = name:identifier() _ "=" _ value:expression() { (name, value) } /// Parse if expressions rule if_expression() -> Expr = "if" _ cond:expression() _ "then" _ then_branch:expression() else_branch:(_ "else" _ e:expression() { e })? { Expr::If { condition: Box::new(cond), then_branch: Box::new(then_branch), else_branch: else_branch.map(Box::new), } } /// Parse identifier expressions rule identifier_expr() -> Expr = id:identifier() { Expr::Identifier(id) } /// Parse identifiers rule identifier() -> String = !reserved_word() s:$(identifier_start() identifier_char()*) { s.to_string() } rule identifier_start() -> () = ['a'..='z' | 'A'..='Z' | '_'] {} rule identifier_char() -> () = ['a'..='z' | 'A'..='Z' | '0'..='9' | '_'] {} /// Reserved words that can't be identifiers rule reserved_word() = ("if" / "then" / "else" / "let" / "in" / "fn" / "def" / "type" / "true" / "false" / "not") !identifier_char() /// Whitespace rule _() = quiet!{ (whitespace() / comment())* } rule whitespace() = [' ' | '\t' | '\n' | '\r']+ rule comment() = "//" (!"\n" [_])* / "/*" (!"*/" [_])* "*/" } } /// Simple evaluator for mathematical expressions pub fn evaluate(expr: &Expr) -> Result<f64, String> { match expr { Expr::Number(n) => Ok(*n as f64), Expr::Float(f) => Ok(*f), Expr::Binary { left, op, right } => { let l = evaluate(left)?; let r = evaluate(right)?; match op { BinaryOp::Add => Ok(l + r), BinaryOp::Sub => Ok(l - r), BinaryOp::Mul => Ok(l * r), BinaryOp::Div => { if r == 0.0 { Err("Division by zero".to_string()) } else { Ok(l / r) } } BinaryOp::Pow => Ok(l.powf(r)), _ => Err(format!("Cannot evaluate operator {:?}", op)), } } Expr::Unary { op: UnaryOp::Neg, expr, } => Ok(-evaluate(expr)?), _ => Err("Cannot evaluate this expression".to_string()), } } /// Parse a simple expression pub fn parse_expression(input: &str) -> Result<Expr, peg::error::ParseError<peg::str::LineCol>> { functional_parser::expression(input) } /// Parse a complete program pub fn parse_program(input: &str) -> Result<Program, peg::error::ParseError<peg::str::LineCol>> { functional_parser::program(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_number_parsing() { let result = parse_expression("42").unwrap(); assert_eq!(result, Expr::Number(42)); let result = parse_expression("-17").unwrap(); assert_eq!( result, Expr::Unary { op: UnaryOp::Neg, expr: Box::new(Expr::Number(17)) } ); } #[test] fn test_binary_expression() { let result = parse_expression("2 + 3").unwrap(); if let Expr::Binary { left, op, right } = result { assert_eq!(*left, Expr::Number(2)); assert_eq!(op, BinaryOp::Add); assert_eq!(*right, Expr::Number(3)); } else { panic!("Expected binary expression"); } } #[test] fn test_operator_precedence() { let result = parse_expression("2 + 3 * 4").unwrap(); // Should parse as 2 + (3 * 4) if let Expr::Binary { left, op, right } = result { assert_eq!(*left, Expr::Number(2)); assert_eq!(op, BinaryOp::Add); if let Expr::Binary { left: rl, op: rop, right: rr, } = right.as_ref() { assert_eq!(rl.as_ref(), &Expr::Number(3)); assert_eq!(*rop, BinaryOp::Mul); assert_eq!(rr.as_ref(), &Expr::Number(4)); } else { panic!("Expected binary expression on right"); } } else { panic!("Expected binary expression"); } } #[test] fn test_evaluation() { let expr = parse_expression("2 + 3 * 4").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 14.0); let expr = parse_expression("(2 + 3) * 4").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 20.0); let expr = parse_expression("2 ** 3").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 8.0); } #[test] fn test_function_call() { let result = parse_expression("foo(1, 2, 3)").unwrap(); if let Expr::Call { func, args } = result { assert_eq!(*func, Expr::Identifier("foo".to_string())); assert_eq!(args.len(), 3); assert_eq!(args[0], Expr::Number(1)); assert_eq!(args[1], Expr::Number(2)); assert_eq!(args[2], Expr::Number(3)); } else { panic!("Expected function call"); } } #[test] fn test_string_literals() { let result = parse_expression("\"hello world\"").unwrap(); assert_eq!(result, Expr::String("hello world".to_string())); let result = parse_expression("\"escaped\\nnewline\"").unwrap(); assert_eq!(result, Expr::String("escaped\nnewline".to_string())); } #[test] fn test_list_parsing() { let result = parse_expression("[1, 2, 3]").unwrap(); assert_eq!( result, Expr::List(vec![Expr::Number(1), Expr::Number(2), Expr::Number(3)]) ); let result = parse_expression("[]").unwrap(); assert_eq!(result, Expr::List(vec![])); } #[test] fn test_let_expression() { let result = parse_expression("let x = 5 in x + 1").unwrap(); if let Expr::Let { bindings, body } = result { assert_eq!(bindings.len(), 1); assert_eq!(bindings[0].0, "x"); assert_eq!(bindings[0].1, Expr::Number(5)); if let Expr::Binary { left, op, right } = &*body { assert_eq!(**left, Expr::Identifier("x".to_string())); assert_eq!(*op, BinaryOp::Add); assert_eq!(**right, Expr::Number(1)); } else { panic!("Expected binary expression in body"); } } else { panic!("Expected let expression"); } } #[test] fn test_if_expression() { let result = parse_expression("if true then 1 else 2").unwrap(); if let Expr::If { condition, then_branch, else_branch, } = result { assert_eq!(*condition, Expr::Bool(true)); assert_eq!(*then_branch, Expr::Number(1)); assert_eq!( else_branch.as_ref().map(|b| b.as_ref()), Some(&Expr::Number(2)) ); } else { panic!("Expected if expression"); } } #[test] fn test_lambda_expression() { let result = parse_expression("\\x -> x + 1").unwrap(); if let Expr::Lambda { params, body } = result { assert_eq!(params, vec!["x"]); if let Expr::Binary { left, op, right } = &*body { assert_eq!(**left, Expr::Identifier("x".to_string())); assert_eq!(*op, BinaryOp::Add); assert_eq!(**right, Expr::Number(1)); } else { panic!("Expected binary expression in body"); } } else { panic!("Expected lambda expression"); } } #[test] fn test_error_reporting() { let result = parse_expression("2 + "); assert!(result.is_err()); } } #[derive(Debug, Clone, PartialEq)] pub enum Statement { Expression(Expr), Definition { name: String, value: Expr, }, TypeDef { name: String, constructors: Vec<(String, Vec<String>)>, }, } }
#![allow(unused)] fn main() { use std::collections::HashMap; /// AST nodes for a functional programming language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Bool(bool), Identifier(String), Binary { left: Box<Expr>, op: BinaryOp, right: Box<Expr>, }, Unary { op: UnaryOp, expr: Box<Expr>, }, Call { func: Box<Expr>, args: Vec<Expr>, }, Lambda { params: Vec<String>, body: Box<Expr>, }, Let { bindings: Vec<(String, Expr)>, body: Box<Expr>, }, If { condition: Box<Expr>, then_branch: Box<Expr>, else_branch: Option<Box<Expr>>, }, List(Vec<Expr>), Record(HashMap<String, Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinaryOp { Add, Sub, Mul, Div, Mod, Pow, Eq, Ne, Lt, Le, Gt, Ge, And, Or, Cons, Append, } #[derive(Debug, Clone, PartialEq)] pub enum UnaryOp { Neg, Not, } #[derive(Debug, Clone, PartialEq)] pub enum Statement { Expression(Expr), Definition { name: String, value: Expr, }, TypeDef { name: String, constructors: Vec<(String, Vec<String>)>, }, } /// Error type for parser errors #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub line: usize, pub column: usize, pub expected: Vec<String>, } impl std::fmt::Display for ParseError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "Parse error at line {}, column {}: {}", self.line, self.column, self.message ) } } impl std::error::Error for ParseError {} peg::parser! { pub grammar functional_parser() for str { /// Parse a complete program pub rule program() -> Program = _ statements:statement()* _ { Program { statements } } /// Parse a statement rule statement() -> Statement = definition() / type_definition() / expression_statement() /// Parse a variable definition rule definition() -> Statement = "def" _ name:identifier() _ "=" _ value:expression() _ { Statement::Definition { name, value } } /// Parse a type definition rule type_definition() -> Statement = "type" _ name:identifier() _ "=" _ constructors:constructor_list() _ { Statement::TypeDef { name, constructors } } /// Parse constructor list for type definitions rule constructor_list() -> Vec<(String, Vec<String>)> = head:constructor() tail:(_ "|" _ c:constructor() { c })* { let mut result = vec![head]; result.extend(tail); result } /// Parse a constructor rule constructor() -> (String, Vec<String>) = name:identifier() args:(_ "(" _ args:type_list() _ ")" { args })? { (name, args.unwrap_or_default()) } /// Parse a list of types rule type_list() -> Vec<String> = head:identifier() tail:(_ "," _ t:identifier() { t })* { let mut result = vec![head]; result.extend(tail); result } /// Parse an expression statement rule expression_statement() -> Statement = expr:expression() { Statement::Expression(expr) } /// Parse expressions with left-associative operators pub rule expression() -> Expr = precedence!{ x:(@) _ "||" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Or, right: Box::new(y) } } -- x:(@) _ "&&" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::And, right: Box::new(y) } } -- x:(@) _ "==" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Eq, right: Box::new(y) } } x:(@) _ "!=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Ne, right: Box::new(y) } } -- x:(@) _ "<=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Le, right: Box::new(y) } } x:(@) _ ">=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Ge, right: Box::new(y) } } x:(@) _ "<" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Lt, right: Box::new(y) } } x:(@) _ ">" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Gt, right: Box::new(y) } } -- x:(@) _ "+" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Add, right: Box::new(y) } } x:(@) _ "-" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Sub, right: Box::new(y) } } -- x:(@) _ "*" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Mul, right: Box::new(y) } } x:(@) _ "/" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Div, right: Box::new(y) } } x:(@) _ "%" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Mod, right: Box::new(y) } } -- x:@ _ "**" _ y:(@) { Expr::Binary { left: Box::new(x), op: BinaryOp::Pow, right: Box::new(y) } } -- "-" _ e:@ { Expr::Unary { op: UnaryOp::Neg, expr: Box::new(e) } } "not" _ e:@ { Expr::Unary { op: UnaryOp::Not, expr: Box::new(e) } } -- e:postfix() { e } } /// Postfix expressions (function calls) rule postfix() -> Expr = e:atom() calls:call_suffix()* { calls.into_iter().fold(e, |func, args| { Expr::Call { func: Box::new(func), args } }) } rule call_suffix() -> Vec<Expr> = _ "(" _ args:argument_list() _ ")" { args } /// Parse atomic expressions rule atom() -> Expr = float() // Must come before number / number() / string_literal() / boolean() / list() / record() / lambda() / let_expression() / if_expression() / identifier_expr() / "(" _ e:expression() _ ")" { e } /// Parse numbers (integers only) rule number() -> Expr = n:$("-"? ['0'..='9']+) !("." ['0'..='9']) {? n.parse::<i64>() .map(Expr::Number) .map_err(|_| "number") } /// Parse floating-point numbers rule float() -> Expr = n:$("-"? ['0'..='9']+ "." ['0'..='9']+) {? n.parse::<f64>() .map(Expr::Float) .map_err(|_| "float") } /// Parse string literals rule string_literal() -> Expr = "\"" chars:string_char()* "\"" { Expr::String(chars.into_iter().collect()) } /// Parse string characters with escape sequences rule string_char() -> char = "\\\\" { '\\' } / "\\\"" { '"' } / "\\n" { '\n' } / "\\t" { '\t' } / "\\r" { '\r' } / !['"' | '\\'] c:char() { c } /// Parse any character rule char() -> char = c:$([_]) { c.chars().next().unwrap() } /// Parse boolean literals rule boolean() -> Expr = "true" !identifier_char() { Expr::Bool(true) } / "false" !identifier_char() { Expr::Bool(false) } /// Parse lists rule list() -> Expr = "[" _ elements:expression_list() _ "]" { Expr::List(elements) } /// Parse expression lists rule expression_list() -> Vec<Expr> = head:expression() tail:(_ "," _ e:expression() { e })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse argument lists (for function calls) rule argument_list() -> Vec<Expr> = expression_list() /// Parse records (key-value mappings) rule record() -> Expr = "{" _ fields:field_list() _ "}" { Expr::Record(fields.into_iter().collect()) } /// Parse field lists for records rule field_list() -> Vec<(String, Expr)> = head:field() tail:(_ "," _ f:field() { f })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse a single field rule field() -> (String, Expr) = key:identifier() _ ":" _ value:expression() { (key, value) } /// Parse lambda expressions rule lambda() -> Expr = "\\" _ params:parameter_list() _ "->" _ body:expression() { Expr::Lambda { params, body: Box::new(body) } } / "fn" _ params:parameter_list() _ "->" _ body:expression() { Expr::Lambda { params, body: Box::new(body) } } /// Parse parameter lists rule parameter_list() -> Vec<String> = "(" _ params:identifier_list() _ ")" { params } / param:identifier() { vec![param] } /// Parse identifier lists rule identifier_list() -> Vec<String> = head:identifier() tail:(_ "," _ id:identifier() { id })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse let expressions rule let_expression() -> Expr = "let" _ bindings:binding_list() _ "in" _ body:expression() { Expr::Let { bindings, body: Box::new(body) } } /// Parse binding lists for let expressions rule binding_list() -> Vec<(String, Expr)> = head:binding() tail:(_ "," _ b:binding() { b })* { let mut result = vec![head]; result.extend(tail); result } /// Parse a single binding rule binding() -> (String, Expr) = name:identifier() _ "=" _ value:expression() { (name, value) } /// Parse if expressions rule if_expression() -> Expr = "if" _ cond:expression() _ "then" _ then_branch:expression() else_branch:(_ "else" _ e:expression() { e })? { Expr::If { condition: Box::new(cond), then_branch: Box::new(then_branch), else_branch: else_branch.map(Box::new), } } /// Parse identifier expressions rule identifier_expr() -> Expr = id:identifier() { Expr::Identifier(id) } /// Parse identifiers rule identifier() -> String = !reserved_word() s:$(identifier_start() identifier_char()*) { s.to_string() } rule identifier_start() -> () = ['a'..='z' | 'A'..='Z' | '_'] {} rule identifier_char() -> () = ['a'..='z' | 'A'..='Z' | '0'..='9' | '_'] {} /// Reserved words that can't be identifiers rule reserved_word() = ("if" / "then" / "else" / "let" / "in" / "fn" / "def" / "type" / "true" / "false" / "not") !identifier_char() /// Whitespace rule _() = quiet!{ (whitespace() / comment())* } rule whitespace() = [' ' | '\t' | '\n' | '\r']+ rule comment() = "//" (!"\n" [_])* / "/*" (!"*/" [_])* "*/" } } /// Simple evaluator for mathematical expressions pub fn evaluate(expr: &Expr) -> Result<f64, String> { match expr { Expr::Number(n) => Ok(*n as f64), Expr::Float(f) => Ok(*f), Expr::Binary { left, op, right } => { let l = evaluate(left)?; let r = evaluate(right)?; match op { BinaryOp::Add => Ok(l + r), BinaryOp::Sub => Ok(l - r), BinaryOp::Mul => Ok(l * r), BinaryOp::Div => { if r == 0.0 { Err("Division by zero".to_string()) } else { Ok(l / r) } } BinaryOp::Pow => Ok(l.powf(r)), _ => Err(format!("Cannot evaluate operator {:?}", op)), } } Expr::Unary { op: UnaryOp::Neg, expr, } => Ok(-evaluate(expr)?), _ => Err("Cannot evaluate this expression".to_string()), } } /// Parse a simple expression pub fn parse_expression(input: &str) -> Result<Expr, peg::error::ParseError<peg::str::LineCol>> { functional_parser::expression(input) } /// Parse a complete program pub fn parse_program(input: &str) -> Result<Program, peg::error::ParseError<peg::str::LineCol>> { functional_parser::program(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_number_parsing() { let result = parse_expression("42").unwrap(); assert_eq!(result, Expr::Number(42)); let result = parse_expression("-17").unwrap(); assert_eq!( result, Expr::Unary { op: UnaryOp::Neg, expr: Box::new(Expr::Number(17)) } ); } #[test] fn test_binary_expression() { let result = parse_expression("2 + 3").unwrap(); if let Expr::Binary { left, op, right } = result { assert_eq!(*left, Expr::Number(2)); assert_eq!(op, BinaryOp::Add); assert_eq!(*right, Expr::Number(3)); } else { panic!("Expected binary expression"); } } #[test] fn test_operator_precedence() { let result = parse_expression("2 + 3 * 4").unwrap(); // Should parse as 2 + (3 * 4) if let Expr::Binary { left, op, right } = result { assert_eq!(*left, Expr::Number(2)); assert_eq!(op, BinaryOp::Add); if let Expr::Binary { left: rl, op: rop, right: rr, } = right.as_ref() { assert_eq!(rl.as_ref(), &Expr::Number(3)); assert_eq!(*rop, BinaryOp::Mul); assert_eq!(rr.as_ref(), &Expr::Number(4)); } else { panic!("Expected binary expression on right"); } } else { panic!("Expected binary expression"); } } #[test] fn test_evaluation() { let expr = parse_expression("2 + 3 * 4").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 14.0); let expr = parse_expression("(2 + 3) * 4").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 20.0); let expr = parse_expression("2 ** 3").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 8.0); } #[test] fn test_function_call() { let result = parse_expression("foo(1, 2, 3)").unwrap(); if let Expr::Call { func, args } = result { assert_eq!(*func, Expr::Identifier("foo".to_string())); assert_eq!(args.len(), 3); assert_eq!(args[0], Expr::Number(1)); assert_eq!(args[1], Expr::Number(2)); assert_eq!(args[2], Expr::Number(3)); } else { panic!("Expected function call"); } } #[test] fn test_string_literals() { let result = parse_expression("\"hello world\"").unwrap(); assert_eq!(result, Expr::String("hello world".to_string())); let result = parse_expression("\"escaped\\nnewline\"").unwrap(); assert_eq!(result, Expr::String("escaped\nnewline".to_string())); } #[test] fn test_list_parsing() { let result = parse_expression("[1, 2, 3]").unwrap(); assert_eq!( result, Expr::List(vec![Expr::Number(1), Expr::Number(2), Expr::Number(3)]) ); let result = parse_expression("[]").unwrap(); assert_eq!(result, Expr::List(vec![])); } #[test] fn test_let_expression() { let result = parse_expression("let x = 5 in x + 1").unwrap(); if let Expr::Let { bindings, body } = result { assert_eq!(bindings.len(), 1); assert_eq!(bindings[0].0, "x"); assert_eq!(bindings[0].1, Expr::Number(5)); if let Expr::Binary { left, op, right } = &*body { assert_eq!(**left, Expr::Identifier("x".to_string())); assert_eq!(*op, BinaryOp::Add); assert_eq!(**right, Expr::Number(1)); } else { panic!("Expected binary expression in body"); } } else { panic!("Expected let expression"); } } #[test] fn test_if_expression() { let result = parse_expression("if true then 1 else 2").unwrap(); if let Expr::If { condition, then_branch, else_branch, } = result { assert_eq!(*condition, Expr::Bool(true)); assert_eq!(*then_branch, Expr::Number(1)); assert_eq!( else_branch.as_ref().map(|b| b.as_ref()), Some(&Expr::Number(2)) ); } else { panic!("Expected if expression"); } } #[test] fn test_lambda_expression() { let result = parse_expression("\\x -> x + 1").unwrap(); if let Expr::Lambda { params, body } = result { assert_eq!(params, vec!["x"]); if let Expr::Binary { left, op, right } = &*body { assert_eq!(**left, Expr::Identifier("x".to_string())); assert_eq!(*op, BinaryOp::Add); assert_eq!(**right, Expr::Number(1)); } else { panic!("Expected binary expression in body"); } } else { panic!("Expected lambda expression"); } } #[test] fn test_error_reporting() { let result = parse_expression("2 + "); assert!(result.is_err()); } } #[derive(Debug, Clone, PartialEq)] pub struct Program { pub statements: Vec<Statement>, } }
#![allow(unused)] fn main() { use std::collections::HashMap; /// AST nodes for a functional programming language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Bool(bool), Identifier(String), Binary { left: Box<Expr>, op: BinaryOp, right: Box<Expr>, }, Unary { op: UnaryOp, expr: Box<Expr>, }, Call { func: Box<Expr>, args: Vec<Expr>, }, Lambda { params: Vec<String>, body: Box<Expr>, }, Let { bindings: Vec<(String, Expr)>, body: Box<Expr>, }, If { condition: Box<Expr>, then_branch: Box<Expr>, else_branch: Option<Box<Expr>>, }, List(Vec<Expr>), Record(HashMap<String, Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinaryOp { Add, Sub, Mul, Div, Mod, Pow, Eq, Ne, Lt, Le, Gt, Ge, And, Or, Cons, Append, } #[derive(Debug, Clone, PartialEq)] pub enum UnaryOp { Neg, Not, } #[derive(Debug, Clone, PartialEq)] pub enum Statement { Expression(Expr), Definition { name: String, value: Expr, }, TypeDef { name: String, constructors: Vec<(String, Vec<String>)>, }, } #[derive(Debug, Clone, PartialEq)] pub struct Program { pub statements: Vec<Statement>, } /// Error type for parser errors #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub line: usize, pub column: usize, pub expected: Vec<String>, } impl std::fmt::Display for ParseError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "Parse error at line {}, column {}: {}", self.line, self.column, self.message ) } } impl std::error::Error for ParseError {} peg::parser! { pub grammar functional_parser() for str { /// Parse a complete program pub rule program() -> Program = _ statements:statement()* _ { Program { statements } } /// Parse a statement rule statement() -> Statement = definition() / type_definition() / expression_statement() /// Parse a variable definition rule definition() -> Statement = "def" _ name:identifier() _ "=" _ value:expression() _ { Statement::Definition { name, value } } /// Parse a type definition rule type_definition() -> Statement = "type" _ name:identifier() _ "=" _ constructors:constructor_list() _ { Statement::TypeDef { name, constructors } } /// Parse constructor list for type definitions rule constructor_list() -> Vec<(String, Vec<String>)> = head:constructor() tail:(_ "|" _ c:constructor() { c })* { let mut result = vec![head]; result.extend(tail); result } /// Parse a constructor rule constructor() -> (String, Vec<String>) = name:identifier() args:(_ "(" _ args:type_list() _ ")" { args })? { (name, args.unwrap_or_default()) } /// Parse a list of types rule type_list() -> Vec<String> = head:identifier() tail:(_ "," _ t:identifier() { t })* { let mut result = vec![head]; result.extend(tail); result } /// Parse an expression statement rule expression_statement() -> Statement = expr:expression() { Statement::Expression(expr) } /// Parse expressions with left-associative operators pub rule expression() -> Expr = precedence!{ x:(@) _ "||" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Or, right: Box::new(y) } } -- x:(@) _ "&&" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::And, right: Box::new(y) } } -- x:(@) _ "==" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Eq, right: Box::new(y) } } x:(@) _ "!=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Ne, right: Box::new(y) } } -- x:(@) _ "<=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Le, right: Box::new(y) } } x:(@) _ ">=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Ge, right: Box::new(y) } } x:(@) _ "<" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Lt, right: Box::new(y) } } x:(@) _ ">" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Gt, right: Box::new(y) } } -- x:(@) _ "+" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Add, right: Box::new(y) } } x:(@) _ "-" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Sub, right: Box::new(y) } } -- x:(@) _ "*" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Mul, right: Box::new(y) } } x:(@) _ "/" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Div, right: Box::new(y) } } x:(@) _ "%" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Mod, right: Box::new(y) } } -- x:@ _ "**" _ y:(@) { Expr::Binary { left: Box::new(x), op: BinaryOp::Pow, right: Box::new(y) } } -- "-" _ e:@ { Expr::Unary { op: UnaryOp::Neg, expr: Box::new(e) } } "not" _ e:@ { Expr::Unary { op: UnaryOp::Not, expr: Box::new(e) } } -- e:postfix() { e } } /// Postfix expressions (function calls) rule postfix() -> Expr = e:atom() calls:call_suffix()* { calls.into_iter().fold(e, |func, args| { Expr::Call { func: Box::new(func), args } }) } rule call_suffix() -> Vec<Expr> = _ "(" _ args:argument_list() _ ")" { args } /// Parse atomic expressions rule atom() -> Expr = float() // Must come before number / number() / string_literal() / boolean() / list() / record() / lambda() / let_expression() / if_expression() / identifier_expr() / "(" _ e:expression() _ ")" { e } /// Parse numbers (integers only) rule number() -> Expr = n:$("-"? ['0'..='9']+) !("." ['0'..='9']) {? n.parse::<i64>() .map(Expr::Number) .map_err(|_| "number") } /// Parse floating-point numbers rule float() -> Expr = n:$("-"? ['0'..='9']+ "." ['0'..='9']+) {? n.parse::<f64>() .map(Expr::Float) .map_err(|_| "float") } /// Parse string literals rule string_literal() -> Expr = "\"" chars:string_char()* "\"" { Expr::String(chars.into_iter().collect()) } /// Parse string characters with escape sequences rule string_char() -> char = "\\\\" { '\\' } / "\\\"" { '"' } / "\\n" { '\n' } / "\\t" { '\t' } / "\\r" { '\r' } / !['"' | '\\'] c:char() { c } /// Parse any character rule char() -> char = c:$([_]) { c.chars().next().unwrap() } /// Parse boolean literals rule boolean() -> Expr = "true" !identifier_char() { Expr::Bool(true) } / "false" !identifier_char() { Expr::Bool(false) } /// Parse lists rule list() -> Expr = "[" _ elements:expression_list() _ "]" { Expr::List(elements) } /// Parse expression lists rule expression_list() -> Vec<Expr> = head:expression() tail:(_ "," _ e:expression() { e })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse argument lists (for function calls) rule argument_list() -> Vec<Expr> = expression_list() /// Parse records (key-value mappings) rule record() -> Expr = "{" _ fields:field_list() _ "}" { Expr::Record(fields.into_iter().collect()) } /// Parse field lists for records rule field_list() -> Vec<(String, Expr)> = head:field() tail:(_ "," _ f:field() { f })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse a single field rule field() -> (String, Expr) = key:identifier() _ ":" _ value:expression() { (key, value) } /// Parse lambda expressions rule lambda() -> Expr = "\\" _ params:parameter_list() _ "->" _ body:expression() { Expr::Lambda { params, body: Box::new(body) } } / "fn" _ params:parameter_list() _ "->" _ body:expression() { Expr::Lambda { params, body: Box::new(body) } } /// Parse parameter lists rule parameter_list() -> Vec<String> = "(" _ params:identifier_list() _ ")" { params } / param:identifier() { vec![param] } /// Parse identifier lists rule identifier_list() -> Vec<String> = head:identifier() tail:(_ "," _ id:identifier() { id })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse let expressions rule let_expression() -> Expr = "let" _ bindings:binding_list() _ "in" _ body:expression() { Expr::Let { bindings, body: Box::new(body) } } /// Parse binding lists for let expressions rule binding_list() -> Vec<(String, Expr)> = head:binding() tail:(_ "," _ b:binding() { b })* { let mut result = vec![head]; result.extend(tail); result } /// Parse a single binding rule binding() -> (String, Expr) = name:identifier() _ "=" _ value:expression() { (name, value) } /// Parse if expressions rule if_expression() -> Expr = "if" _ cond:expression() _ "then" _ then_branch:expression() else_branch:(_ "else" _ e:expression() { e })? { Expr::If { condition: Box::new(cond), then_branch: Box::new(then_branch), else_branch: else_branch.map(Box::new), } } /// Parse identifier expressions rule identifier_expr() -> Expr = id:identifier() { Expr::Identifier(id) } /// Parse identifiers rule identifier() -> String = !reserved_word() s:$(identifier_start() identifier_char()*) { s.to_string() } rule identifier_start() -> () = ['a'..='z' | 'A'..='Z' | '_'] {} rule identifier_char() -> () = ['a'..='z' | 'A'..='Z' | '0'..='9' | '_'] {} /// Reserved words that can't be identifiers rule reserved_word() = ("if" / "then" / "else" / "let" / "in" / "fn" / "def" / "type" / "true" / "false" / "not") !identifier_char() /// Whitespace rule _() = quiet!{ (whitespace() / comment())* } rule whitespace() = [' ' | '\t' | '\n' | '\r']+ rule comment() = "//" (!"\n" [_])* / "/*" (!"*/" [_])* "*/" } } /// Simple evaluator for mathematical expressions pub fn evaluate(expr: &Expr) -> Result<f64, String> { match expr { Expr::Number(n) => Ok(*n as f64), Expr::Float(f) => Ok(*f), Expr::Binary { left, op, right } => { let l = evaluate(left)?; let r = evaluate(right)?; match op { BinaryOp::Add => Ok(l + r), BinaryOp::Sub => Ok(l - r), BinaryOp::Mul => Ok(l * r), BinaryOp::Div => { if r == 0.0 { Err("Division by zero".to_string()) } else { Ok(l / r) } } BinaryOp::Pow => Ok(l.powf(r)), _ => Err(format!("Cannot evaluate operator {:?}", op)), } } Expr::Unary { op: UnaryOp::Neg, expr, } => Ok(-evaluate(expr)?), _ => Err("Cannot evaluate this expression".to_string()), } } /// Parse a simple expression pub fn parse_expression(input: &str) -> Result<Expr, peg::error::ParseError<peg::str::LineCol>> { functional_parser::expression(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_number_parsing() { let result = parse_expression("42").unwrap(); assert_eq!(result, Expr::Number(42)); let result = parse_expression("-17").unwrap(); assert_eq!( result, Expr::Unary { op: UnaryOp::Neg, expr: Box::new(Expr::Number(17)) } ); } #[test] fn test_binary_expression() { let result = parse_expression("2 + 3").unwrap(); if let Expr::Binary { left, op, right } = result { assert_eq!(*left, Expr::Number(2)); assert_eq!(op, BinaryOp::Add); assert_eq!(*right, Expr::Number(3)); } else { panic!("Expected binary expression"); } } #[test] fn test_operator_precedence() { let result = parse_expression("2 + 3 * 4").unwrap(); // Should parse as 2 + (3 * 4) if let Expr::Binary { left, op, right } = result { assert_eq!(*left, Expr::Number(2)); assert_eq!(op, BinaryOp::Add); if let Expr::Binary { left: rl, op: rop, right: rr, } = right.as_ref() { assert_eq!(rl.as_ref(), &Expr::Number(3)); assert_eq!(*rop, BinaryOp::Mul); assert_eq!(rr.as_ref(), &Expr::Number(4)); } else { panic!("Expected binary expression on right"); } } else { panic!("Expected binary expression"); } } #[test] fn test_evaluation() { let expr = parse_expression("2 + 3 * 4").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 14.0); let expr = parse_expression("(2 + 3) * 4").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 20.0); let expr = parse_expression("2 ** 3").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 8.0); } #[test] fn test_function_call() { let result = parse_expression("foo(1, 2, 3)").unwrap(); if let Expr::Call { func, args } = result { assert_eq!(*func, Expr::Identifier("foo".to_string())); assert_eq!(args.len(), 3); assert_eq!(args[0], Expr::Number(1)); assert_eq!(args[1], Expr::Number(2)); assert_eq!(args[2], Expr::Number(3)); } else { panic!("Expected function call"); } } #[test] fn test_string_literals() { let result = parse_expression("\"hello world\"").unwrap(); assert_eq!(result, Expr::String("hello world".to_string())); let result = parse_expression("\"escaped\\nnewline\"").unwrap(); assert_eq!(result, Expr::String("escaped\nnewline".to_string())); } #[test] fn test_list_parsing() { let result = parse_expression("[1, 2, 3]").unwrap(); assert_eq!( result, Expr::List(vec![Expr::Number(1), Expr::Number(2), Expr::Number(3)]) ); let result = parse_expression("[]").unwrap(); assert_eq!(result, Expr::List(vec![])); } #[test] fn test_let_expression() { let result = parse_expression("let x = 5 in x + 1").unwrap(); if let Expr::Let { bindings, body } = result { assert_eq!(bindings.len(), 1); assert_eq!(bindings[0].0, "x"); assert_eq!(bindings[0].1, Expr::Number(5)); if let Expr::Binary { left, op, right } = &*body { assert_eq!(**left, Expr::Identifier("x".to_string())); assert_eq!(*op, BinaryOp::Add); assert_eq!(**right, Expr::Number(1)); } else { panic!("Expected binary expression in body"); } } else { panic!("Expected let expression"); } } #[test] fn test_if_expression() { let result = parse_expression("if true then 1 else 2").unwrap(); if let Expr::If { condition, then_branch, else_branch, } = result { assert_eq!(*condition, Expr::Bool(true)); assert_eq!(*then_branch, Expr::Number(1)); assert_eq!( else_branch.as_ref().map(|b| b.as_ref()), Some(&Expr::Number(2)) ); } else { panic!("Expected if expression"); } } #[test] fn test_lambda_expression() { let result = parse_expression("\\x -> x + 1").unwrap(); if let Expr::Lambda { params, body } = result { assert_eq!(params, vec!["x"]); if let Expr::Binary { left, op, right } = &*body { assert_eq!(**left, Expr::Identifier("x".to_string())); assert_eq!(*op, BinaryOp::Add); assert_eq!(**right, Expr::Number(1)); } else { panic!("Expected binary expression in body"); } } else { panic!("Expected lambda expression"); } } #[test] fn test_error_reporting() { let result = parse_expression("2 + "); assert!(result.is_err()); } } /// Parse a complete program pub fn parse_program(input: &str) -> Result<Program, peg::error::ParseError<peg::str::LineCol>> { functional_parser::program(input) } }
This structure supports mixing expressions, definitions, and type declarations in a program.
Type Definitions
The grammar includes support for algebraic data types through the Statement enum:
#![allow(unused)] fn main() { use std::collections::HashMap; /// AST nodes for a functional programming language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Bool(bool), Identifier(String), Binary { left: Box<Expr>, op: BinaryOp, right: Box<Expr>, }, Unary { op: UnaryOp, expr: Box<Expr>, }, Call { func: Box<Expr>, args: Vec<Expr>, }, Lambda { params: Vec<String>, body: Box<Expr>, }, Let { bindings: Vec<(String, Expr)>, body: Box<Expr>, }, If { condition: Box<Expr>, then_branch: Box<Expr>, else_branch: Option<Box<Expr>>, }, List(Vec<Expr>), Record(HashMap<String, Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinaryOp { Add, Sub, Mul, Div, Mod, Pow, Eq, Ne, Lt, Le, Gt, Ge, And, Or, Cons, Append, } #[derive(Debug, Clone, PartialEq)] pub enum UnaryOp { Neg, Not, } #[derive(Debug, Clone, PartialEq)] pub struct Program { pub statements: Vec<Statement>, } /// Error type for parser errors #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub line: usize, pub column: usize, pub expected: Vec<String>, } impl std::fmt::Display for ParseError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "Parse error at line {}, column {}: {}", self.line, self.column, self.message ) } } impl std::error::Error for ParseError {} peg::parser! { pub grammar functional_parser() for str { /// Parse a complete program pub rule program() -> Program = _ statements:statement()* _ { Program { statements } } /// Parse a statement rule statement() -> Statement = definition() / type_definition() / expression_statement() /// Parse a variable definition rule definition() -> Statement = "def" _ name:identifier() _ "=" _ value:expression() _ { Statement::Definition { name, value } } /// Parse a type definition rule type_definition() -> Statement = "type" _ name:identifier() _ "=" _ constructors:constructor_list() _ { Statement::TypeDef { name, constructors } } /// Parse constructor list for type definitions rule constructor_list() -> Vec<(String, Vec<String>)> = head:constructor() tail:(_ "|" _ c:constructor() { c })* { let mut result = vec![head]; result.extend(tail); result } /// Parse a constructor rule constructor() -> (String, Vec<String>) = name:identifier() args:(_ "(" _ args:type_list() _ ")" { args })? { (name, args.unwrap_or_default()) } /// Parse a list of types rule type_list() -> Vec<String> = head:identifier() tail:(_ "," _ t:identifier() { t })* { let mut result = vec![head]; result.extend(tail); result } /// Parse an expression statement rule expression_statement() -> Statement = expr:expression() { Statement::Expression(expr) } /// Parse expressions with left-associative operators pub rule expression() -> Expr = precedence!{ x:(@) _ "||" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Or, right: Box::new(y) } } -- x:(@) _ "&&" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::And, right: Box::new(y) } } -- x:(@) _ "==" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Eq, right: Box::new(y) } } x:(@) _ "!=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Ne, right: Box::new(y) } } -- x:(@) _ "<=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Le, right: Box::new(y) } } x:(@) _ ">=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Ge, right: Box::new(y) } } x:(@) _ "<" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Lt, right: Box::new(y) } } x:(@) _ ">" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Gt, right: Box::new(y) } } -- x:(@) _ "+" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Add, right: Box::new(y) } } x:(@) _ "-" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Sub, right: Box::new(y) } } -- x:(@) _ "*" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Mul, right: Box::new(y) } } x:(@) _ "/" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Div, right: Box::new(y) } } x:(@) _ "%" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Mod, right: Box::new(y) } } -- x:@ _ "**" _ y:(@) { Expr::Binary { left: Box::new(x), op: BinaryOp::Pow, right: Box::new(y) } } -- "-" _ e:@ { Expr::Unary { op: UnaryOp::Neg, expr: Box::new(e) } } "not" _ e:@ { Expr::Unary { op: UnaryOp::Not, expr: Box::new(e) } } -- e:postfix() { e } } /// Postfix expressions (function calls) rule postfix() -> Expr = e:atom() calls:call_suffix()* { calls.into_iter().fold(e, |func, args| { Expr::Call { func: Box::new(func), args } }) } rule call_suffix() -> Vec<Expr> = _ "(" _ args:argument_list() _ ")" { args } /// Parse atomic expressions rule atom() -> Expr = float() // Must come before number / number() / string_literal() / boolean() / list() / record() / lambda() / let_expression() / if_expression() / identifier_expr() / "(" _ e:expression() _ ")" { e } /// Parse numbers (integers only) rule number() -> Expr = n:$("-"? ['0'..='9']+) !("." ['0'..='9']) {? n.parse::<i64>() .map(Expr::Number) .map_err(|_| "number") } /// Parse floating-point numbers rule float() -> Expr = n:$("-"? ['0'..='9']+ "." ['0'..='9']+) {? n.parse::<f64>() .map(Expr::Float) .map_err(|_| "float") } /// Parse string literals rule string_literal() -> Expr = "\"" chars:string_char()* "\"" { Expr::String(chars.into_iter().collect()) } /// Parse string characters with escape sequences rule string_char() -> char = "\\\\" { '\\' } / "\\\"" { '"' } / "\\n" { '\n' } / "\\t" { '\t' } / "\\r" { '\r' } / !['"' | '\\'] c:char() { c } /// Parse any character rule char() -> char = c:$([_]) { c.chars().next().unwrap() } /// Parse boolean literals rule boolean() -> Expr = "true" !identifier_char() { Expr::Bool(true) } / "false" !identifier_char() { Expr::Bool(false) } /// Parse lists rule list() -> Expr = "[" _ elements:expression_list() _ "]" { Expr::List(elements) } /// Parse expression lists rule expression_list() -> Vec<Expr> = head:expression() tail:(_ "," _ e:expression() { e })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse argument lists (for function calls) rule argument_list() -> Vec<Expr> = expression_list() /// Parse records (key-value mappings) rule record() -> Expr = "{" _ fields:field_list() _ "}" { Expr::Record(fields.into_iter().collect()) } /// Parse field lists for records rule field_list() -> Vec<(String, Expr)> = head:field() tail:(_ "," _ f:field() { f })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse a single field rule field() -> (String, Expr) = key:identifier() _ ":" _ value:expression() { (key, value) } /// Parse lambda expressions rule lambda() -> Expr = "\\" _ params:parameter_list() _ "->" _ body:expression() { Expr::Lambda { params, body: Box::new(body) } } / "fn" _ params:parameter_list() _ "->" _ body:expression() { Expr::Lambda { params, body: Box::new(body) } } /// Parse parameter lists rule parameter_list() -> Vec<String> = "(" _ params:identifier_list() _ ")" { params } / param:identifier() { vec![param] } /// Parse identifier lists rule identifier_list() -> Vec<String> = head:identifier() tail:(_ "," _ id:identifier() { id })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse let expressions rule let_expression() -> Expr = "let" _ bindings:binding_list() _ "in" _ body:expression() { Expr::Let { bindings, body: Box::new(body) } } /// Parse binding lists for let expressions rule binding_list() -> Vec<(String, Expr)> = head:binding() tail:(_ "," _ b:binding() { b })* { let mut result = vec![head]; result.extend(tail); result } /// Parse a single binding rule binding() -> (String, Expr) = name:identifier() _ "=" _ value:expression() { (name, value) } /// Parse if expressions rule if_expression() -> Expr = "if" _ cond:expression() _ "then" _ then_branch:expression() else_branch:(_ "else" _ e:expression() { e })? { Expr::If { condition: Box::new(cond), then_branch: Box::new(then_branch), else_branch: else_branch.map(Box::new), } } /// Parse identifier expressions rule identifier_expr() -> Expr = id:identifier() { Expr::Identifier(id) } /// Parse identifiers rule identifier() -> String = !reserved_word() s:$(identifier_start() identifier_char()*) { s.to_string() } rule identifier_start() -> () = ['a'..='z' | 'A'..='Z' | '_'] {} rule identifier_char() -> () = ['a'..='z' | 'A'..='Z' | '0'..='9' | '_'] {} /// Reserved words that can't be identifiers rule reserved_word() = ("if" / "then" / "else" / "let" / "in" / "fn" / "def" / "type" / "true" / "false" / "not") !identifier_char() /// Whitespace rule _() = quiet!{ (whitespace() / comment())* } rule whitespace() = [' ' | '\t' | '\n' | '\r']+ rule comment() = "//" (!"\n" [_])* / "/*" (!"*/" [_])* "*/" } } /// Simple evaluator for mathematical expressions pub fn evaluate(expr: &Expr) -> Result<f64, String> { match expr { Expr::Number(n) => Ok(*n as f64), Expr::Float(f) => Ok(*f), Expr::Binary { left, op, right } => { let l = evaluate(left)?; let r = evaluate(right)?; match op { BinaryOp::Add => Ok(l + r), BinaryOp::Sub => Ok(l - r), BinaryOp::Mul => Ok(l * r), BinaryOp::Div => { if r == 0.0 { Err("Division by zero".to_string()) } else { Ok(l / r) } } BinaryOp::Pow => Ok(l.powf(r)), _ => Err(format!("Cannot evaluate operator {:?}", op)), } } Expr::Unary { op: UnaryOp::Neg, expr, } => Ok(-evaluate(expr)?), _ => Err("Cannot evaluate this expression".to_string()), } } /// Parse a simple expression pub fn parse_expression(input: &str) -> Result<Expr, peg::error::ParseError<peg::str::LineCol>> { functional_parser::expression(input) } /// Parse a complete program pub fn parse_program(input: &str) -> Result<Program, peg::error::ParseError<peg::str::LineCol>> { functional_parser::program(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_number_parsing() { let result = parse_expression("42").unwrap(); assert_eq!(result, Expr::Number(42)); let result = parse_expression("-17").unwrap(); assert_eq!( result, Expr::Unary { op: UnaryOp::Neg, expr: Box::new(Expr::Number(17)) } ); } #[test] fn test_binary_expression() { let result = parse_expression("2 + 3").unwrap(); if let Expr::Binary { left, op, right } = result { assert_eq!(*left, Expr::Number(2)); assert_eq!(op, BinaryOp::Add); assert_eq!(*right, Expr::Number(3)); } else { panic!("Expected binary expression"); } } #[test] fn test_operator_precedence() { let result = parse_expression("2 + 3 * 4").unwrap(); // Should parse as 2 + (3 * 4) if let Expr::Binary { left, op, right } = result { assert_eq!(*left, Expr::Number(2)); assert_eq!(op, BinaryOp::Add); if let Expr::Binary { left: rl, op: rop, right: rr, } = right.as_ref() { assert_eq!(rl.as_ref(), &Expr::Number(3)); assert_eq!(*rop, BinaryOp::Mul); assert_eq!(rr.as_ref(), &Expr::Number(4)); } else { panic!("Expected binary expression on right"); } } else { panic!("Expected binary expression"); } } #[test] fn test_evaluation() { let expr = parse_expression("2 + 3 * 4").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 14.0); let expr = parse_expression("(2 + 3) * 4").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 20.0); let expr = parse_expression("2 ** 3").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 8.0); } #[test] fn test_function_call() { let result = parse_expression("foo(1, 2, 3)").unwrap(); if let Expr::Call { func, args } = result { assert_eq!(*func, Expr::Identifier("foo".to_string())); assert_eq!(args.len(), 3); assert_eq!(args[0], Expr::Number(1)); assert_eq!(args[1], Expr::Number(2)); assert_eq!(args[2], Expr::Number(3)); } else { panic!("Expected function call"); } } #[test] fn test_string_literals() { let result = parse_expression("\"hello world\"").unwrap(); assert_eq!(result, Expr::String("hello world".to_string())); let result = parse_expression("\"escaped\\nnewline\"").unwrap(); assert_eq!(result, Expr::String("escaped\nnewline".to_string())); } #[test] fn test_list_parsing() { let result = parse_expression("[1, 2, 3]").unwrap(); assert_eq!( result, Expr::List(vec![Expr::Number(1), Expr::Number(2), Expr::Number(3)]) ); let result = parse_expression("[]").unwrap(); assert_eq!(result, Expr::List(vec![])); } #[test] fn test_let_expression() { let result = parse_expression("let x = 5 in x + 1").unwrap(); if let Expr::Let { bindings, body } = result { assert_eq!(bindings.len(), 1); assert_eq!(bindings[0].0, "x"); assert_eq!(bindings[0].1, Expr::Number(5)); if let Expr::Binary { left, op, right } = &*body { assert_eq!(**left, Expr::Identifier("x".to_string())); assert_eq!(*op, BinaryOp::Add); assert_eq!(**right, Expr::Number(1)); } else { panic!("Expected binary expression in body"); } } else { panic!("Expected let expression"); } } #[test] fn test_if_expression() { let result = parse_expression("if true then 1 else 2").unwrap(); if let Expr::If { condition, then_branch, else_branch, } = result { assert_eq!(*condition, Expr::Bool(true)); assert_eq!(*then_branch, Expr::Number(1)); assert_eq!( else_branch.as_ref().map(|b| b.as_ref()), Some(&Expr::Number(2)) ); } else { panic!("Expected if expression"); } } #[test] fn test_lambda_expression() { let result = parse_expression("\\x -> x + 1").unwrap(); if let Expr::Lambda { params, body } = result { assert_eq!(params, vec!["x"]); if let Expr::Binary { left, op, right } = &*body { assert_eq!(**left, Expr::Identifier("x".to_string())); assert_eq!(*op, BinaryOp::Add); assert_eq!(**right, Expr::Number(1)); } else { panic!("Expected binary expression in body"); } } else { panic!("Expected lambda expression"); } } #[test] fn test_error_reporting() { let result = parse_expression("2 + "); assert!(result.is_err()); } } #[derive(Debug, Clone, PartialEq)] pub enum Statement { Expression(Expr), Definition { name: String, value: Expr, }, TypeDef { name: String, constructors: Vec<(String, Vec<String>)>, }, } }
Type definitions enable pattern matching and custom data structures in the parsed language. The TypeDef
variant stores the type name and its constructors with their parameter types.
Comment Handling
PEG makes it easy to handle both line and block comments:
Comments are automatically skipped in whitespace, simplifying the rest of the grammar.
Expression Evaluation
A simple evaluator demonstrates working with the parsed AST:
#![allow(unused)] fn main() { use std::collections::HashMap; /// AST nodes for a functional programming language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Float(f64), String(String), Bool(bool), Identifier(String), Binary { left: Box<Expr>, op: BinaryOp, right: Box<Expr>, }, Unary { op: UnaryOp, expr: Box<Expr>, }, Call { func: Box<Expr>, args: Vec<Expr>, }, Lambda { params: Vec<String>, body: Box<Expr>, }, Let { bindings: Vec<(String, Expr)>, body: Box<Expr>, }, If { condition: Box<Expr>, then_branch: Box<Expr>, else_branch: Option<Box<Expr>>, }, List(Vec<Expr>), Record(HashMap<String, Expr>), } #[derive(Debug, Clone, PartialEq)] pub enum BinaryOp { Add, Sub, Mul, Div, Mod, Pow, Eq, Ne, Lt, Le, Gt, Ge, And, Or, Cons, Append, } #[derive(Debug, Clone, PartialEq)] pub enum UnaryOp { Neg, Not, } #[derive(Debug, Clone, PartialEq)] pub enum Statement { Expression(Expr), Definition { name: String, value: Expr, }, TypeDef { name: String, constructors: Vec<(String, Vec<String>)>, }, } #[derive(Debug, Clone, PartialEq)] pub struct Program { pub statements: Vec<Statement>, } /// Error type for parser errors #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub line: usize, pub column: usize, pub expected: Vec<String>, } impl std::fmt::Display for ParseError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, "Parse error at line {}, column {}: {}", self.line, self.column, self.message ) } } impl std::error::Error for ParseError {} peg::parser! { pub grammar functional_parser() for str { /// Parse a complete program pub rule program() -> Program = _ statements:statement()* _ { Program { statements } } /// Parse a statement rule statement() -> Statement = definition() / type_definition() / expression_statement() /// Parse a variable definition rule definition() -> Statement = "def" _ name:identifier() _ "=" _ value:expression() _ { Statement::Definition { name, value } } /// Parse a type definition rule type_definition() -> Statement = "type" _ name:identifier() _ "=" _ constructors:constructor_list() _ { Statement::TypeDef { name, constructors } } /// Parse constructor list for type definitions rule constructor_list() -> Vec<(String, Vec<String>)> = head:constructor() tail:(_ "|" _ c:constructor() { c })* { let mut result = vec![head]; result.extend(tail); result } /// Parse a constructor rule constructor() -> (String, Vec<String>) = name:identifier() args:(_ "(" _ args:type_list() _ ")" { args })? { (name, args.unwrap_or_default()) } /// Parse a list of types rule type_list() -> Vec<String> = head:identifier() tail:(_ "," _ t:identifier() { t })* { let mut result = vec![head]; result.extend(tail); result } /// Parse an expression statement rule expression_statement() -> Statement = expr:expression() { Statement::Expression(expr) } /// Parse expressions with left-associative operators pub rule expression() -> Expr = precedence!{ x:(@) _ "||" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Or, right: Box::new(y) } } -- x:(@) _ "&&" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::And, right: Box::new(y) } } -- x:(@) _ "==" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Eq, right: Box::new(y) } } x:(@) _ "!=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Ne, right: Box::new(y) } } -- x:(@) _ "<=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Le, right: Box::new(y) } } x:(@) _ ">=" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Ge, right: Box::new(y) } } x:(@) _ "<" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Lt, right: Box::new(y) } } x:(@) _ ">" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Gt, right: Box::new(y) } } -- x:(@) _ "+" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Add, right: Box::new(y) } } x:(@) _ "-" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Sub, right: Box::new(y) } } -- x:(@) _ "*" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Mul, right: Box::new(y) } } x:(@) _ "/" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Div, right: Box::new(y) } } x:(@) _ "%" _ y:@ { Expr::Binary { left: Box::new(x), op: BinaryOp::Mod, right: Box::new(y) } } -- x:@ _ "**" _ y:(@) { Expr::Binary { left: Box::new(x), op: BinaryOp::Pow, right: Box::new(y) } } -- "-" _ e:@ { Expr::Unary { op: UnaryOp::Neg, expr: Box::new(e) } } "not" _ e:@ { Expr::Unary { op: UnaryOp::Not, expr: Box::new(e) } } -- e:postfix() { e } } /// Postfix expressions (function calls) rule postfix() -> Expr = e:atom() calls:call_suffix()* { calls.into_iter().fold(e, |func, args| { Expr::Call { func: Box::new(func), args } }) } rule call_suffix() -> Vec<Expr> = _ "(" _ args:argument_list() _ ")" { args } /// Parse atomic expressions rule atom() -> Expr = float() // Must come before number / number() / string_literal() / boolean() / list() / record() / lambda() / let_expression() / if_expression() / identifier_expr() / "(" _ e:expression() _ ")" { e } /// Parse numbers (integers only) rule number() -> Expr = n:$("-"? ['0'..='9']+) !("." ['0'..='9']) {? n.parse::<i64>() .map(Expr::Number) .map_err(|_| "number") } /// Parse floating-point numbers rule float() -> Expr = n:$("-"? ['0'..='9']+ "." ['0'..='9']+) {? n.parse::<f64>() .map(Expr::Float) .map_err(|_| "float") } /// Parse string literals rule string_literal() -> Expr = "\"" chars:string_char()* "\"" { Expr::String(chars.into_iter().collect()) } /// Parse string characters with escape sequences rule string_char() -> char = "\\\\" { '\\' } / "\\\"" { '"' } / "\\n" { '\n' } / "\\t" { '\t' } / "\\r" { '\r' } / !['"' | '\\'] c:char() { c } /// Parse any character rule char() -> char = c:$([_]) { c.chars().next().unwrap() } /// Parse boolean literals rule boolean() -> Expr = "true" !identifier_char() { Expr::Bool(true) } / "false" !identifier_char() { Expr::Bool(false) } /// Parse lists rule list() -> Expr = "[" _ elements:expression_list() _ "]" { Expr::List(elements) } /// Parse expression lists rule expression_list() -> Vec<Expr> = head:expression() tail:(_ "," _ e:expression() { e })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse argument lists (for function calls) rule argument_list() -> Vec<Expr> = expression_list() /// Parse records (key-value mappings) rule record() -> Expr = "{" _ fields:field_list() _ "}" { Expr::Record(fields.into_iter().collect()) } /// Parse field lists for records rule field_list() -> Vec<(String, Expr)> = head:field() tail:(_ "," _ f:field() { f })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse a single field rule field() -> (String, Expr) = key:identifier() _ ":" _ value:expression() { (key, value) } /// Parse lambda expressions rule lambda() -> Expr = "\\" _ params:parameter_list() _ "->" _ body:expression() { Expr::Lambda { params, body: Box::new(body) } } / "fn" _ params:parameter_list() _ "->" _ body:expression() { Expr::Lambda { params, body: Box::new(body) } } /// Parse parameter lists rule parameter_list() -> Vec<String> = "(" _ params:identifier_list() _ ")" { params } / param:identifier() { vec![param] } /// Parse identifier lists rule identifier_list() -> Vec<String> = head:identifier() tail:(_ "," _ id:identifier() { id })* { let mut result = vec![head]; result.extend(tail); result } / { vec![] } /// Parse let expressions rule let_expression() -> Expr = "let" _ bindings:binding_list() _ "in" _ body:expression() { Expr::Let { bindings, body: Box::new(body) } } /// Parse binding lists for let expressions rule binding_list() -> Vec<(String, Expr)> = head:binding() tail:(_ "," _ b:binding() { b })* { let mut result = vec![head]; result.extend(tail); result } /// Parse a single binding rule binding() -> (String, Expr) = name:identifier() _ "=" _ value:expression() { (name, value) } /// Parse if expressions rule if_expression() -> Expr = "if" _ cond:expression() _ "then" _ then_branch:expression() else_branch:(_ "else" _ e:expression() { e })? { Expr::If { condition: Box::new(cond), then_branch: Box::new(then_branch), else_branch: else_branch.map(Box::new), } } /// Parse identifier expressions rule identifier_expr() -> Expr = id:identifier() { Expr::Identifier(id) } /// Parse identifiers rule identifier() -> String = !reserved_word() s:$(identifier_start() identifier_char()*) { s.to_string() } rule identifier_start() -> () = ['a'..='z' | 'A'..='Z' | '_'] {} rule identifier_char() -> () = ['a'..='z' | 'A'..='Z' | '0'..='9' | '_'] {} /// Reserved words that can't be identifiers rule reserved_word() = ("if" / "then" / "else" / "let" / "in" / "fn" / "def" / "type" / "true" / "false" / "not") !identifier_char() /// Whitespace rule _() = quiet!{ (whitespace() / comment())* } rule whitespace() = [' ' | '\t' | '\n' | '\r']+ rule comment() = "//" (!"\n" [_])* / "/*" (!"*/" [_])* "*/" } } /// Parse a simple expression pub fn parse_expression(input: &str) -> Result<Expr, peg::error::ParseError<peg::str::LineCol>> { functional_parser::expression(input) } /// Parse a complete program pub fn parse_program(input: &str) -> Result<Program, peg::error::ParseError<peg::str::LineCol>> { functional_parser::program(input) } #[cfg(test)] mod tests { use super::*; #[test] fn test_number_parsing() { let result = parse_expression("42").unwrap(); assert_eq!(result, Expr::Number(42)); let result = parse_expression("-17").unwrap(); assert_eq!( result, Expr::Unary { op: UnaryOp::Neg, expr: Box::new(Expr::Number(17)) } ); } #[test] fn test_binary_expression() { let result = parse_expression("2 + 3").unwrap(); if let Expr::Binary { left, op, right } = result { assert_eq!(*left, Expr::Number(2)); assert_eq!(op, BinaryOp::Add); assert_eq!(*right, Expr::Number(3)); } else { panic!("Expected binary expression"); } } #[test] fn test_operator_precedence() { let result = parse_expression("2 + 3 * 4").unwrap(); // Should parse as 2 + (3 * 4) if let Expr::Binary { left, op, right } = result { assert_eq!(*left, Expr::Number(2)); assert_eq!(op, BinaryOp::Add); if let Expr::Binary { left: rl, op: rop, right: rr, } = right.as_ref() { assert_eq!(rl.as_ref(), &Expr::Number(3)); assert_eq!(*rop, BinaryOp::Mul); assert_eq!(rr.as_ref(), &Expr::Number(4)); } else { panic!("Expected binary expression on right"); } } else { panic!("Expected binary expression"); } } #[test] fn test_evaluation() { let expr = parse_expression("2 + 3 * 4").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 14.0); let expr = parse_expression("(2 + 3) * 4").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 20.0); let expr = parse_expression("2 ** 3").unwrap(); let result = evaluate(&expr).unwrap(); assert_eq!(result, 8.0); } #[test] fn test_function_call() { let result = parse_expression("foo(1, 2, 3)").unwrap(); if let Expr::Call { func, args } = result { assert_eq!(*func, Expr::Identifier("foo".to_string())); assert_eq!(args.len(), 3); assert_eq!(args[0], Expr::Number(1)); assert_eq!(args[1], Expr::Number(2)); assert_eq!(args[2], Expr::Number(3)); } else { panic!("Expected function call"); } } #[test] fn test_string_literals() { let result = parse_expression("\"hello world\"").unwrap(); assert_eq!(result, Expr::String("hello world".to_string())); let result = parse_expression("\"escaped\\nnewline\"").unwrap(); assert_eq!(result, Expr::String("escaped\nnewline".to_string())); } #[test] fn test_list_parsing() { let result = parse_expression("[1, 2, 3]").unwrap(); assert_eq!( result, Expr::List(vec![Expr::Number(1), Expr::Number(2), Expr::Number(3)]) ); let result = parse_expression("[]").unwrap(); assert_eq!(result, Expr::List(vec![])); } #[test] fn test_let_expression() { let result = parse_expression("let x = 5 in x + 1").unwrap(); if let Expr::Let { bindings, body } = result { assert_eq!(bindings.len(), 1); assert_eq!(bindings[0].0, "x"); assert_eq!(bindings[0].1, Expr::Number(5)); if let Expr::Binary { left, op, right } = &*body { assert_eq!(**left, Expr::Identifier("x".to_string())); assert_eq!(*op, BinaryOp::Add); assert_eq!(**right, Expr::Number(1)); } else { panic!("Expected binary expression in body"); } } else { panic!("Expected let expression"); } } #[test] fn test_if_expression() { let result = parse_expression("if true then 1 else 2").unwrap(); if let Expr::If { condition, then_branch, else_branch, } = result { assert_eq!(*condition, Expr::Bool(true)); assert_eq!(*then_branch, Expr::Number(1)); assert_eq!( else_branch.as_ref().map(|b| b.as_ref()), Some(&Expr::Number(2)) ); } else { panic!("Expected if expression"); } } #[test] fn test_lambda_expression() { let result = parse_expression("\\x -> x + 1").unwrap(); if let Expr::Lambda { params, body } = result { assert_eq!(params, vec!["x"]); if let Expr::Binary { left, op, right } = &*body { assert_eq!(**left, Expr::Identifier("x".to_string())); assert_eq!(*op, BinaryOp::Add); assert_eq!(**right, Expr::Number(1)); } else { panic!("Expected binary expression in body"); } } else { panic!("Expected lambda expression"); } } #[test] fn test_error_reporting() { let result = parse_expression("2 + "); assert!(result.is_err()); } } /// Simple evaluator for mathematical expressions pub fn evaluate(expr: &Expr) -> Result<f64, String> { match expr { Expr::Number(n) => Ok(*n as f64), Expr::Float(f) => Ok(*f), Expr::Binary { left, op, right } => { let l = evaluate(left)?; let r = evaluate(right)?; match op { BinaryOp::Add => Ok(l + r), BinaryOp::Sub => Ok(l - r), BinaryOp::Mul => Ok(l * r), BinaryOp::Div => { if r == 0.0 { Err("Division by zero".to_string()) } else { Ok(l / r) } } BinaryOp::Pow => Ok(l.powf(r)), _ => Err(format!("Cannot evaluate operator {:?}", op)), } } Expr::Unary { op: UnaryOp::Neg, expr, } => Ok(-evaluate(expr)?), _ => Err("Cannot evaluate this expression".to_string()), } } }
This evaluator handles basic arithmetic operations with proper error handling for cases like division by zero.
Error Reporting
PEG automatically generates error messages with position information:
#![allow(unused)] #![test!("peg/src/lib.rs", test_error_handling)] fn main() { }
The generated parser tracks the furthest position reached and expected tokens, providing helpful error messages for syntax errors.
Precedence and Associativity
The grammar correctly handles operator precedence through its structure:
#![allow(unused)] #![test!("peg/src/lib.rs", test_precedence)] fn main() { }
Higher-precedence operators are parsed in deeper rules, ensuring correct parse trees without ambiguity.
Advanced Grammar Features
PEG supports several advanced features useful for compiler construction:
Syntactic Predicates: Use &
for positive lookahead and !
for negative lookahead without consuming input.
Semantic Actions: Embed Rust code directly in the grammar to build ASTs or perform validation during parsing.
Rule Parameters: Pass parameters to rules for context-sensitive parsing.
Position Tracking: Access the current position in the input for error reporting or source mapping.
Custom Error Types: Define your own error types for domain-specific error reporting.
Performance Characteristics
PEG parsers have predictable performance characteristics:
Linear Time: PEGs parse in linear time with memoization (packrat parsing) or near-linear without.
Memory Usage: Packrat parsing trades memory for guaranteed linear time by memoizing all rule applications.
No Backtracking: Despite appearances, well-written PEG grammars minimize backtracking through careful ordering of alternatives.
Direct Execution: The generated parser is direct Rust code, avoiding interpretation overhead.
Grammar Design Best Practices
Structure your PEG grammar for clarity and performance:
Order alternatives from most specific to least specific. Since PEGs use ordered choice, put more specific patterns first to avoid incorrect matches.
Factor out common prefixes to reduce redundant parsing. Instead of "if" / "ifx"
, use "if" "x"?
.
Use cut operators (@
) to commit to a parse once certain syntax is recognized, improving error messages.
Keep semantic actions simple. Complex AST construction is better done in a separate pass.
Design for positive matching rather than negative. PEGs work best when describing what syntax looks like, not what it doesn’t.
pest
Pest is a PEG (Parsing Expression Grammar) parser generator that uses a dedicated grammar syntax to generate parsers at compile time. Unlike parser combinators, pest separates grammar definition from parsing logic, enabling clear and maintainable parser specifications. The library excels at parsing complex languages with automatic whitespace handling, built-in error reporting, and elegant precedence climbing for expressions.
The core philosophy of pest centers on declarative grammar definitions that closely resemble formal language specifications. Grammars are written in separate .pest
files using a clean, readable syntax that supports modifiers for different parsing behaviors. The pest_derive macro generates efficient parsers from these grammars at compile time.
Grammar Fundamentals
#![allow(unused)] fn main() { // Whitespace handling - pest automatically skips whitespace between rules WHITESPACE = _{ " " | "\t" | "\n" | "\r" } COMMENT = _{ "//" ~ (!NEWLINE ~ ANY)* } // ===== EXPRESSION PARSER ===== // Demonstrates precedence climbing and operators // Main expression entry point using precedence climbing expression = { term ~ (binary_op ~ term)* } // Terms in expressions term = { number | identifier | "(" ~ expression ~ ")" } // Binary operators for precedence climbing binary_op = _{ add | subtract | multiply | divide | power | eq | ne | lt | le | gt | ge } add = { "+" } subtract = { "-" } multiply = { "*" } divide = { "/" } power = { "^" } eq = { "==" } ne = { "!=" } lt = { "<" } le = { "<=" } gt = { ">" } ge = { ">=" } // Basic tokens number = @{ ASCII_DIGIT+ ~ ("." ~ ASCII_DIGIT+)? } identifier = @{ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_")* } // ===== JSON PARSER ===== // Demonstrates recursive structures and string handling json_value = { object | array | string | number | boolean | null } object = { "{" ~ (pair ~ ("," ~ pair)*)? ~ "}" } pair = { string ~ ":" ~ json_value } array = { "[" ~ (json_value ~ ("," ~ json_value)*)? ~ "]" } string = ${ "\"" ~ inner_string ~ "\"" } inner_string = @{ char* } char = { !("\"" | "\\") ~ ANY | "\\" ~ ("\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t") | "\\" ~ ("u" ~ ASCII_HEX_DIGIT{4}) } boolean = { "true" | "false" } null = { "null" } // ===== PROGRAMMING LANGUAGE CONSTRUCTS ===== // Demonstrates complex language parsing with statements and control flow program = { SOI ~ statement* ~ EOI } statement = { if_statement | while_statement | function_def | assignment | expression_statement | block } // Control flow statements if_statement = { "if" ~ expression ~ block ~ ("else" ~ (if_statement | block))? } while_statement = { "while" ~ expression ~ block } // Function definition function_def = { "fn" ~ identifier ~ "(" ~ parameter_list? ~ ")" ~ "->" ~ type_name ~ block } parameter_list = { parameter ~ ("," ~ parameter)* } parameter = { identifier ~ ":" ~ type_name } // Assignment and expressions assignment = { identifier ~ "=" ~ expression ~ ";" } expression_statement = { expression ~ ";" } // Block structure block = { "{" ~ statement* ~ "}" } // Type system type_name = { "int" | "float" | "bool" | "string" | identifier } // ===== CALCULATOR WITH PRECEDENCE ===== // Demonstrates pest's precedence climbing capabilities calculation = { SOI ~ calc_expression ~ EOI } calc_expression = { calc_term ~ (calc_add_op ~ calc_term)* } calc_term = { calc_factor ~ (calc_mul_op ~ calc_factor)* } calc_factor = { calc_power } calc_power = { calc_atom ~ (calc_pow_op ~ calc_power)? } calc_atom = { calc_number | "(" ~ calc_expression ~ ")" | calc_unary } calc_unary = { (calc_plus | calc_minus) ~ calc_number | calc_number } calc_add_op = { calc_plus | calc_minus } calc_mul_op = { calc_multiply | calc_divide } calc_pow_op = { calc_power_op } calc_plus = { "+" } calc_minus = { "-" } calc_multiply = { "*" } calc_divide = { "/" } calc_power_op = { "^" } calc_number = @{ ASCII_DIGIT+ ~ ("." ~ ASCII_DIGIT+)? } // ===== CUSTOM ERROR HANDLING ===== // Demonstrates custom error messages and recovery error_prone = { SOI ~ error_statement* ~ EOI } error_statement = { good_statement | expected_semicolon } good_statement = { identifier ~ "=" ~ number ~ ";" } expected_semicolon = { identifier ~ "=" ~ number ~ !(";" | NEWLINE) } // ===== LEXER-LIKE TOKENS ===== // Demonstrates atomic rules and token extraction token_stream = { SOI ~ token* ~ EOI } token = _{ keyword | operator_token | punctuation | literal | identifier_token } keyword = { "if" | "else" | "while" | "fn" | "let" | "return" | "mut" } operator_token = { "+=" | "-=" | "==" | "!=" | "<=" | ">=" | "&&" | "||" | "++" | "--" | "->" | "+" | "-" | "*" | "/" | "=" | "<" | ">" | "!" } punctuation = { "(" | ")" | "{" | "}" | "[" | "]" | ";" | "," | ":" } literal = { string_literal | number_literal | boolean_literal } string_literal = ${ "\"" ~ string_content ~ "\"" } string_content = @{ (!"\"" ~ ANY)* } number_literal = @{ ASCII_DIGIT+ ~ ("." ~ ASCII_DIGIT+)? } boolean_literal = { "true" | "false" } identifier_token = @{ ASCII_ALPHA ~ (ASCII_ALPHANUMERIC | "_")* } }
The grammar file demonstrates pest’s PEG syntax with various rule types. Silent rules prefixed with underscore don’t appear in the parse tree, simplifying AST construction. Atomic rules marked with @
consume input without considering inner whitespace. Compound atomic rules using $
capture the entire matched string as a single token.
AST Construction
#![allow(unused)] fn main() { use std::fmt; use pest::error::Error; use pest::iterators::{Pair, Pairs}; use pest::pratt_parser::{Assoc, Op, PrattParser}; use pest::Parser; use pest_derive::Parser; use thiserror::Error; #[derive(Parser)] #[grammar = "grammar.pest"] pub struct GrammarParser; #[derive(Error, Debug)] pub enum ParseError { #[error("Pest parsing error: {0}")] Pest(#[from] Box<Error<Rule>>), #[error("Invalid number format: {0}")] InvalidNumber(String), #[error("Unknown operator: {0}")] UnknownOperator(String), #[error("Invalid JSON value")] InvalidJson, #[error("Unexpected end of input")] UnexpectedEOF, } pub type Result<T> = std::result::Result<T, ParseError>; #[derive(Debug, Clone, PartialEq)] pub enum BinOperator { Add, Subtract, Multiply, Divide, Power, Eq, Ne, Lt, Le, Gt, Ge, } impl fmt::Display for BinOperator { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { BinOperator::Add => write!(f, "+"), BinOperator::Subtract => write!(f, "-"), BinOperator::Multiply => write!(f, "*"), BinOperator::Divide => write!(f, "/"), BinOperator::Power => write!(f, "^"), BinOperator::Eq => write!(f, "=="), BinOperator::Ne => write!(f, "!="), BinOperator::Lt => write!(f, "<"), BinOperator::Le => write!(f, "<="), BinOperator::Gt => write!(f, ">"), BinOperator::Ge => write!(f, ">="), } } } #[derive(Debug, Clone, PartialEq)] pub enum JsonValue { Object(Vec<(String, JsonValue)>), Array(Vec<JsonValue>), String(String), Number(f64), Boolean(bool), Null, } #[derive(Debug, Clone, PartialEq)] pub struct Program { pub statements: Vec<Statement>, } #[derive(Debug, Clone, PartialEq)] pub enum Statement { If { condition: Expr, then_block: Vec<Statement>, else_block: Option<Vec<Statement>>, }, While { condition: Expr, body: Vec<Statement>, }, Function { name: String, parameters: Vec<Parameter>, return_type: String, body: Vec<Statement>, }, Assignment { name: String, value: Expr, }, Expression(Expr), Block(Vec<Statement>), } #[derive(Debug, Clone, PartialEq)] pub struct Parameter { pub name: String, pub type_name: String, } #[derive(Debug, Clone, PartialEq)] pub enum Token { Keyword(String), Operator(String), Punctuation(String), Literal(LiteralValue), Identifier(String), } #[derive(Debug, Clone, PartialEq)] pub enum LiteralValue { String(String), Number(f64), Boolean(bool), } impl GrammarParser { /// Parse an expression using pest's built-in precedence climbing pub fn parse_expression(input: &str) -> Result<Expr> { let pairs = Self::parse(Rule::expression, input).map_err(Box::new)?; Self::build_expression(pairs) } fn build_expression(pairs: Pairs<Rule>) -> Result<Expr> { let pratt = PrattParser::new() .op(Op::infix(Rule::eq, Assoc::Left) | Op::infix(Rule::ne, Assoc::Left)) .op(Op::infix(Rule::lt, Assoc::Left) | Op::infix(Rule::le, Assoc::Left) | Op::infix(Rule::gt, Assoc::Left) | Op::infix(Rule::ge, Assoc::Left)) .op(Op::infix(Rule::add, Assoc::Left) | Op::infix(Rule::subtract, Assoc::Left)) .op(Op::infix(Rule::multiply, Assoc::Left) | Op::infix(Rule::divide, Assoc::Left)) .op(Op::infix(Rule::power, Assoc::Right)); pratt .map_primary(|primary| match primary.as_rule() { Rule::term => { let inner = primary .into_inner() .next() .ok_or(ParseError::UnexpectedEOF)?; match inner.as_rule() { Rule::number => { let num = inner.as_str().parse::<f64>().map_err(|_| { ParseError::InvalidNumber(inner.as_str().to_string()) })?; Ok(Expr::Number(num)) } Rule::identifier => Ok(Expr::Identifier(inner.as_str().to_string())), Rule::expression => Self::build_expression(inner.into_inner()), _ => unreachable!("Unexpected term rule: {:?}", inner.as_rule()), } } Rule::number => { let num = primary .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(primary.as_str().to_string()))?; Ok(Expr::Number(num)) } Rule::identifier => Ok(Expr::Identifier(primary.as_str().to_string())), Rule::expression => Self::build_expression(primary.into_inner()), _ => unreachable!("Unexpected primary rule: {:?}", primary.as_rule()), }) .map_infix(|left, op, right| { let op = match op.as_rule() { Rule::add => BinOperator::Add, Rule::subtract => BinOperator::Subtract, Rule::multiply => BinOperator::Multiply, Rule::divide => BinOperator::Divide, Rule::power => BinOperator::Power, Rule::eq => BinOperator::Eq, Rule::ne => BinOperator::Ne, Rule::lt => BinOperator::Lt, Rule::le => BinOperator::Le, Rule::gt => BinOperator::Gt, Rule::ge => BinOperator::Ge, _ => return Err(ParseError::UnknownOperator(op.as_str().to_string())), }; Ok(Expr::BinOp { left: Box::new(left?), op, right: Box::new(right?), }) }) .parse(pairs) } /// Parse a calculator expression with explicit precedence rules pub fn parse_calculation(input: &str) -> Result<f64> { let pairs = Self::parse(Rule::calculation, input).map_err(Box::new)?; Self::evaluate_calculation( pairs .into_iter() .next() .ok_or(ParseError::UnexpectedEOF)? .into_inner() .next() .ok_or(ParseError::UnexpectedEOF)?, ) } fn evaluate_calculation(pair: Pair<Rule>) -> Result<f64> { match pair.as_rule() { Rule::calc_expression => { let mut pairs = pair.into_inner(); let mut result = Self::evaluate_calculation(pairs.next().ok_or(ParseError::UnexpectedEOF)?)?; while let (Some(op), Some(term)) = (pairs.next(), pairs.next()) { let term_val = Self::evaluate_calculation(term)?; match op.as_rule() { Rule::calc_add_op => { let inner_op = op.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; match inner_op.as_rule() { Rule::calc_plus => result += term_val, Rule::calc_minus => result -= term_val, _ => unreachable!("Unexpected add op: {:?}", inner_op.as_rule()), } } Rule::calc_plus => result += term_val, Rule::calc_minus => result -= term_val, _ => unreachable!("Unexpected calc expression op: {:?}", op.as_rule()), } } Ok(result) } Rule::calc_term => { let mut pairs = pair.into_inner(); let mut result = Self::evaluate_calculation(pairs.next().ok_or(ParseError::UnexpectedEOF)?)?; while let (Some(op), Some(factor)) = (pairs.next(), pairs.next()) { let factor_val = Self::evaluate_calculation(factor)?; match op.as_rule() { Rule::calc_mul_op => { let inner_op = op.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; match inner_op.as_rule() { Rule::calc_multiply => result *= factor_val, Rule::calc_divide => result /= factor_val, _ => unreachable!("Unexpected mul op: {:?}", inner_op.as_rule()), } } Rule::calc_multiply => result *= factor_val, Rule::calc_divide => result /= factor_val, _ => unreachable!("Unexpected calc term op: {:?}", op.as_rule()), } } Ok(result) } Rule::calc_factor => Self::evaluate_calculation( pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?, ), Rule::calc_power => { let mut pairs = pair.into_inner(); let base = Self::evaluate_calculation(pairs.next().ok_or(ParseError::UnexpectedEOF)?)?; if let Some(op) = pairs.next() { if op.as_rule() == Rule::calc_pow_op { let exponent = Self::evaluate_calculation( pairs.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(base.powf(exponent)) } else { unreachable!("Expected calc_pow_op, got: {:?}", op.as_rule()); } } else { Ok(base) } } Rule::calc_atom => Self::evaluate_calculation( pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?, ), Rule::calc_unary => { let mut pairs = pair.into_inner(); let first = pairs.next().ok_or(ParseError::UnexpectedEOF)?; match first.as_rule() { Rule::calc_minus => { let val = Self::evaluate_calculation( pairs.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(-val) } Rule::calc_plus => { let val = Self::evaluate_calculation( pairs.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(val) } Rule::calc_number => first .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(first.as_str().to_string())), _ => Self::evaluate_calculation(first), } } Rule::calc_number => pair .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(pair.as_str().to_string())), _ => unreachable!("Unexpected rule: {:?}", pair.as_rule()), } } } impl GrammarParser { /// Parse JSON input into a JsonValue AST pub fn parse_json(input: &str) -> Result<JsonValue> { let pairs = Self::parse(Rule::json_value, input).map_err(Box::new)?; Self::build_json_value(pairs.into_iter().next().ok_or(ParseError::UnexpectedEOF)?) } fn build_json_value(pair: Pair<Rule>) -> Result<JsonValue> { match pair.as_rule() { Rule::json_value => { Self::build_json_value(pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?) } Rule::object => { let mut object = Vec::new(); for pair in pair.into_inner() { if let Rule::pair = pair.as_rule() { let mut inner = pair.into_inner(); let key = Self::parse_string(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; let value = Self::build_json_value(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; object.push((key, value)); } } Ok(JsonValue::Object(object)) } Rule::array => { let mut array = Vec::new(); for pair in pair.into_inner() { array.push(Self::build_json_value(pair)?); } Ok(JsonValue::Array(array)) } Rule::string => Ok(JsonValue::String(Self::parse_string(pair)?)), Rule::number => { let num = pair .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(pair.as_str().to_string()))?; Ok(JsonValue::Number(num)) } Rule::boolean => Ok(JsonValue::Boolean(pair.as_str() == "true")), Rule::null => Ok(JsonValue::Null), _ => Err(ParseError::InvalidJson), } } fn parse_string(pair: Pair<Rule>) -> Result<String> { let inner = pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; Ok(inner.as_str().to_string()) } } impl GrammarParser { /// Parse a complete program pub fn parse_program(input: &str) -> Result<Program> { let pairs = Self::parse(Rule::program, input).map_err(Box::new)?; let mut statements = Vec::new(); for pair in pairs .into_iter() .next() .ok_or(ParseError::UnexpectedEOF)? .into_inner() { if pair.as_rule() != Rule::EOI { statements.push(Self::build_statement(pair)?); } } Ok(Program { statements }) } fn build_statement(pair: Pair<Rule>) -> Result<Statement> { match pair.as_rule() { Rule::statement => { Self::build_statement(pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?) } Rule::if_statement => { let mut inner = pair.into_inner(); let condition = Self::build_expression_from_pair( inner.next().ok_or(ParseError::UnexpectedEOF)?, )?; let then_block = Self::build_block(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; let else_block = inner .next() .map(|p| match p.as_rule() { Rule::block => Self::build_block(p), Rule::if_statement => Ok(vec![Self::build_statement(p)?]), _ => unreachable!(), }) .transpose()?; Ok(Statement::If { condition, then_block, else_block, }) } Rule::while_statement => { let mut inner = pair.into_inner(); let condition = Self::build_expression_from_pair( inner.next().ok_or(ParseError::UnexpectedEOF)?, )?; let body = Self::build_block(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; Ok(Statement::While { condition, body }) } Rule::function_def => { let mut inner = pair.into_inner(); let name = inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); let mut parameters = Vec::new(); let mut next = inner.next().ok_or(ParseError::UnexpectedEOF)?; if next.as_rule() == Rule::parameter_list { for param_pair in next.into_inner() { let mut param_inner = param_pair.into_inner(); let param_name = param_inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); let param_type = param_inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); parameters.push(Parameter { name: param_name, type_name: param_type, }); } next = inner.next().ok_or(ParseError::UnexpectedEOF)?; } let return_type = next.as_str().to_string(); let body = Self::build_block(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; Ok(Statement::Function { name, parameters, return_type, body, }) } Rule::assignment => { let mut inner = pair.into_inner(); let name = inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); let value = Self::build_expression_from_pair( inner.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(Statement::Assignment { name, value }) } Rule::expression_statement => { let expr = Self::build_expression_from_pair( pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(Statement::Expression(expr)) } Rule::block => Ok(Statement::Block(Self::build_block(pair)?)), _ => unreachable!("Unexpected statement rule: {:?}", pair.as_rule()), } } fn build_block(pair: Pair<Rule>) -> Result<Vec<Statement>> { let mut statements = Vec::new(); for stmt_pair in pair.into_inner() { statements.push(Self::build_statement(stmt_pair)?); } Ok(statements) } fn build_expression_from_pair(pair: Pair<Rule>) -> Result<Expr> { Self::build_expression(pair.into_inner()) } } impl GrammarParser { /// Parse input into a stream of tokens pub fn parse_tokens(input: &str) -> Result<Vec<Token>> { let pairs = Self::parse(Rule::token_stream, input).map_err(Box::new)?; let mut tokens = Vec::new(); for pair in pairs .into_iter() .next() .ok_or(ParseError::UnexpectedEOF)? .into_inner() { if pair.as_rule() != Rule::EOI { tokens.push(Self::build_token(pair)?); } } Ok(tokens) } fn build_token(pair: Pair<Rule>) -> Result<Token> { match pair.as_rule() { Rule::keyword => Ok(Token::Keyword(pair.as_str().to_string())), Rule::operator_token => Ok(Token::Operator(pair.as_str().to_string())), Rule::punctuation => Ok(Token::Punctuation(pair.as_str().to_string())), Rule::literal => { let inner = pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; match inner.as_rule() { Rule::string_literal => { let content = inner .into_inner() .next() .ok_or(ParseError::UnexpectedEOF)? .as_str(); Ok(Token::Literal(LiteralValue::String(content.to_string()))) } Rule::number_literal => { let num = inner .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(inner.as_str().to_string()))?; Ok(Token::Literal(LiteralValue::Number(num))) } Rule::boolean_literal => Ok(Token::Literal(LiteralValue::Boolean( inner.as_str() == "true", ))), _ => unreachable!(), } } Rule::identifier_token => Ok(Token::Identifier(pair.as_str().to_string())), _ => unreachable!("Unexpected token rule: {:?}", pair.as_rule()), } } } impl GrammarParser { /// Parse and print pest parse tree for debugging pub fn debug_parse(rule: Rule, input: &str) -> Result<()> { let pairs = Self::parse(rule, input).map_err(Box::new)?; for pair in pairs { Self::print_pair(&pair, 0); } Ok(()) } fn print_pair(pair: &Pair<Rule>, indent: usize) { let indent_str = " ".repeat(indent); println!("{}{:?}: \"{}\"", indent_str, pair.as_rule(), pair.as_str()); for inner_pair in pair.clone().into_inner() { Self::print_pair(&inner_pair, indent + 1); } } /// Extract all identifiers from an expression pub fn extract_identifiers(expr: &Expr) -> Vec<String> { match expr { Expr::Identifier(name) => vec![name.clone()], Expr::BinOp { left, right, .. } => { let mut ids = Self::extract_identifiers(left); ids.extend(Self::extract_identifiers(right)); ids } Expr::Number(_) => vec![], } } /// Check if a rule matches the complete input pub fn can_parse(rule: Rule, input: &str) -> bool { match Self::parse(rule, input) { Ok(pairs) => { // Check that the entire input is consumed let input_len = input.len(); let parsed_len = pairs.as_str().len(); parsed_len == input_len } Err(_) => false, } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_expression_parsing() { let expr = GrammarParser::parse_expression("2 + 3 * 4").unwrap(); match expr { Expr::BinOp { op: BinOperator::Add, .. } => (), _ => panic!("Expected addition at top level"), } } #[test] fn test_calculation() { assert_eq!(GrammarParser::parse_calculation("2 + 3 * 4").unwrap(), 14.0); assert_eq!( GrammarParser::parse_calculation("(2 + 3) * 4").unwrap(), 20.0 ); assert_eq!( GrammarParser::parse_calculation("2 ^ 3 ^ 2").unwrap(), 512.0 ); } #[test] fn test_json_parsing() { let json = r#"{"name": "test", "value": 42, "active": true}"#; let result = GrammarParser::parse_json(json).unwrap(); if let JsonValue::Object(obj) = result { assert_eq!(obj.len(), 3); } else { panic!("Expected JSON object"); } } #[test] fn test_program_parsing() { let program = r#" fn add(x: int, y: int) -> int { x + y; } if x > 0 { y = 42; } "#; let result = GrammarParser::parse_program(program).unwrap(); assert_eq!(result.statements.len(), 2); } #[test] fn test_token_parsing() { let input = "if x == 42 { return true; }"; let tokens = GrammarParser::parse_tokens(input).unwrap(); assert!(tokens.len() > 5); match &tokens[0] { Token::Keyword(kw) => assert_eq!(kw, "if"), _ => panic!("Expected keyword"), } } #[test] fn test_identifier_extraction() { let expr = GrammarParser::parse_expression("x + y * z").unwrap(); let ids = GrammarParser::extract_identifiers(&expr); assert_eq!(ids, vec!["x", "y", "z"]); } #[test] fn test_debug_features() { assert!(GrammarParser::can_parse(Rule::expression, "2 + 3")); assert!(!GrammarParser::can_parse(Rule::expression, "2 +")); } } #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(f64), Identifier(String), BinOp { left: Box<Expr>, op: BinOperator, right: Box<Expr>, }, } }
The expression type represents the abstract syntax tree that parsers construct from pest’s parse pairs. Each variant corresponds to grammatical constructs defined in the pest grammar.
#![allow(unused)] fn main() { /// Parse an expression using pest's built-in precedence climbing pub fn parse_expression(input: &str) -> Result<Expr> { let pairs = Self::parse(Rule::expression, input).map_err(Box::new)?; Self::build_expression(pairs) } }
Expression parsing demonstrates the transformation from pest’s generic parse tree to a typed AST. The PrattParser handles operator precedence through precedence climbing, supporting both left and right associative operators. The parser processes pairs recursively, matching rule names to construct appropriate AST nodes.
Precedence Climbing
The Pratt parser configuration is built inline within the expression parser, defining operator precedence and associativity declaratively. Left associative operators like addition and multiplication are specified with infix
, while right associative operators like exponentiation use infix
with right associativity. Prefix operators for unary expressions integrate seamlessly with the precedence system.
#![allow(unused)] fn main() { /// Parse a calculator expression with explicit precedence rules pub fn parse_calculation(input: &str) -> Result<f64> { let pairs = Self::parse(Rule::calculation, input).map_err(Box::new)?; Self::evaluate_calculation( pairs .into_iter() .next() .ok_or(ParseError::UnexpectedEOF)? .into_inner() .next() .ok_or(ParseError::UnexpectedEOF)?, ) } }
The calculator parser evaluates expressions directly during parsing, demonstrating how precedence climbing produces correct results. Right associative exponentiation evaluates from right to left, while other operators evaluate left to right. This approach combines parsing and evaluation for simple expression languages.
JSON Parsing
#![allow(unused)] fn main() { use std::fmt; use pest::error::Error; use pest::iterators::{Pair, Pairs}; use pest::pratt_parser::{Assoc, Op, PrattParser}; use pest::Parser; use pest_derive::Parser; use thiserror::Error; #[derive(Parser)] #[grammar = "grammar.pest"] pub struct GrammarParser; #[derive(Error, Debug)] pub enum ParseError { #[error("Pest parsing error: {0}")] Pest(#[from] Box<Error<Rule>>), #[error("Invalid number format: {0}")] InvalidNumber(String), #[error("Unknown operator: {0}")] UnknownOperator(String), #[error("Invalid JSON value")] InvalidJson, #[error("Unexpected end of input")] UnexpectedEOF, } pub type Result<T> = std::result::Result<T, ParseError>; #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(f64), Identifier(String), BinOp { left: Box<Expr>, op: BinOperator, right: Box<Expr>, }, } #[derive(Debug, Clone, PartialEq)] pub enum BinOperator { Add, Subtract, Multiply, Divide, Power, Eq, Ne, Lt, Le, Gt, Ge, } impl fmt::Display for BinOperator { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { BinOperator::Add => write!(f, "+"), BinOperator::Subtract => write!(f, "-"), BinOperator::Multiply => write!(f, "*"), BinOperator::Divide => write!(f, "/"), BinOperator::Power => write!(f, "^"), BinOperator::Eq => write!(f, "=="), BinOperator::Ne => write!(f, "!="), BinOperator::Lt => write!(f, "<"), BinOperator::Le => write!(f, "<="), BinOperator::Gt => write!(f, ">"), BinOperator::Ge => write!(f, ">="), } } } #[derive(Debug, Clone, PartialEq)] pub struct Program { pub statements: Vec<Statement>, } #[derive(Debug, Clone, PartialEq)] pub enum Statement { If { condition: Expr, then_block: Vec<Statement>, else_block: Option<Vec<Statement>>, }, While { condition: Expr, body: Vec<Statement>, }, Function { name: String, parameters: Vec<Parameter>, return_type: String, body: Vec<Statement>, }, Assignment { name: String, value: Expr, }, Expression(Expr), Block(Vec<Statement>), } #[derive(Debug, Clone, PartialEq)] pub struct Parameter { pub name: String, pub type_name: String, } #[derive(Debug, Clone, PartialEq)] pub enum Token { Keyword(String), Operator(String), Punctuation(String), Literal(LiteralValue), Identifier(String), } #[derive(Debug, Clone, PartialEq)] pub enum LiteralValue { String(String), Number(f64), Boolean(bool), } impl GrammarParser { /// Parse an expression using pest's built-in precedence climbing pub fn parse_expression(input: &str) -> Result<Expr> { let pairs = Self::parse(Rule::expression, input).map_err(Box::new)?; Self::build_expression(pairs) } fn build_expression(pairs: Pairs<Rule>) -> Result<Expr> { let pratt = PrattParser::new() .op(Op::infix(Rule::eq, Assoc::Left) | Op::infix(Rule::ne, Assoc::Left)) .op(Op::infix(Rule::lt, Assoc::Left) | Op::infix(Rule::le, Assoc::Left) | Op::infix(Rule::gt, Assoc::Left) | Op::infix(Rule::ge, Assoc::Left)) .op(Op::infix(Rule::add, Assoc::Left) | Op::infix(Rule::subtract, Assoc::Left)) .op(Op::infix(Rule::multiply, Assoc::Left) | Op::infix(Rule::divide, Assoc::Left)) .op(Op::infix(Rule::power, Assoc::Right)); pratt .map_primary(|primary| match primary.as_rule() { Rule::term => { let inner = primary .into_inner() .next() .ok_or(ParseError::UnexpectedEOF)?; match inner.as_rule() { Rule::number => { let num = inner.as_str().parse::<f64>().map_err(|_| { ParseError::InvalidNumber(inner.as_str().to_string()) })?; Ok(Expr::Number(num)) } Rule::identifier => Ok(Expr::Identifier(inner.as_str().to_string())), Rule::expression => Self::build_expression(inner.into_inner()), _ => unreachable!("Unexpected term rule: {:?}", inner.as_rule()), } } Rule::number => { let num = primary .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(primary.as_str().to_string()))?; Ok(Expr::Number(num)) } Rule::identifier => Ok(Expr::Identifier(primary.as_str().to_string())), Rule::expression => Self::build_expression(primary.into_inner()), _ => unreachable!("Unexpected primary rule: {:?}", primary.as_rule()), }) .map_infix(|left, op, right| { let op = match op.as_rule() { Rule::add => BinOperator::Add, Rule::subtract => BinOperator::Subtract, Rule::multiply => BinOperator::Multiply, Rule::divide => BinOperator::Divide, Rule::power => BinOperator::Power, Rule::eq => BinOperator::Eq, Rule::ne => BinOperator::Ne, Rule::lt => BinOperator::Lt, Rule::le => BinOperator::Le, Rule::gt => BinOperator::Gt, Rule::ge => BinOperator::Ge, _ => return Err(ParseError::UnknownOperator(op.as_str().to_string())), }; Ok(Expr::BinOp { left: Box::new(left?), op, right: Box::new(right?), }) }) .parse(pairs) } /// Parse a calculator expression with explicit precedence rules pub fn parse_calculation(input: &str) -> Result<f64> { let pairs = Self::parse(Rule::calculation, input).map_err(Box::new)?; Self::evaluate_calculation( pairs .into_iter() .next() .ok_or(ParseError::UnexpectedEOF)? .into_inner() .next() .ok_or(ParseError::UnexpectedEOF)?, ) } fn evaluate_calculation(pair: Pair<Rule>) -> Result<f64> { match pair.as_rule() { Rule::calc_expression => { let mut pairs = pair.into_inner(); let mut result = Self::evaluate_calculation(pairs.next().ok_or(ParseError::UnexpectedEOF)?)?; while let (Some(op), Some(term)) = (pairs.next(), pairs.next()) { let term_val = Self::evaluate_calculation(term)?; match op.as_rule() { Rule::calc_add_op => { let inner_op = op.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; match inner_op.as_rule() { Rule::calc_plus => result += term_val, Rule::calc_minus => result -= term_val, _ => unreachable!("Unexpected add op: {:?}", inner_op.as_rule()), } } Rule::calc_plus => result += term_val, Rule::calc_minus => result -= term_val, _ => unreachable!("Unexpected calc expression op: {:?}", op.as_rule()), } } Ok(result) } Rule::calc_term => { let mut pairs = pair.into_inner(); let mut result = Self::evaluate_calculation(pairs.next().ok_or(ParseError::UnexpectedEOF)?)?; while let (Some(op), Some(factor)) = (pairs.next(), pairs.next()) { let factor_val = Self::evaluate_calculation(factor)?; match op.as_rule() { Rule::calc_mul_op => { let inner_op = op.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; match inner_op.as_rule() { Rule::calc_multiply => result *= factor_val, Rule::calc_divide => result /= factor_val, _ => unreachable!("Unexpected mul op: {:?}", inner_op.as_rule()), } } Rule::calc_multiply => result *= factor_val, Rule::calc_divide => result /= factor_val, _ => unreachable!("Unexpected calc term op: {:?}", op.as_rule()), } } Ok(result) } Rule::calc_factor => Self::evaluate_calculation( pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?, ), Rule::calc_power => { let mut pairs = pair.into_inner(); let base = Self::evaluate_calculation(pairs.next().ok_or(ParseError::UnexpectedEOF)?)?; if let Some(op) = pairs.next() { if op.as_rule() == Rule::calc_pow_op { let exponent = Self::evaluate_calculation( pairs.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(base.powf(exponent)) } else { unreachable!("Expected calc_pow_op, got: {:?}", op.as_rule()); } } else { Ok(base) } } Rule::calc_atom => Self::evaluate_calculation( pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?, ), Rule::calc_unary => { let mut pairs = pair.into_inner(); let first = pairs.next().ok_or(ParseError::UnexpectedEOF)?; match first.as_rule() { Rule::calc_minus => { let val = Self::evaluate_calculation( pairs.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(-val) } Rule::calc_plus => { let val = Self::evaluate_calculation( pairs.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(val) } Rule::calc_number => first .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(first.as_str().to_string())), _ => Self::evaluate_calculation(first), } } Rule::calc_number => pair .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(pair.as_str().to_string())), _ => unreachable!("Unexpected rule: {:?}", pair.as_rule()), } } } impl GrammarParser { /// Parse JSON input into a JsonValue AST pub fn parse_json(input: &str) -> Result<JsonValue> { let pairs = Self::parse(Rule::json_value, input).map_err(Box::new)?; Self::build_json_value(pairs.into_iter().next().ok_or(ParseError::UnexpectedEOF)?) } fn build_json_value(pair: Pair<Rule>) -> Result<JsonValue> { match pair.as_rule() { Rule::json_value => { Self::build_json_value(pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?) } Rule::object => { let mut object = Vec::new(); for pair in pair.into_inner() { if let Rule::pair = pair.as_rule() { let mut inner = pair.into_inner(); let key = Self::parse_string(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; let value = Self::build_json_value(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; object.push((key, value)); } } Ok(JsonValue::Object(object)) } Rule::array => { let mut array = Vec::new(); for pair in pair.into_inner() { array.push(Self::build_json_value(pair)?); } Ok(JsonValue::Array(array)) } Rule::string => Ok(JsonValue::String(Self::parse_string(pair)?)), Rule::number => { let num = pair .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(pair.as_str().to_string()))?; Ok(JsonValue::Number(num)) } Rule::boolean => Ok(JsonValue::Boolean(pair.as_str() == "true")), Rule::null => Ok(JsonValue::Null), _ => Err(ParseError::InvalidJson), } } fn parse_string(pair: Pair<Rule>) -> Result<String> { let inner = pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; Ok(inner.as_str().to_string()) } } impl GrammarParser { /// Parse a complete program pub fn parse_program(input: &str) -> Result<Program> { let pairs = Self::parse(Rule::program, input).map_err(Box::new)?; let mut statements = Vec::new(); for pair in pairs .into_iter() .next() .ok_or(ParseError::UnexpectedEOF)? .into_inner() { if pair.as_rule() != Rule::EOI { statements.push(Self::build_statement(pair)?); } } Ok(Program { statements }) } fn build_statement(pair: Pair<Rule>) -> Result<Statement> { match pair.as_rule() { Rule::statement => { Self::build_statement(pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?) } Rule::if_statement => { let mut inner = pair.into_inner(); let condition = Self::build_expression_from_pair( inner.next().ok_or(ParseError::UnexpectedEOF)?, )?; let then_block = Self::build_block(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; let else_block = inner .next() .map(|p| match p.as_rule() { Rule::block => Self::build_block(p), Rule::if_statement => Ok(vec![Self::build_statement(p)?]), _ => unreachable!(), }) .transpose()?; Ok(Statement::If { condition, then_block, else_block, }) } Rule::while_statement => { let mut inner = pair.into_inner(); let condition = Self::build_expression_from_pair( inner.next().ok_or(ParseError::UnexpectedEOF)?, )?; let body = Self::build_block(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; Ok(Statement::While { condition, body }) } Rule::function_def => { let mut inner = pair.into_inner(); let name = inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); let mut parameters = Vec::new(); let mut next = inner.next().ok_or(ParseError::UnexpectedEOF)?; if next.as_rule() == Rule::parameter_list { for param_pair in next.into_inner() { let mut param_inner = param_pair.into_inner(); let param_name = param_inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); let param_type = param_inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); parameters.push(Parameter { name: param_name, type_name: param_type, }); } next = inner.next().ok_or(ParseError::UnexpectedEOF)?; } let return_type = next.as_str().to_string(); let body = Self::build_block(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; Ok(Statement::Function { name, parameters, return_type, body, }) } Rule::assignment => { let mut inner = pair.into_inner(); let name = inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); let value = Self::build_expression_from_pair( inner.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(Statement::Assignment { name, value }) } Rule::expression_statement => { let expr = Self::build_expression_from_pair( pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(Statement::Expression(expr)) } Rule::block => Ok(Statement::Block(Self::build_block(pair)?)), _ => unreachable!("Unexpected statement rule: {:?}", pair.as_rule()), } } fn build_block(pair: Pair<Rule>) -> Result<Vec<Statement>> { let mut statements = Vec::new(); for stmt_pair in pair.into_inner() { statements.push(Self::build_statement(stmt_pair)?); } Ok(statements) } fn build_expression_from_pair(pair: Pair<Rule>) -> Result<Expr> { Self::build_expression(pair.into_inner()) } } impl GrammarParser { /// Parse input into a stream of tokens pub fn parse_tokens(input: &str) -> Result<Vec<Token>> { let pairs = Self::parse(Rule::token_stream, input).map_err(Box::new)?; let mut tokens = Vec::new(); for pair in pairs .into_iter() .next() .ok_or(ParseError::UnexpectedEOF)? .into_inner() { if pair.as_rule() != Rule::EOI { tokens.push(Self::build_token(pair)?); } } Ok(tokens) } fn build_token(pair: Pair<Rule>) -> Result<Token> { match pair.as_rule() { Rule::keyword => Ok(Token::Keyword(pair.as_str().to_string())), Rule::operator_token => Ok(Token::Operator(pair.as_str().to_string())), Rule::punctuation => Ok(Token::Punctuation(pair.as_str().to_string())), Rule::literal => { let inner = pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; match inner.as_rule() { Rule::string_literal => { let content = inner .into_inner() .next() .ok_or(ParseError::UnexpectedEOF)? .as_str(); Ok(Token::Literal(LiteralValue::String(content.to_string()))) } Rule::number_literal => { let num = inner .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(inner.as_str().to_string()))?; Ok(Token::Literal(LiteralValue::Number(num))) } Rule::boolean_literal => Ok(Token::Literal(LiteralValue::Boolean( inner.as_str() == "true", ))), _ => unreachable!(), } } Rule::identifier_token => Ok(Token::Identifier(pair.as_str().to_string())), _ => unreachable!("Unexpected token rule: {:?}", pair.as_rule()), } } } impl GrammarParser { /// Parse and print pest parse tree for debugging pub fn debug_parse(rule: Rule, input: &str) -> Result<()> { let pairs = Self::parse(rule, input).map_err(Box::new)?; for pair in pairs { Self::print_pair(&pair, 0); } Ok(()) } fn print_pair(pair: &Pair<Rule>, indent: usize) { let indent_str = " ".repeat(indent); println!("{}{:?}: \"{}\"", indent_str, pair.as_rule(), pair.as_str()); for inner_pair in pair.clone().into_inner() { Self::print_pair(&inner_pair, indent + 1); } } /// Extract all identifiers from an expression pub fn extract_identifiers(expr: &Expr) -> Vec<String> { match expr { Expr::Identifier(name) => vec![name.clone()], Expr::BinOp { left, right, .. } => { let mut ids = Self::extract_identifiers(left); ids.extend(Self::extract_identifiers(right)); ids } Expr::Number(_) => vec![], } } /// Check if a rule matches the complete input pub fn can_parse(rule: Rule, input: &str) -> bool { match Self::parse(rule, input) { Ok(pairs) => { // Check that the entire input is consumed let input_len = input.len(); let parsed_len = pairs.as_str().len(); parsed_len == input_len } Err(_) => false, } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_expression_parsing() { let expr = GrammarParser::parse_expression("2 + 3 * 4").unwrap(); match expr { Expr::BinOp { op: BinOperator::Add, .. } => (), _ => panic!("Expected addition at top level"), } } #[test] fn test_calculation() { assert_eq!(GrammarParser::parse_calculation("2 + 3 * 4").unwrap(), 14.0); assert_eq!( GrammarParser::parse_calculation("(2 + 3) * 4").unwrap(), 20.0 ); assert_eq!( GrammarParser::parse_calculation("2 ^ 3 ^ 2").unwrap(), 512.0 ); } #[test] fn test_json_parsing() { let json = r#"{"name": "test", "value": 42, "active": true}"#; let result = GrammarParser::parse_json(json).unwrap(); if let JsonValue::Object(obj) = result { assert_eq!(obj.len(), 3); } else { panic!("Expected JSON object"); } } #[test] fn test_program_parsing() { let program = r#" fn add(x: int, y: int) -> int { x + y; } if x > 0 { y = 42; } "#; let result = GrammarParser::parse_program(program).unwrap(); assert_eq!(result.statements.len(), 2); } #[test] fn test_token_parsing() { let input = "if x == 42 { return true; }"; let tokens = GrammarParser::parse_tokens(input).unwrap(); assert!(tokens.len() > 5); match &tokens[0] { Token::Keyword(kw) => assert_eq!(kw, "if"), _ => panic!("Expected keyword"), } } #[test] fn test_identifier_extraction() { let expr = GrammarParser::parse_expression("x + y * z").unwrap(); let ids = GrammarParser::extract_identifiers(&expr); assert_eq!(ids, vec!["x", "y", "z"]); } #[test] fn test_debug_features() { assert!(GrammarParser::can_parse(Rule::expression, "2 + 3")); assert!(!GrammarParser::can_parse(Rule::expression, "2 +")); } } #[derive(Debug, Clone, PartialEq)] pub enum JsonValue { Object(Vec<(String, JsonValue)>), Array(Vec<JsonValue>), String(String), Number(f64), Boolean(bool), Null, } }
JSON parsing showcases pest’s handling of recursive data structures. The grammar defines objects, arrays, strings, numbers, and literals with appropriate nesting rules.
#![allow(unused)] fn main() { /// Parse JSON input into a JsonValue AST pub fn parse_json(input: &str) -> Result<JsonValue> { let pairs = Self::parse(Rule::json_value, input).map_err(Box::new)?; Self::build_json_value(pairs.into_iter().next().ok_or(ParseError::UnexpectedEOF)?) } }
The JSON parser transforms pest pairs into a typed representation. Match expressions on rule types drive the recursive construction of nested structures. String escape sequences and number parsing are handled during AST construction rather than in the grammar.
Programming Language Constructs
#![allow(unused)] fn main() { use std::fmt; use pest::error::Error; use pest::iterators::{Pair, Pairs}; use pest::pratt_parser::{Assoc, Op, PrattParser}; use pest::Parser; use pest_derive::Parser; use thiserror::Error; #[derive(Parser)] #[grammar = "grammar.pest"] pub struct GrammarParser; #[derive(Error, Debug)] pub enum ParseError { #[error("Pest parsing error: {0}")] Pest(#[from] Box<Error<Rule>>), #[error("Invalid number format: {0}")] InvalidNumber(String), #[error("Unknown operator: {0}")] UnknownOperator(String), #[error("Invalid JSON value")] InvalidJson, #[error("Unexpected end of input")] UnexpectedEOF, } pub type Result<T> = std::result::Result<T, ParseError>; #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(f64), Identifier(String), BinOp { left: Box<Expr>, op: BinOperator, right: Box<Expr>, }, } #[derive(Debug, Clone, PartialEq)] pub enum BinOperator { Add, Subtract, Multiply, Divide, Power, Eq, Ne, Lt, Le, Gt, Ge, } impl fmt::Display for BinOperator { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { BinOperator::Add => write!(f, "+"), BinOperator::Subtract => write!(f, "-"), BinOperator::Multiply => write!(f, "*"), BinOperator::Divide => write!(f, "/"), BinOperator::Power => write!(f, "^"), BinOperator::Eq => write!(f, "=="), BinOperator::Ne => write!(f, "!="), BinOperator::Lt => write!(f, "<"), BinOperator::Le => write!(f, "<="), BinOperator::Gt => write!(f, ">"), BinOperator::Ge => write!(f, ">="), } } } #[derive(Debug, Clone, PartialEq)] pub enum JsonValue { Object(Vec<(String, JsonValue)>), Array(Vec<JsonValue>), String(String), Number(f64), Boolean(bool), Null, } #[derive(Debug, Clone, PartialEq)] pub struct Program { pub statements: Vec<Statement>, } #[derive(Debug, Clone, PartialEq)] pub struct Parameter { pub name: String, pub type_name: String, } #[derive(Debug, Clone, PartialEq)] pub enum Token { Keyword(String), Operator(String), Punctuation(String), Literal(LiteralValue), Identifier(String), } #[derive(Debug, Clone, PartialEq)] pub enum LiteralValue { String(String), Number(f64), Boolean(bool), } impl GrammarParser { /// Parse an expression using pest's built-in precedence climbing pub fn parse_expression(input: &str) -> Result<Expr> { let pairs = Self::parse(Rule::expression, input).map_err(Box::new)?; Self::build_expression(pairs) } fn build_expression(pairs: Pairs<Rule>) -> Result<Expr> { let pratt = PrattParser::new() .op(Op::infix(Rule::eq, Assoc::Left) | Op::infix(Rule::ne, Assoc::Left)) .op(Op::infix(Rule::lt, Assoc::Left) | Op::infix(Rule::le, Assoc::Left) | Op::infix(Rule::gt, Assoc::Left) | Op::infix(Rule::ge, Assoc::Left)) .op(Op::infix(Rule::add, Assoc::Left) | Op::infix(Rule::subtract, Assoc::Left)) .op(Op::infix(Rule::multiply, Assoc::Left) | Op::infix(Rule::divide, Assoc::Left)) .op(Op::infix(Rule::power, Assoc::Right)); pratt .map_primary(|primary| match primary.as_rule() { Rule::term => { let inner = primary .into_inner() .next() .ok_or(ParseError::UnexpectedEOF)?; match inner.as_rule() { Rule::number => { let num = inner.as_str().parse::<f64>().map_err(|_| { ParseError::InvalidNumber(inner.as_str().to_string()) })?; Ok(Expr::Number(num)) } Rule::identifier => Ok(Expr::Identifier(inner.as_str().to_string())), Rule::expression => Self::build_expression(inner.into_inner()), _ => unreachable!("Unexpected term rule: {:?}", inner.as_rule()), } } Rule::number => { let num = primary .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(primary.as_str().to_string()))?; Ok(Expr::Number(num)) } Rule::identifier => Ok(Expr::Identifier(primary.as_str().to_string())), Rule::expression => Self::build_expression(primary.into_inner()), _ => unreachable!("Unexpected primary rule: {:?}", primary.as_rule()), }) .map_infix(|left, op, right| { let op = match op.as_rule() { Rule::add => BinOperator::Add, Rule::subtract => BinOperator::Subtract, Rule::multiply => BinOperator::Multiply, Rule::divide => BinOperator::Divide, Rule::power => BinOperator::Power, Rule::eq => BinOperator::Eq, Rule::ne => BinOperator::Ne, Rule::lt => BinOperator::Lt, Rule::le => BinOperator::Le, Rule::gt => BinOperator::Gt, Rule::ge => BinOperator::Ge, _ => return Err(ParseError::UnknownOperator(op.as_str().to_string())), }; Ok(Expr::BinOp { left: Box::new(left?), op, right: Box::new(right?), }) }) .parse(pairs) } /// Parse a calculator expression with explicit precedence rules pub fn parse_calculation(input: &str) -> Result<f64> { let pairs = Self::parse(Rule::calculation, input).map_err(Box::new)?; Self::evaluate_calculation( pairs .into_iter() .next() .ok_or(ParseError::UnexpectedEOF)? .into_inner() .next() .ok_or(ParseError::UnexpectedEOF)?, ) } fn evaluate_calculation(pair: Pair<Rule>) -> Result<f64> { match pair.as_rule() { Rule::calc_expression => { let mut pairs = pair.into_inner(); let mut result = Self::evaluate_calculation(pairs.next().ok_or(ParseError::UnexpectedEOF)?)?; while let (Some(op), Some(term)) = (pairs.next(), pairs.next()) { let term_val = Self::evaluate_calculation(term)?; match op.as_rule() { Rule::calc_add_op => { let inner_op = op.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; match inner_op.as_rule() { Rule::calc_plus => result += term_val, Rule::calc_minus => result -= term_val, _ => unreachable!("Unexpected add op: {:?}", inner_op.as_rule()), } } Rule::calc_plus => result += term_val, Rule::calc_minus => result -= term_val, _ => unreachable!("Unexpected calc expression op: {:?}", op.as_rule()), } } Ok(result) } Rule::calc_term => { let mut pairs = pair.into_inner(); let mut result = Self::evaluate_calculation(pairs.next().ok_or(ParseError::UnexpectedEOF)?)?; while let (Some(op), Some(factor)) = (pairs.next(), pairs.next()) { let factor_val = Self::evaluate_calculation(factor)?; match op.as_rule() { Rule::calc_mul_op => { let inner_op = op.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; match inner_op.as_rule() { Rule::calc_multiply => result *= factor_val, Rule::calc_divide => result /= factor_val, _ => unreachable!("Unexpected mul op: {:?}", inner_op.as_rule()), } } Rule::calc_multiply => result *= factor_val, Rule::calc_divide => result /= factor_val, _ => unreachable!("Unexpected calc term op: {:?}", op.as_rule()), } } Ok(result) } Rule::calc_factor => Self::evaluate_calculation( pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?, ), Rule::calc_power => { let mut pairs = pair.into_inner(); let base = Self::evaluate_calculation(pairs.next().ok_or(ParseError::UnexpectedEOF)?)?; if let Some(op) = pairs.next() { if op.as_rule() == Rule::calc_pow_op { let exponent = Self::evaluate_calculation( pairs.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(base.powf(exponent)) } else { unreachable!("Expected calc_pow_op, got: {:?}", op.as_rule()); } } else { Ok(base) } } Rule::calc_atom => Self::evaluate_calculation( pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?, ), Rule::calc_unary => { let mut pairs = pair.into_inner(); let first = pairs.next().ok_or(ParseError::UnexpectedEOF)?; match first.as_rule() { Rule::calc_minus => { let val = Self::evaluate_calculation( pairs.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(-val) } Rule::calc_plus => { let val = Self::evaluate_calculation( pairs.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(val) } Rule::calc_number => first .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(first.as_str().to_string())), _ => Self::evaluate_calculation(first), } } Rule::calc_number => pair .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(pair.as_str().to_string())), _ => unreachable!("Unexpected rule: {:?}", pair.as_rule()), } } } impl GrammarParser { /// Parse JSON input into a JsonValue AST pub fn parse_json(input: &str) -> Result<JsonValue> { let pairs = Self::parse(Rule::json_value, input).map_err(Box::new)?; Self::build_json_value(pairs.into_iter().next().ok_or(ParseError::UnexpectedEOF)?) } fn build_json_value(pair: Pair<Rule>) -> Result<JsonValue> { match pair.as_rule() { Rule::json_value => { Self::build_json_value(pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?) } Rule::object => { let mut object = Vec::new(); for pair in pair.into_inner() { if let Rule::pair = pair.as_rule() { let mut inner = pair.into_inner(); let key = Self::parse_string(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; let value = Self::build_json_value(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; object.push((key, value)); } } Ok(JsonValue::Object(object)) } Rule::array => { let mut array = Vec::new(); for pair in pair.into_inner() { array.push(Self::build_json_value(pair)?); } Ok(JsonValue::Array(array)) } Rule::string => Ok(JsonValue::String(Self::parse_string(pair)?)), Rule::number => { let num = pair .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(pair.as_str().to_string()))?; Ok(JsonValue::Number(num)) } Rule::boolean => Ok(JsonValue::Boolean(pair.as_str() == "true")), Rule::null => Ok(JsonValue::Null), _ => Err(ParseError::InvalidJson), } } fn parse_string(pair: Pair<Rule>) -> Result<String> { let inner = pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; Ok(inner.as_str().to_string()) } } impl GrammarParser { /// Parse a complete program pub fn parse_program(input: &str) -> Result<Program> { let pairs = Self::parse(Rule::program, input).map_err(Box::new)?; let mut statements = Vec::new(); for pair in pairs .into_iter() .next() .ok_or(ParseError::UnexpectedEOF)? .into_inner() { if pair.as_rule() != Rule::EOI { statements.push(Self::build_statement(pair)?); } } Ok(Program { statements }) } fn build_statement(pair: Pair<Rule>) -> Result<Statement> { match pair.as_rule() { Rule::statement => { Self::build_statement(pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?) } Rule::if_statement => { let mut inner = pair.into_inner(); let condition = Self::build_expression_from_pair( inner.next().ok_or(ParseError::UnexpectedEOF)?, )?; let then_block = Self::build_block(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; let else_block = inner .next() .map(|p| match p.as_rule() { Rule::block => Self::build_block(p), Rule::if_statement => Ok(vec![Self::build_statement(p)?]), _ => unreachable!(), }) .transpose()?; Ok(Statement::If { condition, then_block, else_block, }) } Rule::while_statement => { let mut inner = pair.into_inner(); let condition = Self::build_expression_from_pair( inner.next().ok_or(ParseError::UnexpectedEOF)?, )?; let body = Self::build_block(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; Ok(Statement::While { condition, body }) } Rule::function_def => { let mut inner = pair.into_inner(); let name = inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); let mut parameters = Vec::new(); let mut next = inner.next().ok_or(ParseError::UnexpectedEOF)?; if next.as_rule() == Rule::parameter_list { for param_pair in next.into_inner() { let mut param_inner = param_pair.into_inner(); let param_name = param_inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); let param_type = param_inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); parameters.push(Parameter { name: param_name, type_name: param_type, }); } next = inner.next().ok_or(ParseError::UnexpectedEOF)?; } let return_type = next.as_str().to_string(); let body = Self::build_block(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; Ok(Statement::Function { name, parameters, return_type, body, }) } Rule::assignment => { let mut inner = pair.into_inner(); let name = inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); let value = Self::build_expression_from_pair( inner.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(Statement::Assignment { name, value }) } Rule::expression_statement => { let expr = Self::build_expression_from_pair( pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(Statement::Expression(expr)) } Rule::block => Ok(Statement::Block(Self::build_block(pair)?)), _ => unreachable!("Unexpected statement rule: {:?}", pair.as_rule()), } } fn build_block(pair: Pair<Rule>) -> Result<Vec<Statement>> { let mut statements = Vec::new(); for stmt_pair in pair.into_inner() { statements.push(Self::build_statement(stmt_pair)?); } Ok(statements) } fn build_expression_from_pair(pair: Pair<Rule>) -> Result<Expr> { Self::build_expression(pair.into_inner()) } } impl GrammarParser { /// Parse input into a stream of tokens pub fn parse_tokens(input: &str) -> Result<Vec<Token>> { let pairs = Self::parse(Rule::token_stream, input).map_err(Box::new)?; let mut tokens = Vec::new(); for pair in pairs .into_iter() .next() .ok_or(ParseError::UnexpectedEOF)? .into_inner() { if pair.as_rule() != Rule::EOI { tokens.push(Self::build_token(pair)?); } } Ok(tokens) } fn build_token(pair: Pair<Rule>) -> Result<Token> { match pair.as_rule() { Rule::keyword => Ok(Token::Keyword(pair.as_str().to_string())), Rule::operator_token => Ok(Token::Operator(pair.as_str().to_string())), Rule::punctuation => Ok(Token::Punctuation(pair.as_str().to_string())), Rule::literal => { let inner = pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; match inner.as_rule() { Rule::string_literal => { let content = inner .into_inner() .next() .ok_or(ParseError::UnexpectedEOF)? .as_str(); Ok(Token::Literal(LiteralValue::String(content.to_string()))) } Rule::number_literal => { let num = inner .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(inner.as_str().to_string()))?; Ok(Token::Literal(LiteralValue::Number(num))) } Rule::boolean_literal => Ok(Token::Literal(LiteralValue::Boolean( inner.as_str() == "true", ))), _ => unreachable!(), } } Rule::identifier_token => Ok(Token::Identifier(pair.as_str().to_string())), _ => unreachable!("Unexpected token rule: {:?}", pair.as_rule()), } } } impl GrammarParser { /// Parse and print pest parse tree for debugging pub fn debug_parse(rule: Rule, input: &str) -> Result<()> { let pairs = Self::parse(rule, input).map_err(Box::new)?; for pair in pairs { Self::print_pair(&pair, 0); } Ok(()) } fn print_pair(pair: &Pair<Rule>, indent: usize) { let indent_str = " ".repeat(indent); println!("{}{:?}: \"{}\"", indent_str, pair.as_rule(), pair.as_str()); for inner_pair in pair.clone().into_inner() { Self::print_pair(&inner_pair, indent + 1); } } /// Extract all identifiers from an expression pub fn extract_identifiers(expr: &Expr) -> Vec<String> { match expr { Expr::Identifier(name) => vec![name.clone()], Expr::BinOp { left, right, .. } => { let mut ids = Self::extract_identifiers(left); ids.extend(Self::extract_identifiers(right)); ids } Expr::Number(_) => vec![], } } /// Check if a rule matches the complete input pub fn can_parse(rule: Rule, input: &str) -> bool { match Self::parse(rule, input) { Ok(pairs) => { // Check that the entire input is consumed let input_len = input.len(); let parsed_len = pairs.as_str().len(); parsed_len == input_len } Err(_) => false, } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_expression_parsing() { let expr = GrammarParser::parse_expression("2 + 3 * 4").unwrap(); match expr { Expr::BinOp { op: BinOperator::Add, .. } => (), _ => panic!("Expected addition at top level"), } } #[test] fn test_calculation() { assert_eq!(GrammarParser::parse_calculation("2 + 3 * 4").unwrap(), 14.0); assert_eq!( GrammarParser::parse_calculation("(2 + 3) * 4").unwrap(), 20.0 ); assert_eq!( GrammarParser::parse_calculation("2 ^ 3 ^ 2").unwrap(), 512.0 ); } #[test] fn test_json_parsing() { let json = r#"{"name": "test", "value": 42, "active": true}"#; let result = GrammarParser::parse_json(json).unwrap(); if let JsonValue::Object(obj) = result { assert_eq!(obj.len(), 3); } else { panic!("Expected JSON object"); } } #[test] fn test_program_parsing() { let program = r#" fn add(x: int, y: int) -> int { x + y; } if x > 0 { y = 42; } "#; let result = GrammarParser::parse_program(program).unwrap(); assert_eq!(result.statements.len(), 2); } #[test] fn test_token_parsing() { let input = "if x == 42 { return true; }"; let tokens = GrammarParser::parse_tokens(input).unwrap(); assert!(tokens.len() > 5); match &tokens[0] { Token::Keyword(kw) => assert_eq!(kw, "if"), _ => panic!("Expected keyword"), } } #[test] fn test_identifier_extraction() { let expr = GrammarParser::parse_expression("x + y * z").unwrap(); let ids = GrammarParser::extract_identifiers(&expr); assert_eq!(ids, vec!["x", "y", "z"]); } #[test] fn test_debug_features() { assert!(GrammarParser::can_parse(Rule::expression, "2 + 3")); assert!(!GrammarParser::can_parse(Rule::expression, "2 +")); } } #[derive(Debug, Clone, PartialEq)] pub enum Statement { If { condition: Expr, then_block: Vec<Statement>, else_block: Option<Vec<Statement>>, }, While { condition: Expr, body: Vec<Statement>, }, Function { name: String, parameters: Vec<Parameter>, return_type: String, body: Vec<Statement>, }, Assignment { name: String, value: Expr, }, Expression(Expr), Block(Vec<Statement>), } }
Statement types demonstrate parsing of control flow and program structure. The grammar supports if statements, while loops, function definitions, and variable declarations.
#![allow(unused)] fn main() { /// Parse a complete program pub fn parse_program(input: &str) -> Result<Program> { let pairs = Self::parse(Rule::program, input).map_err(Box::new)?; let mut statements = Vec::new(); for pair in pairs .into_iter() .next() .ok_or(ParseError::UnexpectedEOF)? .into_inner() { if pair.as_rule() != Rule::EOI { statements.push(Self::build_statement(pair)?); } } Ok(Program { statements }) } }
Program parsing builds complex ASTs from multiple statement types. The recursive nature of the parser handles nested blocks and control structures. Pattern matching on rule names provides clear correspondence between grammar and implementation.
Token Stream Parsing
#![allow(unused)] fn main() { use std::fmt; use pest::error::Error; use pest::iterators::{Pair, Pairs}; use pest::pratt_parser::{Assoc, Op, PrattParser}; use pest::Parser; use pest_derive::Parser; use thiserror::Error; #[derive(Parser)] #[grammar = "grammar.pest"] pub struct GrammarParser; #[derive(Error, Debug)] pub enum ParseError { #[error("Pest parsing error: {0}")] Pest(#[from] Box<Error<Rule>>), #[error("Invalid number format: {0}")] InvalidNumber(String), #[error("Unknown operator: {0}")] UnknownOperator(String), #[error("Invalid JSON value")] InvalidJson, #[error("Unexpected end of input")] UnexpectedEOF, } pub type Result<T> = std::result::Result<T, ParseError>; #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(f64), Identifier(String), BinOp { left: Box<Expr>, op: BinOperator, right: Box<Expr>, }, } #[derive(Debug, Clone, PartialEq)] pub enum BinOperator { Add, Subtract, Multiply, Divide, Power, Eq, Ne, Lt, Le, Gt, Ge, } impl fmt::Display for BinOperator { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { BinOperator::Add => write!(f, "+"), BinOperator::Subtract => write!(f, "-"), BinOperator::Multiply => write!(f, "*"), BinOperator::Divide => write!(f, "/"), BinOperator::Power => write!(f, "^"), BinOperator::Eq => write!(f, "=="), BinOperator::Ne => write!(f, "!="), BinOperator::Lt => write!(f, "<"), BinOperator::Le => write!(f, "<="), BinOperator::Gt => write!(f, ">"), BinOperator::Ge => write!(f, ">="), } } } #[derive(Debug, Clone, PartialEq)] pub enum JsonValue { Object(Vec<(String, JsonValue)>), Array(Vec<JsonValue>), String(String), Number(f64), Boolean(bool), Null, } #[derive(Debug, Clone, PartialEq)] pub struct Program { pub statements: Vec<Statement>, } #[derive(Debug, Clone, PartialEq)] pub enum Statement { If { condition: Expr, then_block: Vec<Statement>, else_block: Option<Vec<Statement>>, }, While { condition: Expr, body: Vec<Statement>, }, Function { name: String, parameters: Vec<Parameter>, return_type: String, body: Vec<Statement>, }, Assignment { name: String, value: Expr, }, Expression(Expr), Block(Vec<Statement>), } #[derive(Debug, Clone, PartialEq)] pub struct Parameter { pub name: String, pub type_name: String, } #[derive(Debug, Clone, PartialEq)] pub enum LiteralValue { String(String), Number(f64), Boolean(bool), } impl GrammarParser { /// Parse an expression using pest's built-in precedence climbing pub fn parse_expression(input: &str) -> Result<Expr> { let pairs = Self::parse(Rule::expression, input).map_err(Box::new)?; Self::build_expression(pairs) } fn build_expression(pairs: Pairs<Rule>) -> Result<Expr> { let pratt = PrattParser::new() .op(Op::infix(Rule::eq, Assoc::Left) | Op::infix(Rule::ne, Assoc::Left)) .op(Op::infix(Rule::lt, Assoc::Left) | Op::infix(Rule::le, Assoc::Left) | Op::infix(Rule::gt, Assoc::Left) | Op::infix(Rule::ge, Assoc::Left)) .op(Op::infix(Rule::add, Assoc::Left) | Op::infix(Rule::subtract, Assoc::Left)) .op(Op::infix(Rule::multiply, Assoc::Left) | Op::infix(Rule::divide, Assoc::Left)) .op(Op::infix(Rule::power, Assoc::Right)); pratt .map_primary(|primary| match primary.as_rule() { Rule::term => { let inner = primary .into_inner() .next() .ok_or(ParseError::UnexpectedEOF)?; match inner.as_rule() { Rule::number => { let num = inner.as_str().parse::<f64>().map_err(|_| { ParseError::InvalidNumber(inner.as_str().to_string()) })?; Ok(Expr::Number(num)) } Rule::identifier => Ok(Expr::Identifier(inner.as_str().to_string())), Rule::expression => Self::build_expression(inner.into_inner()), _ => unreachable!("Unexpected term rule: {:?}", inner.as_rule()), } } Rule::number => { let num = primary .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(primary.as_str().to_string()))?; Ok(Expr::Number(num)) } Rule::identifier => Ok(Expr::Identifier(primary.as_str().to_string())), Rule::expression => Self::build_expression(primary.into_inner()), _ => unreachable!("Unexpected primary rule: {:?}", primary.as_rule()), }) .map_infix(|left, op, right| { let op = match op.as_rule() { Rule::add => BinOperator::Add, Rule::subtract => BinOperator::Subtract, Rule::multiply => BinOperator::Multiply, Rule::divide => BinOperator::Divide, Rule::power => BinOperator::Power, Rule::eq => BinOperator::Eq, Rule::ne => BinOperator::Ne, Rule::lt => BinOperator::Lt, Rule::le => BinOperator::Le, Rule::gt => BinOperator::Gt, Rule::ge => BinOperator::Ge, _ => return Err(ParseError::UnknownOperator(op.as_str().to_string())), }; Ok(Expr::BinOp { left: Box::new(left?), op, right: Box::new(right?), }) }) .parse(pairs) } /// Parse a calculator expression with explicit precedence rules pub fn parse_calculation(input: &str) -> Result<f64> { let pairs = Self::parse(Rule::calculation, input).map_err(Box::new)?; Self::evaluate_calculation( pairs .into_iter() .next() .ok_or(ParseError::UnexpectedEOF)? .into_inner() .next() .ok_or(ParseError::UnexpectedEOF)?, ) } fn evaluate_calculation(pair: Pair<Rule>) -> Result<f64> { match pair.as_rule() { Rule::calc_expression => { let mut pairs = pair.into_inner(); let mut result = Self::evaluate_calculation(pairs.next().ok_or(ParseError::UnexpectedEOF)?)?; while let (Some(op), Some(term)) = (pairs.next(), pairs.next()) { let term_val = Self::evaluate_calculation(term)?; match op.as_rule() { Rule::calc_add_op => { let inner_op = op.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; match inner_op.as_rule() { Rule::calc_plus => result += term_val, Rule::calc_minus => result -= term_val, _ => unreachable!("Unexpected add op: {:?}", inner_op.as_rule()), } } Rule::calc_plus => result += term_val, Rule::calc_minus => result -= term_val, _ => unreachable!("Unexpected calc expression op: {:?}", op.as_rule()), } } Ok(result) } Rule::calc_term => { let mut pairs = pair.into_inner(); let mut result = Self::evaluate_calculation(pairs.next().ok_or(ParseError::UnexpectedEOF)?)?; while let (Some(op), Some(factor)) = (pairs.next(), pairs.next()) { let factor_val = Self::evaluate_calculation(factor)?; match op.as_rule() { Rule::calc_mul_op => { let inner_op = op.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; match inner_op.as_rule() { Rule::calc_multiply => result *= factor_val, Rule::calc_divide => result /= factor_val, _ => unreachable!("Unexpected mul op: {:?}", inner_op.as_rule()), } } Rule::calc_multiply => result *= factor_val, Rule::calc_divide => result /= factor_val, _ => unreachable!("Unexpected calc term op: {:?}", op.as_rule()), } } Ok(result) } Rule::calc_factor => Self::evaluate_calculation( pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?, ), Rule::calc_power => { let mut pairs = pair.into_inner(); let base = Self::evaluate_calculation(pairs.next().ok_or(ParseError::UnexpectedEOF)?)?; if let Some(op) = pairs.next() { if op.as_rule() == Rule::calc_pow_op { let exponent = Self::evaluate_calculation( pairs.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(base.powf(exponent)) } else { unreachable!("Expected calc_pow_op, got: {:?}", op.as_rule()); } } else { Ok(base) } } Rule::calc_atom => Self::evaluate_calculation( pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?, ), Rule::calc_unary => { let mut pairs = pair.into_inner(); let first = pairs.next().ok_or(ParseError::UnexpectedEOF)?; match first.as_rule() { Rule::calc_minus => { let val = Self::evaluate_calculation( pairs.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(-val) } Rule::calc_plus => { let val = Self::evaluate_calculation( pairs.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(val) } Rule::calc_number => first .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(first.as_str().to_string())), _ => Self::evaluate_calculation(first), } } Rule::calc_number => pair .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(pair.as_str().to_string())), _ => unreachable!("Unexpected rule: {:?}", pair.as_rule()), } } } impl GrammarParser { /// Parse JSON input into a JsonValue AST pub fn parse_json(input: &str) -> Result<JsonValue> { let pairs = Self::parse(Rule::json_value, input).map_err(Box::new)?; Self::build_json_value(pairs.into_iter().next().ok_or(ParseError::UnexpectedEOF)?) } fn build_json_value(pair: Pair<Rule>) -> Result<JsonValue> { match pair.as_rule() { Rule::json_value => { Self::build_json_value(pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?) } Rule::object => { let mut object = Vec::new(); for pair in pair.into_inner() { if let Rule::pair = pair.as_rule() { let mut inner = pair.into_inner(); let key = Self::parse_string(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; let value = Self::build_json_value(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; object.push((key, value)); } } Ok(JsonValue::Object(object)) } Rule::array => { let mut array = Vec::new(); for pair in pair.into_inner() { array.push(Self::build_json_value(pair)?); } Ok(JsonValue::Array(array)) } Rule::string => Ok(JsonValue::String(Self::parse_string(pair)?)), Rule::number => { let num = pair .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(pair.as_str().to_string()))?; Ok(JsonValue::Number(num)) } Rule::boolean => Ok(JsonValue::Boolean(pair.as_str() == "true")), Rule::null => Ok(JsonValue::Null), _ => Err(ParseError::InvalidJson), } } fn parse_string(pair: Pair<Rule>) -> Result<String> { let inner = pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; Ok(inner.as_str().to_string()) } } impl GrammarParser { /// Parse a complete program pub fn parse_program(input: &str) -> Result<Program> { let pairs = Self::parse(Rule::program, input).map_err(Box::new)?; let mut statements = Vec::new(); for pair in pairs .into_iter() .next() .ok_or(ParseError::UnexpectedEOF)? .into_inner() { if pair.as_rule() != Rule::EOI { statements.push(Self::build_statement(pair)?); } } Ok(Program { statements }) } fn build_statement(pair: Pair<Rule>) -> Result<Statement> { match pair.as_rule() { Rule::statement => { Self::build_statement(pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?) } Rule::if_statement => { let mut inner = pair.into_inner(); let condition = Self::build_expression_from_pair( inner.next().ok_or(ParseError::UnexpectedEOF)?, )?; let then_block = Self::build_block(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; let else_block = inner .next() .map(|p| match p.as_rule() { Rule::block => Self::build_block(p), Rule::if_statement => Ok(vec![Self::build_statement(p)?]), _ => unreachable!(), }) .transpose()?; Ok(Statement::If { condition, then_block, else_block, }) } Rule::while_statement => { let mut inner = pair.into_inner(); let condition = Self::build_expression_from_pair( inner.next().ok_or(ParseError::UnexpectedEOF)?, )?; let body = Self::build_block(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; Ok(Statement::While { condition, body }) } Rule::function_def => { let mut inner = pair.into_inner(); let name = inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); let mut parameters = Vec::new(); let mut next = inner.next().ok_or(ParseError::UnexpectedEOF)?; if next.as_rule() == Rule::parameter_list { for param_pair in next.into_inner() { let mut param_inner = param_pair.into_inner(); let param_name = param_inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); let param_type = param_inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); parameters.push(Parameter { name: param_name, type_name: param_type, }); } next = inner.next().ok_or(ParseError::UnexpectedEOF)?; } let return_type = next.as_str().to_string(); let body = Self::build_block(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; Ok(Statement::Function { name, parameters, return_type, body, }) } Rule::assignment => { let mut inner = pair.into_inner(); let name = inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); let value = Self::build_expression_from_pair( inner.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(Statement::Assignment { name, value }) } Rule::expression_statement => { let expr = Self::build_expression_from_pair( pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(Statement::Expression(expr)) } Rule::block => Ok(Statement::Block(Self::build_block(pair)?)), _ => unreachable!("Unexpected statement rule: {:?}", pair.as_rule()), } } fn build_block(pair: Pair<Rule>) -> Result<Vec<Statement>> { let mut statements = Vec::new(); for stmt_pair in pair.into_inner() { statements.push(Self::build_statement(stmt_pair)?); } Ok(statements) } fn build_expression_from_pair(pair: Pair<Rule>) -> Result<Expr> { Self::build_expression(pair.into_inner()) } } impl GrammarParser { /// Parse input into a stream of tokens pub fn parse_tokens(input: &str) -> Result<Vec<Token>> { let pairs = Self::parse(Rule::token_stream, input).map_err(Box::new)?; let mut tokens = Vec::new(); for pair in pairs .into_iter() .next() .ok_or(ParseError::UnexpectedEOF)? .into_inner() { if pair.as_rule() != Rule::EOI { tokens.push(Self::build_token(pair)?); } } Ok(tokens) } fn build_token(pair: Pair<Rule>) -> Result<Token> { match pair.as_rule() { Rule::keyword => Ok(Token::Keyword(pair.as_str().to_string())), Rule::operator_token => Ok(Token::Operator(pair.as_str().to_string())), Rule::punctuation => Ok(Token::Punctuation(pair.as_str().to_string())), Rule::literal => { let inner = pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; match inner.as_rule() { Rule::string_literal => { let content = inner .into_inner() .next() .ok_or(ParseError::UnexpectedEOF)? .as_str(); Ok(Token::Literal(LiteralValue::String(content.to_string()))) } Rule::number_literal => { let num = inner .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(inner.as_str().to_string()))?; Ok(Token::Literal(LiteralValue::Number(num))) } Rule::boolean_literal => Ok(Token::Literal(LiteralValue::Boolean( inner.as_str() == "true", ))), _ => unreachable!(), } } Rule::identifier_token => Ok(Token::Identifier(pair.as_str().to_string())), _ => unreachable!("Unexpected token rule: {:?}", pair.as_rule()), } } } impl GrammarParser { /// Parse and print pest parse tree for debugging pub fn debug_parse(rule: Rule, input: &str) -> Result<()> { let pairs = Self::parse(rule, input).map_err(Box::new)?; for pair in pairs { Self::print_pair(&pair, 0); } Ok(()) } fn print_pair(pair: &Pair<Rule>, indent: usize) { let indent_str = " ".repeat(indent); println!("{}{:?}: \"{}\"", indent_str, pair.as_rule(), pair.as_str()); for inner_pair in pair.clone().into_inner() { Self::print_pair(&inner_pair, indent + 1); } } /// Extract all identifiers from an expression pub fn extract_identifiers(expr: &Expr) -> Vec<String> { match expr { Expr::Identifier(name) => vec![name.clone()], Expr::BinOp { left, right, .. } => { let mut ids = Self::extract_identifiers(left); ids.extend(Self::extract_identifiers(right)); ids } Expr::Number(_) => vec![], } } /// Check if a rule matches the complete input pub fn can_parse(rule: Rule, input: &str) -> bool { match Self::parse(rule, input) { Ok(pairs) => { // Check that the entire input is consumed let input_len = input.len(); let parsed_len = pairs.as_str().len(); parsed_len == input_len } Err(_) => false, } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_expression_parsing() { let expr = GrammarParser::parse_expression("2 + 3 * 4").unwrap(); match expr { Expr::BinOp { op: BinOperator::Add, .. } => (), _ => panic!("Expected addition at top level"), } } #[test] fn test_calculation() { assert_eq!(GrammarParser::parse_calculation("2 + 3 * 4").unwrap(), 14.0); assert_eq!( GrammarParser::parse_calculation("(2 + 3) * 4").unwrap(), 20.0 ); assert_eq!( GrammarParser::parse_calculation("2 ^ 3 ^ 2").unwrap(), 512.0 ); } #[test] fn test_json_parsing() { let json = r#"{"name": "test", "value": 42, "active": true}"#; let result = GrammarParser::parse_json(json).unwrap(); if let JsonValue::Object(obj) = result { assert_eq!(obj.len(), 3); } else { panic!("Expected JSON object"); } } #[test] fn test_program_parsing() { let program = r#" fn add(x: int, y: int) -> int { x + y; } if x > 0 { y = 42; } "#; let result = GrammarParser::parse_program(program).unwrap(); assert_eq!(result.statements.len(), 2); } #[test] fn test_token_parsing() { let input = "if x == 42 { return true; }"; let tokens = GrammarParser::parse_tokens(input).unwrap(); assert!(tokens.len() > 5); match &tokens[0] { Token::Keyword(kw) => assert_eq!(kw, "if"), _ => panic!("Expected keyword"), } } #[test] fn test_identifier_extraction() { let expr = GrammarParser::parse_expression("x + y * z").unwrap(); let ids = GrammarParser::extract_identifiers(&expr); assert_eq!(ids, vec!["x", "y", "z"]); } #[test] fn test_debug_features() { assert!(GrammarParser::can_parse(Rule::expression, "2 + 3")); assert!(!GrammarParser::can_parse(Rule::expression, "2 +")); } } #[derive(Debug, Clone, PartialEq)] pub enum Token { Keyword(String), Operator(String), Punctuation(String), Literal(LiteralValue), Identifier(String), } }
Token extraction demonstrates pest’s ability to function as a lexer. The grammar identifies different token types while preserving their textual representation.
#![allow(unused)] fn main() { /// Parse input into a stream of tokens pub fn parse_tokens(input: &str) -> Result<Vec<Token>> { let pairs = Self::parse(Rule::token_stream, input).map_err(Box::new)?; let mut tokens = Vec::new(); for pair in pairs .into_iter() .next() .ok_or(ParseError::UnexpectedEOF)? .into_inner() { if pair.as_rule() != Rule::EOI { tokens.push(Self::build_token(pair)?); } } Ok(tokens) } }
Token stream parsing extracts lexical elements from source text. Each token preserves its span information for error reporting and source mapping. The approach separates lexical analysis from syntactic parsing when needed.
Error Handling
#![allow(unused)] fn main() { use std::fmt; use pest::error::Error; use pest::iterators::{Pair, Pairs}; use pest::pratt_parser::{Assoc, Op, PrattParser}; use pest::Parser; use pest_derive::Parser; use thiserror::Error; #[derive(Parser)] #[grammar = "grammar.pest"] pub struct GrammarParser; pub type Result<T> = std::result::Result<T, ParseError>; #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(f64), Identifier(String), BinOp { left: Box<Expr>, op: BinOperator, right: Box<Expr>, }, } #[derive(Debug, Clone, PartialEq)] pub enum BinOperator { Add, Subtract, Multiply, Divide, Power, Eq, Ne, Lt, Le, Gt, Ge, } impl fmt::Display for BinOperator { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { BinOperator::Add => write!(f, "+"), BinOperator::Subtract => write!(f, "-"), BinOperator::Multiply => write!(f, "*"), BinOperator::Divide => write!(f, "/"), BinOperator::Power => write!(f, "^"), BinOperator::Eq => write!(f, "=="), BinOperator::Ne => write!(f, "!="), BinOperator::Lt => write!(f, "<"), BinOperator::Le => write!(f, "<="), BinOperator::Gt => write!(f, ">"), BinOperator::Ge => write!(f, ">="), } } } #[derive(Debug, Clone, PartialEq)] pub enum JsonValue { Object(Vec<(String, JsonValue)>), Array(Vec<JsonValue>), String(String), Number(f64), Boolean(bool), Null, } #[derive(Debug, Clone, PartialEq)] pub struct Program { pub statements: Vec<Statement>, } #[derive(Debug, Clone, PartialEq)] pub enum Statement { If { condition: Expr, then_block: Vec<Statement>, else_block: Option<Vec<Statement>>, }, While { condition: Expr, body: Vec<Statement>, }, Function { name: String, parameters: Vec<Parameter>, return_type: String, body: Vec<Statement>, }, Assignment { name: String, value: Expr, }, Expression(Expr), Block(Vec<Statement>), } #[derive(Debug, Clone, PartialEq)] pub struct Parameter { pub name: String, pub type_name: String, } #[derive(Debug, Clone, PartialEq)] pub enum Token { Keyword(String), Operator(String), Punctuation(String), Literal(LiteralValue), Identifier(String), } #[derive(Debug, Clone, PartialEq)] pub enum LiteralValue { String(String), Number(f64), Boolean(bool), } impl GrammarParser { /// Parse an expression using pest's built-in precedence climbing pub fn parse_expression(input: &str) -> Result<Expr> { let pairs = Self::parse(Rule::expression, input).map_err(Box::new)?; Self::build_expression(pairs) } fn build_expression(pairs: Pairs<Rule>) -> Result<Expr> { let pratt = PrattParser::new() .op(Op::infix(Rule::eq, Assoc::Left) | Op::infix(Rule::ne, Assoc::Left)) .op(Op::infix(Rule::lt, Assoc::Left) | Op::infix(Rule::le, Assoc::Left) | Op::infix(Rule::gt, Assoc::Left) | Op::infix(Rule::ge, Assoc::Left)) .op(Op::infix(Rule::add, Assoc::Left) | Op::infix(Rule::subtract, Assoc::Left)) .op(Op::infix(Rule::multiply, Assoc::Left) | Op::infix(Rule::divide, Assoc::Left)) .op(Op::infix(Rule::power, Assoc::Right)); pratt .map_primary(|primary| match primary.as_rule() { Rule::term => { let inner = primary .into_inner() .next() .ok_or(ParseError::UnexpectedEOF)?; match inner.as_rule() { Rule::number => { let num = inner.as_str().parse::<f64>().map_err(|_| { ParseError::InvalidNumber(inner.as_str().to_string()) })?; Ok(Expr::Number(num)) } Rule::identifier => Ok(Expr::Identifier(inner.as_str().to_string())), Rule::expression => Self::build_expression(inner.into_inner()), _ => unreachable!("Unexpected term rule: {:?}", inner.as_rule()), } } Rule::number => { let num = primary .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(primary.as_str().to_string()))?; Ok(Expr::Number(num)) } Rule::identifier => Ok(Expr::Identifier(primary.as_str().to_string())), Rule::expression => Self::build_expression(primary.into_inner()), _ => unreachable!("Unexpected primary rule: {:?}", primary.as_rule()), }) .map_infix(|left, op, right| { let op = match op.as_rule() { Rule::add => BinOperator::Add, Rule::subtract => BinOperator::Subtract, Rule::multiply => BinOperator::Multiply, Rule::divide => BinOperator::Divide, Rule::power => BinOperator::Power, Rule::eq => BinOperator::Eq, Rule::ne => BinOperator::Ne, Rule::lt => BinOperator::Lt, Rule::le => BinOperator::Le, Rule::gt => BinOperator::Gt, Rule::ge => BinOperator::Ge, _ => return Err(ParseError::UnknownOperator(op.as_str().to_string())), }; Ok(Expr::BinOp { left: Box::new(left?), op, right: Box::new(right?), }) }) .parse(pairs) } /// Parse a calculator expression with explicit precedence rules pub fn parse_calculation(input: &str) -> Result<f64> { let pairs = Self::parse(Rule::calculation, input).map_err(Box::new)?; Self::evaluate_calculation( pairs .into_iter() .next() .ok_or(ParseError::UnexpectedEOF)? .into_inner() .next() .ok_or(ParseError::UnexpectedEOF)?, ) } fn evaluate_calculation(pair: Pair<Rule>) -> Result<f64> { match pair.as_rule() { Rule::calc_expression => { let mut pairs = pair.into_inner(); let mut result = Self::evaluate_calculation(pairs.next().ok_or(ParseError::UnexpectedEOF)?)?; while let (Some(op), Some(term)) = (pairs.next(), pairs.next()) { let term_val = Self::evaluate_calculation(term)?; match op.as_rule() { Rule::calc_add_op => { let inner_op = op.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; match inner_op.as_rule() { Rule::calc_plus => result += term_val, Rule::calc_minus => result -= term_val, _ => unreachable!("Unexpected add op: {:?}", inner_op.as_rule()), } } Rule::calc_plus => result += term_val, Rule::calc_minus => result -= term_val, _ => unreachable!("Unexpected calc expression op: {:?}", op.as_rule()), } } Ok(result) } Rule::calc_term => { let mut pairs = pair.into_inner(); let mut result = Self::evaluate_calculation(pairs.next().ok_or(ParseError::UnexpectedEOF)?)?; while let (Some(op), Some(factor)) = (pairs.next(), pairs.next()) { let factor_val = Self::evaluate_calculation(factor)?; match op.as_rule() { Rule::calc_mul_op => { let inner_op = op.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; match inner_op.as_rule() { Rule::calc_multiply => result *= factor_val, Rule::calc_divide => result /= factor_val, _ => unreachable!("Unexpected mul op: {:?}", inner_op.as_rule()), } } Rule::calc_multiply => result *= factor_val, Rule::calc_divide => result /= factor_val, _ => unreachable!("Unexpected calc term op: {:?}", op.as_rule()), } } Ok(result) } Rule::calc_factor => Self::evaluate_calculation( pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?, ), Rule::calc_power => { let mut pairs = pair.into_inner(); let base = Self::evaluate_calculation(pairs.next().ok_or(ParseError::UnexpectedEOF)?)?; if let Some(op) = pairs.next() { if op.as_rule() == Rule::calc_pow_op { let exponent = Self::evaluate_calculation( pairs.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(base.powf(exponent)) } else { unreachable!("Expected calc_pow_op, got: {:?}", op.as_rule()); } } else { Ok(base) } } Rule::calc_atom => Self::evaluate_calculation( pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?, ), Rule::calc_unary => { let mut pairs = pair.into_inner(); let first = pairs.next().ok_or(ParseError::UnexpectedEOF)?; match first.as_rule() { Rule::calc_minus => { let val = Self::evaluate_calculation( pairs.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(-val) } Rule::calc_plus => { let val = Self::evaluate_calculation( pairs.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(val) } Rule::calc_number => first .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(first.as_str().to_string())), _ => Self::evaluate_calculation(first), } } Rule::calc_number => pair .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(pair.as_str().to_string())), _ => unreachable!("Unexpected rule: {:?}", pair.as_rule()), } } } impl GrammarParser { /// Parse JSON input into a JsonValue AST pub fn parse_json(input: &str) -> Result<JsonValue> { let pairs = Self::parse(Rule::json_value, input).map_err(Box::new)?; Self::build_json_value(pairs.into_iter().next().ok_or(ParseError::UnexpectedEOF)?) } fn build_json_value(pair: Pair<Rule>) -> Result<JsonValue> { match pair.as_rule() { Rule::json_value => { Self::build_json_value(pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?) } Rule::object => { let mut object = Vec::new(); for pair in pair.into_inner() { if let Rule::pair = pair.as_rule() { let mut inner = pair.into_inner(); let key = Self::parse_string(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; let value = Self::build_json_value(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; object.push((key, value)); } } Ok(JsonValue::Object(object)) } Rule::array => { let mut array = Vec::new(); for pair in pair.into_inner() { array.push(Self::build_json_value(pair)?); } Ok(JsonValue::Array(array)) } Rule::string => Ok(JsonValue::String(Self::parse_string(pair)?)), Rule::number => { let num = pair .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(pair.as_str().to_string()))?; Ok(JsonValue::Number(num)) } Rule::boolean => Ok(JsonValue::Boolean(pair.as_str() == "true")), Rule::null => Ok(JsonValue::Null), _ => Err(ParseError::InvalidJson), } } fn parse_string(pair: Pair<Rule>) -> Result<String> { let inner = pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; Ok(inner.as_str().to_string()) } } impl GrammarParser { /// Parse a complete program pub fn parse_program(input: &str) -> Result<Program> { let pairs = Self::parse(Rule::program, input).map_err(Box::new)?; let mut statements = Vec::new(); for pair in pairs .into_iter() .next() .ok_or(ParseError::UnexpectedEOF)? .into_inner() { if pair.as_rule() != Rule::EOI { statements.push(Self::build_statement(pair)?); } } Ok(Program { statements }) } fn build_statement(pair: Pair<Rule>) -> Result<Statement> { match pair.as_rule() { Rule::statement => { Self::build_statement(pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?) } Rule::if_statement => { let mut inner = pair.into_inner(); let condition = Self::build_expression_from_pair( inner.next().ok_or(ParseError::UnexpectedEOF)?, )?; let then_block = Self::build_block(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; let else_block = inner .next() .map(|p| match p.as_rule() { Rule::block => Self::build_block(p), Rule::if_statement => Ok(vec![Self::build_statement(p)?]), _ => unreachable!(), }) .transpose()?; Ok(Statement::If { condition, then_block, else_block, }) } Rule::while_statement => { let mut inner = pair.into_inner(); let condition = Self::build_expression_from_pair( inner.next().ok_or(ParseError::UnexpectedEOF)?, )?; let body = Self::build_block(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; Ok(Statement::While { condition, body }) } Rule::function_def => { let mut inner = pair.into_inner(); let name = inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); let mut parameters = Vec::new(); let mut next = inner.next().ok_or(ParseError::UnexpectedEOF)?; if next.as_rule() == Rule::parameter_list { for param_pair in next.into_inner() { let mut param_inner = param_pair.into_inner(); let param_name = param_inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); let param_type = param_inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); parameters.push(Parameter { name: param_name, type_name: param_type, }); } next = inner.next().ok_or(ParseError::UnexpectedEOF)?; } let return_type = next.as_str().to_string(); let body = Self::build_block(inner.next().ok_or(ParseError::UnexpectedEOF)?)?; Ok(Statement::Function { name, parameters, return_type, body, }) } Rule::assignment => { let mut inner = pair.into_inner(); let name = inner .next() .ok_or(ParseError::UnexpectedEOF)? .as_str() .to_string(); let value = Self::build_expression_from_pair( inner.next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(Statement::Assignment { name, value }) } Rule::expression_statement => { let expr = Self::build_expression_from_pair( pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?, )?; Ok(Statement::Expression(expr)) } Rule::block => Ok(Statement::Block(Self::build_block(pair)?)), _ => unreachable!("Unexpected statement rule: {:?}", pair.as_rule()), } } fn build_block(pair: Pair<Rule>) -> Result<Vec<Statement>> { let mut statements = Vec::new(); for stmt_pair in pair.into_inner() { statements.push(Self::build_statement(stmt_pair)?); } Ok(statements) } fn build_expression_from_pair(pair: Pair<Rule>) -> Result<Expr> { Self::build_expression(pair.into_inner()) } } impl GrammarParser { /// Parse input into a stream of tokens pub fn parse_tokens(input: &str) -> Result<Vec<Token>> { let pairs = Self::parse(Rule::token_stream, input).map_err(Box::new)?; let mut tokens = Vec::new(); for pair in pairs .into_iter() .next() .ok_or(ParseError::UnexpectedEOF)? .into_inner() { if pair.as_rule() != Rule::EOI { tokens.push(Self::build_token(pair)?); } } Ok(tokens) } fn build_token(pair: Pair<Rule>) -> Result<Token> { match pair.as_rule() { Rule::keyword => Ok(Token::Keyword(pair.as_str().to_string())), Rule::operator_token => Ok(Token::Operator(pair.as_str().to_string())), Rule::punctuation => Ok(Token::Punctuation(pair.as_str().to_string())), Rule::literal => { let inner = pair.into_inner().next().ok_or(ParseError::UnexpectedEOF)?; match inner.as_rule() { Rule::string_literal => { let content = inner .into_inner() .next() .ok_or(ParseError::UnexpectedEOF)? .as_str(); Ok(Token::Literal(LiteralValue::String(content.to_string()))) } Rule::number_literal => { let num = inner .as_str() .parse::<f64>() .map_err(|_| ParseError::InvalidNumber(inner.as_str().to_string()))?; Ok(Token::Literal(LiteralValue::Number(num))) } Rule::boolean_literal => Ok(Token::Literal(LiteralValue::Boolean( inner.as_str() == "true", ))), _ => unreachable!(), } } Rule::identifier_token => Ok(Token::Identifier(pair.as_str().to_string())), _ => unreachable!("Unexpected token rule: {:?}", pair.as_rule()), } } } impl GrammarParser { /// Parse and print pest parse tree for debugging pub fn debug_parse(rule: Rule, input: &str) -> Result<()> { let pairs = Self::parse(rule, input).map_err(Box::new)?; for pair in pairs { Self::print_pair(&pair, 0); } Ok(()) } fn print_pair(pair: &Pair<Rule>, indent: usize) { let indent_str = " ".repeat(indent); println!("{}{:?}: \"{}\"", indent_str, pair.as_rule(), pair.as_str()); for inner_pair in pair.clone().into_inner() { Self::print_pair(&inner_pair, indent + 1); } } /// Extract all identifiers from an expression pub fn extract_identifiers(expr: &Expr) -> Vec<String> { match expr { Expr::Identifier(name) => vec![name.clone()], Expr::BinOp { left, right, .. } => { let mut ids = Self::extract_identifiers(left); ids.extend(Self::extract_identifiers(right)); ids } Expr::Number(_) => vec![], } } /// Check if a rule matches the complete input pub fn can_parse(rule: Rule, input: &str) -> bool { match Self::parse(rule, input) { Ok(pairs) => { // Check that the entire input is consumed let input_len = input.len(); let parsed_len = pairs.as_str().len(); parsed_len == input_len } Err(_) => false, } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_expression_parsing() { let expr = GrammarParser::parse_expression("2 + 3 * 4").unwrap(); match expr { Expr::BinOp { op: BinOperator::Add, .. } => (), _ => panic!("Expected addition at top level"), } } #[test] fn test_calculation() { assert_eq!(GrammarParser::parse_calculation("2 + 3 * 4").unwrap(), 14.0); assert_eq!( GrammarParser::parse_calculation("(2 + 3) * 4").unwrap(), 20.0 ); assert_eq!( GrammarParser::parse_calculation("2 ^ 3 ^ 2").unwrap(), 512.0 ); } #[test] fn test_json_parsing() { let json = r#"{"name": "test", "value": 42, "active": true}"#; let result = GrammarParser::parse_json(json).unwrap(); if let JsonValue::Object(obj) = result { assert_eq!(obj.len(), 3); } else { panic!("Expected JSON object"); } } #[test] fn test_program_parsing() { let program = r#" fn add(x: int, y: int) -> int { x + y; } if x > 0 { y = 42; } "#; let result = GrammarParser::parse_program(program).unwrap(); assert_eq!(result.statements.len(), 2); } #[test] fn test_token_parsing() { let input = "if x == 42 { return true; }"; let tokens = GrammarParser::parse_tokens(input).unwrap(); assert!(tokens.len() > 5); match &tokens[0] { Token::Keyword(kw) => assert_eq!(kw, "if"), _ => panic!("Expected keyword"), } } #[test] fn test_identifier_extraction() { let expr = GrammarParser::parse_expression("x + y * z").unwrap(); let ids = GrammarParser::extract_identifiers(&expr); assert_eq!(ids, vec!["x", "y", "z"]); } #[test] fn test_debug_features() { assert!(GrammarParser::can_parse(Rule::expression, "2 + 3")); assert!(!GrammarParser::can_parse(Rule::expression, "2 +")); } } #[derive(Error, Debug)] pub enum ParseError { #[error("Pest parsing error: {0}")] Pest(#[from] Box<Error<Rule>>), #[error("Invalid number format: {0}")] InvalidNumber(String), #[error("Unknown operator: {0}")] UnknownOperator(String), #[error("Invalid JSON value")] InvalidJson, #[error("Unexpected end of input")] UnexpectedEOF, } }
Custom error types provide context-specific error messages beyond pest’s built-in reporting. The error variants correspond to different failure modes in parsing and semantic analysis.
#![allow(unused)] fn main() { /// Parse and print pest parse tree for debugging pub fn debug_parse(rule: Rule, input: &str) -> Result<()> { let pairs = Self::parse(rule, input).map_err(Box::new)?; for pair in pairs { Self::print_pair(&pair, 0); } Ok(()) } }
Debug parsing enriches pest’s error messages with domain-specific information. The function wraps pest errors with additional context about what was being parsed. Line and column information from pest integrates with custom error formatting.
Utility Functions
#![allow(unused)] fn main() { /// Extract all identifiers from an expression pub fn extract_identifiers(expr: &Expr) -> Vec<String> { match expr { Expr::Identifier(name) => vec![name.clone()], Expr::BinOp { left, right, .. } => { let mut ids = Self::extract_identifiers(left); ids.extend(Self::extract_identifiers(right)); ids } Expr::Number(_) => vec![], } } }
Helper functions simplify common parsing patterns. Identifier extraction demonstrates traversing the parse tree to collect specific elements. The visitor pattern works well with pest’s pair structure for gathering information.
Debug utilities help understand pest’s parse tree structure during development. The function recursively prints the tree with indentation showing nesting levels. Rule names and captured text provide insight into how the grammar matched input.
Grammar Design Patterns
Pest grammars benefit from consistent structure and naming conventions. Use snake_case for rule names and UPPER_CASE for token constants. Group related rules together with comments explaining their purpose. Silent rules with underscore prefixes hide implementation details from the parse tree.
Whitespace handling deserves special attention in grammar design. The built-in WHITESPACE rule automatically skips whitespace between tokens. Atomic rules disable automatic whitespace handling when exact matching is required. Use push and pop operations for significant indentation in languages like Python.
Comments can be handled uniformly through the COMMENT rule. Single-line and multi-line comment patterns integrate naturally with automatic skipping. This approach keeps the main grammar rules clean and focused on language structure.
Performance Optimization
Pest generates efficient parsers through compile-time code generation. The generated code uses backtracking only when necessary, preferring deterministic parsing where possible. Memoization of rule results prevents redundant parsing of the same input.
Rule ordering impacts performance in choice expressions. Place more common alternatives first to reduce backtracking. Use atomic rules to prevent unnecessary whitespace checking in tight loops. Consider breaking complex rules into smaller components for better caching.
The precedence climbing algorithm provides optimal performance for expression parsing. Unlike naive recursive descent, it avoids deep recursion for left-associative operators. The algorithm handles arbitrary precedence levels efficiently without grammar transformations.
Integration Patterns
Pest integrates well with other Rust compiler infrastructure. Parse results can be converted to spans for error reporting libraries like codespan or ariadne. The AST types can implement serde traits for serialization or visitor patterns for analysis passes.
Incremental parsing can be implemented by caching parse results for unchanged input sections. The stateless nature of pest parsers enables parallel parsing of independent input chunks. Custom pair processing can extract only needed information without full AST construction.
Testing pest grammars requires attention to both positive and negative cases. Use pest’s built-in testing syntax in grammar files for quick validation. Integration tests should verify AST construction and error handling. Property-based testing can validate grammar properties like precedence and associativity.
Best Practices
Keep grammars readable and maintainable by avoiding overly complex rules. Break down complicated patterns into named sub-rules that document their purpose. Use meaningful rule names that correspond to language concepts rather than implementation details.
Version control grammar files alongside implementation code. Document grammar changes and their rationale in commit messages. Consider grammar compatibility when evolving languages to avoid breaking existing code.
Profile parser performance on representative input to identify bottlenecks. Complex backtracking patterns or excessive rule nesting can impact performance. Use pest’s built-in debugging features to understand parsing behavior on problematic input.
Handle errors gracefully with informative messages. Pest’s automatic error reporting provides good default messages, but custom errors can add domain-specific context. Consider recovery strategies for IDE integration where partial results are valuable.
rowan
The rowan
crate provides a foundation for building lossless syntax trees that preserve all source text including whitespace and comments. This architecture forms the basis of rust-analyzer and enables incremental reparsing, precise error recovery, and full-fidelity source transformations. Unlike traditional abstract syntax trees that discard formatting information, rowan maintains a complete representation of the input text while providing efficient tree traversal and manipulation operations.
The core innovation of rowan lies in its separation of untyped green trees from typed red trees. Green trees are immutable, shareable nodes that store the actual data, while red trees provide a typed API with parent pointers and absolute positions. This dual structure enables both memory efficiency through structural sharing and ergonomic traversal through the red tree API. The design supports incremental updates by allowing subtrees to be reused when portions of the source remain unchanged.
Basic Usage
Rowan operates through a language definition that maps syntax kinds to tree nodes:
#![allow(unused)] fn main() { use rowan::{GreenNode, GreenNodeBuilder, Language, SyntaxNode, SyntaxToken, TextRange, TextSize}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u16)] pub enum SyntaxKind { Whitespace = 0, Comment, Ident, Number, String, Plus, Minus, Star, Slash, Eq, Neq, Lt, Gt, LParen, RParen, LBrace, RBrace, Semicolon, Comma, Keyword, Error, Root, BinaryExpr, UnaryExpr, ParenExpr, Literal, BlockStmt, ExprStmt, LetStmt, IfStmt, WhileStmt, ReturnStmt, FnDef, ParamList, TypeRef, Path, CallExpr, ArgList, } impl From<SyntaxKind> for rowan::SyntaxKind { fn from(kind: SyntaxKind) -> Self { Self(kind as u16) } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Lang {} impl Language for Lang { type Kind = SyntaxKind; fn kind_from_raw(raw: rowan::SyntaxKind) -> Self::Kind { assert!(raw.0 <= SyntaxKind::ArgList as u16); unsafe { std::mem::transmute::<u16, SyntaxKind>(raw.0) } } fn kind_to_raw(kind: Self::Kind) -> rowan::SyntaxKind { kind.into() } } pub type SyntaxNodeRef = SyntaxNode<Lang>; pub type SyntaxTokenRef = SyntaxToken<Lang>; #[derive(Debug, Clone)] pub struct ParseResult { pub green_node: GreenNode, pub errors: Vec<ParseError>, } #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub range: TextRange, } pub struct Parser { builder: GreenNodeBuilder<'static>, errors: Vec<ParseError>, tokens: Vec<Token>, cursor: usize, } #[derive(Debug, Clone)] pub struct Token { pub kind: SyntaxKind, pub text: String, pub offset: TextSize, } impl Parser { pub fn new(tokens: Vec<Token>) -> Self { Self { builder: GreenNodeBuilder::new(), errors: Vec::new(), tokens, cursor: 0, } } pub fn parse(mut self) -> ParseResult { self.builder.start_node(SyntaxKind::Root.into()); while !self.at_end() { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.builder.finish_node(); ParseResult { green_node: self.builder.finish(), errors: self.errors, } } fn statement(&mut self) { match self.current_kind() { Some(SyntaxKind::Keyword) if self.current_text() == Some("let") => { self.let_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("if") => { self.if_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("while") => { self.while_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("return") => { self.return_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("fn") => { self.function_definition(); } _ => { self.expression_statement(); } } } fn let_statement(&mut self) { self.builder.start_node(SyntaxKind::LetStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Eq) { self.consume(SyntaxKind::Eq); self.skip_trivia(); self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn if_statement(&mut self) { self.builder.start_node(SyntaxKind::IfStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); if self.at_keyword("else") { self.consume(SyntaxKind::Keyword); self.skip_trivia(); if self.at_keyword("if") { self.if_statement(); } else { self.block(); } } self.builder.finish_node(); } fn while_statement(&mut self) { self.builder.start_node(SyntaxKind::WhileStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn return_statement(&mut self) { self.builder.start_node(SyntaxKind::ReturnStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); if !self.at(SyntaxKind::Semicolon) { self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn function_definition(&mut self) { self.builder.start_node(SyntaxKind::FnDef.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); self.parameter_list(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn parameter_list(&mut self) { self.builder.start_node(SyntaxKind::ParamList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn block(&mut self) { self.builder.start_node(SyntaxKind::BlockStmt.into()); self.consume(SyntaxKind::LBrace); while !self.at_end() && !self.at(SyntaxKind::RBrace) { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.consume(SyntaxKind::RBrace); self.builder.finish_node(); } fn expression_statement(&mut self) { self.builder.start_node(SyntaxKind::ExprStmt.into()); self.expression(); self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn expression(&mut self) { self.binary_expression(0); } fn binary_expression(&mut self, min_precedence: u8) { self.unary_expression(); // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } while let Some(op_precedence) = self.current_binary_op_precedence() { if op_precedence < min_precedence { break; } let checkpoint = self.builder.checkpoint(); if let Some( k @ (SyntaxKind::Plus | SyntaxKind::Minus | SyntaxKind::Star | SyntaxKind::Slash | SyntaxKind::Eq | SyntaxKind::Neq | SyntaxKind::Lt | SyntaxKind::Gt), ) = self.current_kind() { self.consume(k); } // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.binary_expression(op_precedence + 1); self.builder .start_node_at(checkpoint, SyntaxKind::BinaryExpr.into()); self.builder.finish_node(); } } fn unary_expression(&mut self) { if self.at(SyntaxKind::Minus) || self.at(SyntaxKind::Plus) { self.builder.start_node(SyntaxKind::UnaryExpr.into()); self.consume(self.current_kind().unwrap()); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.unary_expression(); self.builder.finish_node(); } else { self.postfix_expression(); } } fn postfix_expression(&mut self) { self.primary_expression(); while self.at(SyntaxKind::LParen) { self.builder.start_node(SyntaxKind::CallExpr.into()); let checkpoint = self.builder.checkpoint(); self.argument_list(); self.builder .start_node_at(checkpoint, SyntaxKind::CallExpr.into()); self.builder.finish_node(); } } fn argument_list(&mut self) { self.builder.start_node(SyntaxKind::ArgList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.expression(); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn primary_expression(&mut self) { match self.current_kind() { Some(SyntaxKind::Number) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::Number); self.builder.finish_node(); } Some(SyntaxKind::String) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::String); self.builder.finish_node(); } Some(SyntaxKind::Ident) => { self.consume(SyntaxKind::Ident); } Some(SyntaxKind::LParen) => { self.builder.start_node(SyntaxKind::ParenExpr.into()); self.consume(SyntaxKind::LParen); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.expression(); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } _ => { self.error("Expected expression"); self.advance(); } } } fn current_binary_op_precedence(&self) -> Option<u8> { match self.current_kind()? { SyntaxKind::Star | SyntaxKind::Slash => Some(5), SyntaxKind::Plus | SyntaxKind::Minus => Some(4), SyntaxKind::Lt | SyntaxKind::Gt => Some(3), SyntaxKind::Eq | SyntaxKind::Neq => Some(2), _ => None, } } fn trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { let kind = self.current_kind().unwrap(); self.consume(kind); } } fn skip_trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.advance(); } } fn consume(&mut self, expected: SyntaxKind) { if self.at(expected) { let token = &self.tokens[self.cursor]; self.builder.token(expected.into(), &token.text); self.advance(); } else { self.error(&format!("Expected {:?}", expected)); } } fn at(&self, kind: SyntaxKind) -> bool { self.current_kind() == Some(kind) } fn at_keyword(&self, keyword: &str) -> bool { self.at(SyntaxKind::Keyword) && self.current_text() == Some(keyword) } fn current_kind(&self) -> Option<SyntaxKind> { self.tokens.get(self.cursor).map(|t| t.kind) } fn current_text(&self) -> Option<&str> { self.tokens.get(self.cursor).map(|t| t.text.as_str()) } fn advance(&mut self) { if self.cursor < self.tokens.len() { self.cursor += 1; } } fn at_end(&self) -> bool { self.cursor >= self.tokens.len() } fn error(&mut self, message: &str) { let offset = self .tokens .get(self.cursor) .map(|t| t.offset) .unwrap_or_else(|| TextSize::from(0)); self.errors.push(ParseError { message: message.to_string(), range: TextRange::empty(offset), }); } } pub fn tokenize(input: &str) -> Vec<Token> { let mut tokens = Vec::new(); let mut offset = TextSize::from(0); let mut chars = input.chars().peekable(); while let Some(ch) = chars.next() { let start = offset; offset += TextSize::of(ch); let (kind, text) = match ch { ' ' | '\t' | '\n' | '\r' => { let mut text = String::from(ch); while let Some(&next) = chars.peek() { if next.is_whitespace() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Whitespace, text) } '/' if chars.peek() == Some(&'/') => { chars.next(); offset += TextSize::of('/'); let mut text = String::from("//"); while let Some(&next) = chars.peek() { if next != '\n' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Comment, text) } '+' => (SyntaxKind::Plus, String::from("+")), '-' => (SyntaxKind::Minus, String::from("-")), '*' => (SyntaxKind::Star, String::from("*")), '/' => (SyntaxKind::Slash, String::from("/")), '=' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Eq, String::from("==")) } '=' => (SyntaxKind::Eq, String::from("=")), '!' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Neq, String::from("!=")) } '<' => (SyntaxKind::Lt, String::from("<")), '>' => (SyntaxKind::Gt, String::from(">")), '(' => (SyntaxKind::LParen, String::from("(")), ')' => (SyntaxKind::RParen, String::from(")")), '{' => (SyntaxKind::LBrace, String::from("{")), '}' => (SyntaxKind::RBrace, String::from("}")), ';' => (SyntaxKind::Semicolon, String::from(";")), ',' => (SyntaxKind::Comma, String::from(",")), '"' => { let mut text = String::from("\""); while let Some(&next) = chars.peek() { text.push(next); offset += TextSize::of(next); chars.next(); if next == '"' { break; } } (SyntaxKind::String, text) } c if c.is_ascii_digit() => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_digit() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Number, text) } c if c.is_ascii_alphabetic() || c == '_' => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_alphanumeric() || next == '_' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } let kind = match text.as_str() { "let" | "if" | "else" | "while" | "for" | "fn" | "return" | "true" | "false" | "struct" | "enum" | "impl" => SyntaxKind::Keyword, _ => SyntaxKind::Ident, }; (kind, text) } _ => (SyntaxKind::Error, String::from(ch)), }; tokens.push(Token { kind, text, offset: start, }); } tokens } #[derive(Debug)] pub struct IncrementalReparser { _old_tree: SyntaxNodeRef, edits: Vec<TextEdit>, } #[derive(Debug, Clone)] pub struct TextEdit { pub range: TextRange, pub new_text: String, } impl IncrementalReparser { pub fn new(tree: SyntaxNodeRef) -> Self { Self { _old_tree: tree, edits: Vec::new(), } } pub fn add_edit(&mut self, edit: TextEdit) { self.edits.push(edit); } pub fn reparse(&self, new_text: &str) -> ParseResult { let tokens = tokenize(new_text); let parser = Parser::new(tokens); parser.parse() } } pub struct SyntaxTreeBuilder { green: GreenNode, } impl SyntaxTreeBuilder { pub fn new(green: GreenNode) -> Self { Self { green } } pub fn build(self) -> SyntaxNodeRef { SyntaxNodeRef::new_root(self.green) } } pub struct AstNode { syntax: SyntaxNodeRef, } impl AstNode { pub fn cast(syntax: SyntaxNodeRef) -> Option<Self> { match syntax.kind() { SyntaxKind::Root | SyntaxKind::BinaryExpr | SyntaxKind::UnaryExpr | SyntaxKind::ParenExpr | SyntaxKind::Literal | SyntaxKind::BlockStmt | SyntaxKind::ExprStmt | SyntaxKind::LetStmt | SyntaxKind::IfStmt | SyntaxKind::WhileStmt | SyntaxKind::ReturnStmt | SyntaxKind::FnDef | SyntaxKind::CallExpr => Some(Self { syntax }), _ => None, } } pub fn syntax(&self) -> &SyntaxNodeRef { &self.syntax } } pub trait AstToken { fn cast(syntax: SyntaxTokenRef) -> Option<Self> where Self: Sized; fn syntax(&self) -> &SyntaxTokenRef; } pub struct Identifier { syntax: SyntaxTokenRef, } impl AstToken for Identifier { fn cast(syntax: SyntaxTokenRef) -> Option<Self> { if syntax.kind() == SyntaxKind::Ident { Some(Self { syntax }) } else { None } } fn syntax(&self) -> &SyntaxTokenRef { &self.syntax } } impl Identifier { pub fn text(&self) -> &str { self.syntax.text() } } pub fn walk_tree(node: &SyntaxNodeRef, depth: usize) { let indent = " ".repeat(depth); println!("{}{:?}", indent, node.kind()); for child in node.children() { walk_tree(&child, depth + 1); } } pub fn find_node_at_offset(root: &SyntaxNodeRef, offset: TextSize) -> Option<SyntaxNodeRef> { if !root.text_range().contains(offset) { return None; } let mut result = root.clone(); for child in root.children() { if child.text_range().contains(offset) { if let Some(deeper) = find_node_at_offset(&child, offset) { result = deeper; } } } Some(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let input = "let x = 42 + y;"; let tokens = tokenize(input); assert_eq!(tokens[0].kind, SyntaxKind::Keyword); assert_eq!(tokens[0].text, "let"); assert_eq!(tokens[2].kind, SyntaxKind::Ident); assert_eq!(tokens[2].text, "x"); } #[test] fn test_parse_expression() { let input = "x + 42 * (y - 3)"; // First check tokenization let tokens = tokenize(input); // The expression should tokenize properly assert!(tokens.len() > 5, "Expected more tokens, got: {:?}", tokens); // Check if tokens include operators let has_plus = tokens.iter().any(|t| t.kind == SyntaxKind::Plus); let has_star = tokens.iter().any(|t| t.kind == SyntaxKind::Star); assert!(has_plus, "Missing + operator in tokens: {:?}", tokens); assert!(has_star, "Missing * operator in tokens: {:?}", tokens); let tree = parse_expression(input); assert_eq!(tree.kind(), SyntaxKind::Root); // Check the tree contains the full expression let tree_text = tree.text().to_string(); assert_eq!( tree_text.trim(), input, "Tree text doesn't match input. Got '{}' expected '{}'", tree_text.trim(), input ); } #[test] fn test_incremental_reparse() { let input = "let x = 42;"; let tree = parse_expression(input); let mut reparser = IncrementalReparser::new(tree); reparser.add_edit(TextEdit { range: TextRange::new(TextSize::from(8), TextSize::from(10)), new_text: "100".to_string(), }); let new_tree = reparser.reparse("let x = 100;"); assert_eq!(new_tree.errors.len(), 0); } #[test] fn test_ast_node_cast() { let input = "42 + x"; let tree = parse_expression(input); if let Some(first_child) = tree.first_child() { let ast_node = AstNode::cast(first_child); assert!(ast_node.is_some()); } } } pub fn parse_expression(input: &str) -> SyntaxNodeRef { let tokens = tokenize(input); let mut parser = Parser::new(tokens); // Build just an expression tree parser.builder.start_node(SyntaxKind::Root.into()); // Include leading whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } // Parse the expression if !parser.at_end() { parser.expression(); } // Include trailing whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } parser.builder.finish_node(); let green_node = parser.builder.finish(); SyntaxTreeBuilder::new(green_node).build() } }
The parse_expression function demonstrates the complete pipeline from source text to syntax tree. The tokenizer produces a stream of tokens with their kinds and positions, the parser builds a green tree using these tokens, and finally the syntax tree builder creates the typed red tree for traversal.
Language Definition
Every rowan-based parser requires a language definition that specifies the syntax kinds and their conversions:
#![allow(unused)] fn main() { use rowan::{GreenNode, GreenNodeBuilder, Language, SyntaxNode, SyntaxToken, TextRange, TextSize}; impl From<SyntaxKind> for rowan::SyntaxKind { fn from(kind: SyntaxKind) -> Self { Self(kind as u16) } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Lang {} impl Language for Lang { type Kind = SyntaxKind; fn kind_from_raw(raw: rowan::SyntaxKind) -> Self::Kind { assert!(raw.0 <= SyntaxKind::ArgList as u16); unsafe { std::mem::transmute::<u16, SyntaxKind>(raw.0) } } fn kind_to_raw(kind: Self::Kind) -> rowan::SyntaxKind { kind.into() } } pub type SyntaxNodeRef = SyntaxNode<Lang>; pub type SyntaxTokenRef = SyntaxToken<Lang>; #[derive(Debug, Clone)] pub struct ParseResult { pub green_node: GreenNode, pub errors: Vec<ParseError>, } #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub range: TextRange, } pub struct Parser { builder: GreenNodeBuilder<'static>, errors: Vec<ParseError>, tokens: Vec<Token>, cursor: usize, } #[derive(Debug, Clone)] pub struct Token { pub kind: SyntaxKind, pub text: String, pub offset: TextSize, } impl Parser { pub fn new(tokens: Vec<Token>) -> Self { Self { builder: GreenNodeBuilder::new(), errors: Vec::new(), tokens, cursor: 0, } } pub fn parse(mut self) -> ParseResult { self.builder.start_node(SyntaxKind::Root.into()); while !self.at_end() { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.builder.finish_node(); ParseResult { green_node: self.builder.finish(), errors: self.errors, } } fn statement(&mut self) { match self.current_kind() { Some(SyntaxKind::Keyword) if self.current_text() == Some("let") => { self.let_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("if") => { self.if_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("while") => { self.while_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("return") => { self.return_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("fn") => { self.function_definition(); } _ => { self.expression_statement(); } } } fn let_statement(&mut self) { self.builder.start_node(SyntaxKind::LetStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Eq) { self.consume(SyntaxKind::Eq); self.skip_trivia(); self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn if_statement(&mut self) { self.builder.start_node(SyntaxKind::IfStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); if self.at_keyword("else") { self.consume(SyntaxKind::Keyword); self.skip_trivia(); if self.at_keyword("if") { self.if_statement(); } else { self.block(); } } self.builder.finish_node(); } fn while_statement(&mut self) { self.builder.start_node(SyntaxKind::WhileStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn return_statement(&mut self) { self.builder.start_node(SyntaxKind::ReturnStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); if !self.at(SyntaxKind::Semicolon) { self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn function_definition(&mut self) { self.builder.start_node(SyntaxKind::FnDef.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); self.parameter_list(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn parameter_list(&mut self) { self.builder.start_node(SyntaxKind::ParamList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn block(&mut self) { self.builder.start_node(SyntaxKind::BlockStmt.into()); self.consume(SyntaxKind::LBrace); while !self.at_end() && !self.at(SyntaxKind::RBrace) { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.consume(SyntaxKind::RBrace); self.builder.finish_node(); } fn expression_statement(&mut self) { self.builder.start_node(SyntaxKind::ExprStmt.into()); self.expression(); self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn expression(&mut self) { self.binary_expression(0); } fn binary_expression(&mut self, min_precedence: u8) { self.unary_expression(); // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } while let Some(op_precedence) = self.current_binary_op_precedence() { if op_precedence < min_precedence { break; } let checkpoint = self.builder.checkpoint(); if let Some( k @ (SyntaxKind::Plus | SyntaxKind::Minus | SyntaxKind::Star | SyntaxKind::Slash | SyntaxKind::Eq | SyntaxKind::Neq | SyntaxKind::Lt | SyntaxKind::Gt), ) = self.current_kind() { self.consume(k); } // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.binary_expression(op_precedence + 1); self.builder .start_node_at(checkpoint, SyntaxKind::BinaryExpr.into()); self.builder.finish_node(); } } fn unary_expression(&mut self) { if self.at(SyntaxKind::Minus) || self.at(SyntaxKind::Plus) { self.builder.start_node(SyntaxKind::UnaryExpr.into()); self.consume(self.current_kind().unwrap()); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.unary_expression(); self.builder.finish_node(); } else { self.postfix_expression(); } } fn postfix_expression(&mut self) { self.primary_expression(); while self.at(SyntaxKind::LParen) { self.builder.start_node(SyntaxKind::CallExpr.into()); let checkpoint = self.builder.checkpoint(); self.argument_list(); self.builder .start_node_at(checkpoint, SyntaxKind::CallExpr.into()); self.builder.finish_node(); } } fn argument_list(&mut self) { self.builder.start_node(SyntaxKind::ArgList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.expression(); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn primary_expression(&mut self) { match self.current_kind() { Some(SyntaxKind::Number) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::Number); self.builder.finish_node(); } Some(SyntaxKind::String) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::String); self.builder.finish_node(); } Some(SyntaxKind::Ident) => { self.consume(SyntaxKind::Ident); } Some(SyntaxKind::LParen) => { self.builder.start_node(SyntaxKind::ParenExpr.into()); self.consume(SyntaxKind::LParen); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.expression(); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } _ => { self.error("Expected expression"); self.advance(); } } } fn current_binary_op_precedence(&self) -> Option<u8> { match self.current_kind()? { SyntaxKind::Star | SyntaxKind::Slash => Some(5), SyntaxKind::Plus | SyntaxKind::Minus => Some(4), SyntaxKind::Lt | SyntaxKind::Gt => Some(3), SyntaxKind::Eq | SyntaxKind::Neq => Some(2), _ => None, } } fn trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { let kind = self.current_kind().unwrap(); self.consume(kind); } } fn skip_trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.advance(); } } fn consume(&mut self, expected: SyntaxKind) { if self.at(expected) { let token = &self.tokens[self.cursor]; self.builder.token(expected.into(), &token.text); self.advance(); } else { self.error(&format!("Expected {:?}", expected)); } } fn at(&self, kind: SyntaxKind) -> bool { self.current_kind() == Some(kind) } fn at_keyword(&self, keyword: &str) -> bool { self.at(SyntaxKind::Keyword) && self.current_text() == Some(keyword) } fn current_kind(&self) -> Option<SyntaxKind> { self.tokens.get(self.cursor).map(|t| t.kind) } fn current_text(&self) -> Option<&str> { self.tokens.get(self.cursor).map(|t| t.text.as_str()) } fn advance(&mut self) { if self.cursor < self.tokens.len() { self.cursor += 1; } } fn at_end(&self) -> bool { self.cursor >= self.tokens.len() } fn error(&mut self, message: &str) { let offset = self .tokens .get(self.cursor) .map(|t| t.offset) .unwrap_or_else(|| TextSize::from(0)); self.errors.push(ParseError { message: message.to_string(), range: TextRange::empty(offset), }); } } pub fn tokenize(input: &str) -> Vec<Token> { let mut tokens = Vec::new(); let mut offset = TextSize::from(0); let mut chars = input.chars().peekable(); while let Some(ch) = chars.next() { let start = offset; offset += TextSize::of(ch); let (kind, text) = match ch { ' ' | '\t' | '\n' | '\r' => { let mut text = String::from(ch); while let Some(&next) = chars.peek() { if next.is_whitespace() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Whitespace, text) } '/' if chars.peek() == Some(&'/') => { chars.next(); offset += TextSize::of('/'); let mut text = String::from("//"); while let Some(&next) = chars.peek() { if next != '\n' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Comment, text) } '+' => (SyntaxKind::Plus, String::from("+")), '-' => (SyntaxKind::Minus, String::from("-")), '*' => (SyntaxKind::Star, String::from("*")), '/' => (SyntaxKind::Slash, String::from("/")), '=' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Eq, String::from("==")) } '=' => (SyntaxKind::Eq, String::from("=")), '!' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Neq, String::from("!=")) } '<' => (SyntaxKind::Lt, String::from("<")), '>' => (SyntaxKind::Gt, String::from(">")), '(' => (SyntaxKind::LParen, String::from("(")), ')' => (SyntaxKind::RParen, String::from(")")), '{' => (SyntaxKind::LBrace, String::from("{")), '}' => (SyntaxKind::RBrace, String::from("}")), ';' => (SyntaxKind::Semicolon, String::from(";")), ',' => (SyntaxKind::Comma, String::from(",")), '"' => { let mut text = String::from("\""); while let Some(&next) = chars.peek() { text.push(next); offset += TextSize::of(next); chars.next(); if next == '"' { break; } } (SyntaxKind::String, text) } c if c.is_ascii_digit() => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_digit() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Number, text) } c if c.is_ascii_alphabetic() || c == '_' => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_alphanumeric() || next == '_' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } let kind = match text.as_str() { "let" | "if" | "else" | "while" | "for" | "fn" | "return" | "true" | "false" | "struct" | "enum" | "impl" => SyntaxKind::Keyword, _ => SyntaxKind::Ident, }; (kind, text) } _ => (SyntaxKind::Error, String::from(ch)), }; tokens.push(Token { kind, text, offset: start, }); } tokens } #[derive(Debug)] pub struct IncrementalReparser { _old_tree: SyntaxNodeRef, edits: Vec<TextEdit>, } #[derive(Debug, Clone)] pub struct TextEdit { pub range: TextRange, pub new_text: String, } impl IncrementalReparser { pub fn new(tree: SyntaxNodeRef) -> Self { Self { _old_tree: tree, edits: Vec::new(), } } pub fn add_edit(&mut self, edit: TextEdit) { self.edits.push(edit); } pub fn reparse(&self, new_text: &str) -> ParseResult { let tokens = tokenize(new_text); let parser = Parser::new(tokens); parser.parse() } } pub struct SyntaxTreeBuilder { green: GreenNode, } impl SyntaxTreeBuilder { pub fn new(green: GreenNode) -> Self { Self { green } } pub fn build(self) -> SyntaxNodeRef { SyntaxNodeRef::new_root(self.green) } } pub fn parse_expression(input: &str) -> SyntaxNodeRef { let tokens = tokenize(input); let mut parser = Parser::new(tokens); // Build just an expression tree parser.builder.start_node(SyntaxKind::Root.into()); // Include leading whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } // Parse the expression if !parser.at_end() { parser.expression(); } // Include trailing whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } parser.builder.finish_node(); let green_node = parser.builder.finish(); SyntaxTreeBuilder::new(green_node).build() } pub struct AstNode { syntax: SyntaxNodeRef, } impl AstNode { pub fn cast(syntax: SyntaxNodeRef) -> Option<Self> { match syntax.kind() { SyntaxKind::Root | SyntaxKind::BinaryExpr | SyntaxKind::UnaryExpr | SyntaxKind::ParenExpr | SyntaxKind::Literal | SyntaxKind::BlockStmt | SyntaxKind::ExprStmt | SyntaxKind::LetStmt | SyntaxKind::IfStmt | SyntaxKind::WhileStmt | SyntaxKind::ReturnStmt | SyntaxKind::FnDef | SyntaxKind::CallExpr => Some(Self { syntax }), _ => None, } } pub fn syntax(&self) -> &SyntaxNodeRef { &self.syntax } } pub trait AstToken { fn cast(syntax: SyntaxTokenRef) -> Option<Self> where Self: Sized; fn syntax(&self) -> &SyntaxTokenRef; } pub struct Identifier { syntax: SyntaxTokenRef, } impl AstToken for Identifier { fn cast(syntax: SyntaxTokenRef) -> Option<Self> { if syntax.kind() == SyntaxKind::Ident { Some(Self { syntax }) } else { None } } fn syntax(&self) -> &SyntaxTokenRef { &self.syntax } } impl Identifier { pub fn text(&self) -> &str { self.syntax.text() } } pub fn walk_tree(node: &SyntaxNodeRef, depth: usize) { let indent = " ".repeat(depth); println!("{}{:?}", indent, node.kind()); for child in node.children() { walk_tree(&child, depth + 1); } } pub fn find_node_at_offset(root: &SyntaxNodeRef, offset: TextSize) -> Option<SyntaxNodeRef> { if !root.text_range().contains(offset) { return None; } let mut result = root.clone(); for child in root.children() { if child.text_range().contains(offset) { if let Some(deeper) = find_node_at_offset(&child, offset) { result = deeper; } } } Some(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let input = "let x = 42 + y;"; let tokens = tokenize(input); assert_eq!(tokens[0].kind, SyntaxKind::Keyword); assert_eq!(tokens[0].text, "let"); assert_eq!(tokens[2].kind, SyntaxKind::Ident); assert_eq!(tokens[2].text, "x"); } #[test] fn test_parse_expression() { let input = "x + 42 * (y - 3)"; // First check tokenization let tokens = tokenize(input); // The expression should tokenize properly assert!(tokens.len() > 5, "Expected more tokens, got: {:?}", tokens); // Check if tokens include operators let has_plus = tokens.iter().any(|t| t.kind == SyntaxKind::Plus); let has_star = tokens.iter().any(|t| t.kind == SyntaxKind::Star); assert!(has_plus, "Missing + operator in tokens: {:?}", tokens); assert!(has_star, "Missing * operator in tokens: {:?}", tokens); let tree = parse_expression(input); assert_eq!(tree.kind(), SyntaxKind::Root); // Check the tree contains the full expression let tree_text = tree.text().to_string(); assert_eq!( tree_text.trim(), input, "Tree text doesn't match input. Got '{}' expected '{}'", tree_text.trim(), input ); } #[test] fn test_incremental_reparse() { let input = "let x = 42;"; let tree = parse_expression(input); let mut reparser = IncrementalReparser::new(tree); reparser.add_edit(TextEdit { range: TextRange::new(TextSize::from(8), TextSize::from(10)), new_text: "100".to_string(), }); let new_tree = reparser.reparse("let x = 100;"); assert_eq!(new_tree.errors.len(), 0); } #[test] fn test_ast_node_cast() { let input = "42 + x"; let tree = parse_expression(input); if let Some(first_child) = tree.first_child() { let ast_node = AstNode::cast(first_child); assert!(ast_node.is_some()); } } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u16)] pub enum SyntaxKind { Whitespace = 0, Comment, Ident, Number, String, Plus, Minus, Star, Slash, Eq, Neq, Lt, Gt, LParen, RParen, LBrace, RBrace, Semicolon, Comma, Keyword, Error, Root, BinaryExpr, UnaryExpr, ParenExpr, Literal, BlockStmt, ExprStmt, LetStmt, IfStmt, WhileStmt, ReturnStmt, FnDef, ParamList, TypeRef, Path, CallExpr, ArgList, } }
#![allow(unused)] fn main() { use rowan::{GreenNode, GreenNodeBuilder, Language, SyntaxNode, SyntaxToken, TextRange, TextSize}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u16)] pub enum SyntaxKind { Whitespace = 0, Comment, Ident, Number, String, Plus, Minus, Star, Slash, Eq, Neq, Lt, Gt, LParen, RParen, LBrace, RBrace, Semicolon, Comma, Keyword, Error, Root, BinaryExpr, UnaryExpr, ParenExpr, Literal, BlockStmt, ExprStmt, LetStmt, IfStmt, WhileStmt, ReturnStmt, FnDef, ParamList, TypeRef, Path, CallExpr, ArgList, } impl From<SyntaxKind> for rowan::SyntaxKind { fn from(kind: SyntaxKind) -> Self { Self(kind as u16) } } impl Language for Lang { type Kind = SyntaxKind; fn kind_from_raw(raw: rowan::SyntaxKind) -> Self::Kind { assert!(raw.0 <= SyntaxKind::ArgList as u16); unsafe { std::mem::transmute::<u16, SyntaxKind>(raw.0) } } fn kind_to_raw(kind: Self::Kind) -> rowan::SyntaxKind { kind.into() } } pub type SyntaxNodeRef = SyntaxNode<Lang>; pub type SyntaxTokenRef = SyntaxToken<Lang>; #[derive(Debug, Clone)] pub struct ParseResult { pub green_node: GreenNode, pub errors: Vec<ParseError>, } #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub range: TextRange, } pub struct Parser { builder: GreenNodeBuilder<'static>, errors: Vec<ParseError>, tokens: Vec<Token>, cursor: usize, } #[derive(Debug, Clone)] pub struct Token { pub kind: SyntaxKind, pub text: String, pub offset: TextSize, } impl Parser { pub fn new(tokens: Vec<Token>) -> Self { Self { builder: GreenNodeBuilder::new(), errors: Vec::new(), tokens, cursor: 0, } } pub fn parse(mut self) -> ParseResult { self.builder.start_node(SyntaxKind::Root.into()); while !self.at_end() { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.builder.finish_node(); ParseResult { green_node: self.builder.finish(), errors: self.errors, } } fn statement(&mut self) { match self.current_kind() { Some(SyntaxKind::Keyword) if self.current_text() == Some("let") => { self.let_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("if") => { self.if_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("while") => { self.while_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("return") => { self.return_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("fn") => { self.function_definition(); } _ => { self.expression_statement(); } } } fn let_statement(&mut self) { self.builder.start_node(SyntaxKind::LetStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Eq) { self.consume(SyntaxKind::Eq); self.skip_trivia(); self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn if_statement(&mut self) { self.builder.start_node(SyntaxKind::IfStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); if self.at_keyword("else") { self.consume(SyntaxKind::Keyword); self.skip_trivia(); if self.at_keyword("if") { self.if_statement(); } else { self.block(); } } self.builder.finish_node(); } fn while_statement(&mut self) { self.builder.start_node(SyntaxKind::WhileStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn return_statement(&mut self) { self.builder.start_node(SyntaxKind::ReturnStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); if !self.at(SyntaxKind::Semicolon) { self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn function_definition(&mut self) { self.builder.start_node(SyntaxKind::FnDef.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); self.parameter_list(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn parameter_list(&mut self) { self.builder.start_node(SyntaxKind::ParamList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn block(&mut self) { self.builder.start_node(SyntaxKind::BlockStmt.into()); self.consume(SyntaxKind::LBrace); while !self.at_end() && !self.at(SyntaxKind::RBrace) { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.consume(SyntaxKind::RBrace); self.builder.finish_node(); } fn expression_statement(&mut self) { self.builder.start_node(SyntaxKind::ExprStmt.into()); self.expression(); self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn expression(&mut self) { self.binary_expression(0); } fn binary_expression(&mut self, min_precedence: u8) { self.unary_expression(); // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } while let Some(op_precedence) = self.current_binary_op_precedence() { if op_precedence < min_precedence { break; } let checkpoint = self.builder.checkpoint(); if let Some( k @ (SyntaxKind::Plus | SyntaxKind::Minus | SyntaxKind::Star | SyntaxKind::Slash | SyntaxKind::Eq | SyntaxKind::Neq | SyntaxKind::Lt | SyntaxKind::Gt), ) = self.current_kind() { self.consume(k); } // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.binary_expression(op_precedence + 1); self.builder .start_node_at(checkpoint, SyntaxKind::BinaryExpr.into()); self.builder.finish_node(); } } fn unary_expression(&mut self) { if self.at(SyntaxKind::Minus) || self.at(SyntaxKind::Plus) { self.builder.start_node(SyntaxKind::UnaryExpr.into()); self.consume(self.current_kind().unwrap()); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.unary_expression(); self.builder.finish_node(); } else { self.postfix_expression(); } } fn postfix_expression(&mut self) { self.primary_expression(); while self.at(SyntaxKind::LParen) { self.builder.start_node(SyntaxKind::CallExpr.into()); let checkpoint = self.builder.checkpoint(); self.argument_list(); self.builder .start_node_at(checkpoint, SyntaxKind::CallExpr.into()); self.builder.finish_node(); } } fn argument_list(&mut self) { self.builder.start_node(SyntaxKind::ArgList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.expression(); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn primary_expression(&mut self) { match self.current_kind() { Some(SyntaxKind::Number) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::Number); self.builder.finish_node(); } Some(SyntaxKind::String) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::String); self.builder.finish_node(); } Some(SyntaxKind::Ident) => { self.consume(SyntaxKind::Ident); } Some(SyntaxKind::LParen) => { self.builder.start_node(SyntaxKind::ParenExpr.into()); self.consume(SyntaxKind::LParen); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.expression(); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } _ => { self.error("Expected expression"); self.advance(); } } } fn current_binary_op_precedence(&self) -> Option<u8> { match self.current_kind()? { SyntaxKind::Star | SyntaxKind::Slash => Some(5), SyntaxKind::Plus | SyntaxKind::Minus => Some(4), SyntaxKind::Lt | SyntaxKind::Gt => Some(3), SyntaxKind::Eq | SyntaxKind::Neq => Some(2), _ => None, } } fn trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { let kind = self.current_kind().unwrap(); self.consume(kind); } } fn skip_trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.advance(); } } fn consume(&mut self, expected: SyntaxKind) { if self.at(expected) { let token = &self.tokens[self.cursor]; self.builder.token(expected.into(), &token.text); self.advance(); } else { self.error(&format!("Expected {:?}", expected)); } } fn at(&self, kind: SyntaxKind) -> bool { self.current_kind() == Some(kind) } fn at_keyword(&self, keyword: &str) -> bool { self.at(SyntaxKind::Keyword) && self.current_text() == Some(keyword) } fn current_kind(&self) -> Option<SyntaxKind> { self.tokens.get(self.cursor).map(|t| t.kind) } fn current_text(&self) -> Option<&str> { self.tokens.get(self.cursor).map(|t| t.text.as_str()) } fn advance(&mut self) { if self.cursor < self.tokens.len() { self.cursor += 1; } } fn at_end(&self) -> bool { self.cursor >= self.tokens.len() } fn error(&mut self, message: &str) { let offset = self .tokens .get(self.cursor) .map(|t| t.offset) .unwrap_or_else(|| TextSize::from(0)); self.errors.push(ParseError { message: message.to_string(), range: TextRange::empty(offset), }); } } pub fn tokenize(input: &str) -> Vec<Token> { let mut tokens = Vec::new(); let mut offset = TextSize::from(0); let mut chars = input.chars().peekable(); while let Some(ch) = chars.next() { let start = offset; offset += TextSize::of(ch); let (kind, text) = match ch { ' ' | '\t' | '\n' | '\r' => { let mut text = String::from(ch); while let Some(&next) = chars.peek() { if next.is_whitespace() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Whitespace, text) } '/' if chars.peek() == Some(&'/') => { chars.next(); offset += TextSize::of('/'); let mut text = String::from("//"); while let Some(&next) = chars.peek() { if next != '\n' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Comment, text) } '+' => (SyntaxKind::Plus, String::from("+")), '-' => (SyntaxKind::Minus, String::from("-")), '*' => (SyntaxKind::Star, String::from("*")), '/' => (SyntaxKind::Slash, String::from("/")), '=' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Eq, String::from("==")) } '=' => (SyntaxKind::Eq, String::from("=")), '!' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Neq, String::from("!=")) } '<' => (SyntaxKind::Lt, String::from("<")), '>' => (SyntaxKind::Gt, String::from(">")), '(' => (SyntaxKind::LParen, String::from("(")), ')' => (SyntaxKind::RParen, String::from(")")), '{' => (SyntaxKind::LBrace, String::from("{")), '}' => (SyntaxKind::RBrace, String::from("}")), ';' => (SyntaxKind::Semicolon, String::from(";")), ',' => (SyntaxKind::Comma, String::from(",")), '"' => { let mut text = String::from("\""); while let Some(&next) = chars.peek() { text.push(next); offset += TextSize::of(next); chars.next(); if next == '"' { break; } } (SyntaxKind::String, text) } c if c.is_ascii_digit() => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_digit() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Number, text) } c if c.is_ascii_alphabetic() || c == '_' => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_alphanumeric() || next == '_' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } let kind = match text.as_str() { "let" | "if" | "else" | "while" | "for" | "fn" | "return" | "true" | "false" | "struct" | "enum" | "impl" => SyntaxKind::Keyword, _ => SyntaxKind::Ident, }; (kind, text) } _ => (SyntaxKind::Error, String::from(ch)), }; tokens.push(Token { kind, text, offset: start, }); } tokens } #[derive(Debug)] pub struct IncrementalReparser { _old_tree: SyntaxNodeRef, edits: Vec<TextEdit>, } #[derive(Debug, Clone)] pub struct TextEdit { pub range: TextRange, pub new_text: String, } impl IncrementalReparser { pub fn new(tree: SyntaxNodeRef) -> Self { Self { _old_tree: tree, edits: Vec::new(), } } pub fn add_edit(&mut self, edit: TextEdit) { self.edits.push(edit); } pub fn reparse(&self, new_text: &str) -> ParseResult { let tokens = tokenize(new_text); let parser = Parser::new(tokens); parser.parse() } } pub struct SyntaxTreeBuilder { green: GreenNode, } impl SyntaxTreeBuilder { pub fn new(green: GreenNode) -> Self { Self { green } } pub fn build(self) -> SyntaxNodeRef { SyntaxNodeRef::new_root(self.green) } } pub fn parse_expression(input: &str) -> SyntaxNodeRef { let tokens = tokenize(input); let mut parser = Parser::new(tokens); // Build just an expression tree parser.builder.start_node(SyntaxKind::Root.into()); // Include leading whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } // Parse the expression if !parser.at_end() { parser.expression(); } // Include trailing whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } parser.builder.finish_node(); let green_node = parser.builder.finish(); SyntaxTreeBuilder::new(green_node).build() } pub struct AstNode { syntax: SyntaxNodeRef, } impl AstNode { pub fn cast(syntax: SyntaxNodeRef) -> Option<Self> { match syntax.kind() { SyntaxKind::Root | SyntaxKind::BinaryExpr | SyntaxKind::UnaryExpr | SyntaxKind::ParenExpr | SyntaxKind::Literal | SyntaxKind::BlockStmt | SyntaxKind::ExprStmt | SyntaxKind::LetStmt | SyntaxKind::IfStmt | SyntaxKind::WhileStmt | SyntaxKind::ReturnStmt | SyntaxKind::FnDef | SyntaxKind::CallExpr => Some(Self { syntax }), _ => None, } } pub fn syntax(&self) -> &SyntaxNodeRef { &self.syntax } } pub trait AstToken { fn cast(syntax: SyntaxTokenRef) -> Option<Self> where Self: Sized; fn syntax(&self) -> &SyntaxTokenRef; } pub struct Identifier { syntax: SyntaxTokenRef, } impl AstToken for Identifier { fn cast(syntax: SyntaxTokenRef) -> Option<Self> { if syntax.kind() == SyntaxKind::Ident { Some(Self { syntax }) } else { None } } fn syntax(&self) -> &SyntaxTokenRef { &self.syntax } } impl Identifier { pub fn text(&self) -> &str { self.syntax.text() } } pub fn walk_tree(node: &SyntaxNodeRef, depth: usize) { let indent = " ".repeat(depth); println!("{}{:?}", indent, node.kind()); for child in node.children() { walk_tree(&child, depth + 1); } } pub fn find_node_at_offset(root: &SyntaxNodeRef, offset: TextSize) -> Option<SyntaxNodeRef> { if !root.text_range().contains(offset) { return None; } let mut result = root.clone(); for child in root.children() { if child.text_range().contains(offset) { if let Some(deeper) = find_node_at_offset(&child, offset) { result = deeper; } } } Some(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let input = "let x = 42 + y;"; let tokens = tokenize(input); assert_eq!(tokens[0].kind, SyntaxKind::Keyword); assert_eq!(tokens[0].text, "let"); assert_eq!(tokens[2].kind, SyntaxKind::Ident); assert_eq!(tokens[2].text, "x"); } #[test] fn test_parse_expression() { let input = "x + 42 * (y - 3)"; // First check tokenization let tokens = tokenize(input); // The expression should tokenize properly assert!(tokens.len() > 5, "Expected more tokens, got: {:?}", tokens); // Check if tokens include operators let has_plus = tokens.iter().any(|t| t.kind == SyntaxKind::Plus); let has_star = tokens.iter().any(|t| t.kind == SyntaxKind::Star); assert!(has_plus, "Missing + operator in tokens: {:?}", tokens); assert!(has_star, "Missing * operator in tokens: {:?}", tokens); let tree = parse_expression(input); assert_eq!(tree.kind(), SyntaxKind::Root); // Check the tree contains the full expression let tree_text = tree.text().to_string(); assert_eq!( tree_text.trim(), input, "Tree text doesn't match input. Got '{}' expected '{}'", tree_text.trim(), input ); } #[test] fn test_incremental_reparse() { let input = "let x = 42;"; let tree = parse_expression(input); let mut reparser = IncrementalReparser::new(tree); reparser.add_edit(TextEdit { range: TextRange::new(TextSize::from(8), TextSize::from(10)), new_text: "100".to_string(), }); let new_tree = reparser.reparse("let x = 100;"); assert_eq!(new_tree.errors.len(), 0); } #[test] fn test_ast_node_cast() { let input = "42 + x"; let tree = parse_expression(input); if let Some(first_child) = tree.first_child() { let ast_node = AstNode::cast(first_child); assert!(ast_node.is_some()); } } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Lang {} }
The SyntaxKind enumeration defines all possible node and token types in the language. Each variant represents either a terminal token like an identifier or operator, or a non-terminal node like an expression or statement. The Lang type implements the Language trait, providing the bridge between rowan’s generic infrastructure and the specific syntax kinds.
Green Tree Construction
The parser builds green trees using the GreenNodeBuilder API:
#![allow(unused)] fn main() { use rowan::{GreenNode, GreenNodeBuilder, Language, SyntaxNode, SyntaxToken, TextRange, TextSize}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u16)] pub enum SyntaxKind { Whitespace = 0, Comment, Ident, Number, String, Plus, Minus, Star, Slash, Eq, Neq, Lt, Gt, LParen, RParen, LBrace, RBrace, Semicolon, Comma, Keyword, Error, Root, BinaryExpr, UnaryExpr, ParenExpr, Literal, BlockStmt, ExprStmt, LetStmt, IfStmt, WhileStmt, ReturnStmt, FnDef, ParamList, TypeRef, Path, CallExpr, ArgList, } impl From<SyntaxKind> for rowan::SyntaxKind { fn from(kind: SyntaxKind) -> Self { Self(kind as u16) } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Lang {} impl Language for Lang { type Kind = SyntaxKind; fn kind_from_raw(raw: rowan::SyntaxKind) -> Self::Kind { assert!(raw.0 <= SyntaxKind::ArgList as u16); unsafe { std::mem::transmute::<u16, SyntaxKind>(raw.0) } } fn kind_to_raw(kind: Self::Kind) -> rowan::SyntaxKind { kind.into() } } pub type SyntaxNodeRef = SyntaxNode<Lang>; pub type SyntaxTokenRef = SyntaxToken<Lang>; #[derive(Debug, Clone)] pub struct ParseResult { pub green_node: GreenNode, pub errors: Vec<ParseError>, } #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub range: TextRange, } #[derive(Debug, Clone)] pub struct Token { pub kind: SyntaxKind, pub text: String, pub offset: TextSize, } impl Parser { pub fn new(tokens: Vec<Token>) -> Self { Self { builder: GreenNodeBuilder::new(), errors: Vec::new(), tokens, cursor: 0, } } pub fn parse(mut self) -> ParseResult { self.builder.start_node(SyntaxKind::Root.into()); while !self.at_end() { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.builder.finish_node(); ParseResult { green_node: self.builder.finish(), errors: self.errors, } } fn statement(&mut self) { match self.current_kind() { Some(SyntaxKind::Keyword) if self.current_text() == Some("let") => { self.let_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("if") => { self.if_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("while") => { self.while_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("return") => { self.return_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("fn") => { self.function_definition(); } _ => { self.expression_statement(); } } } fn let_statement(&mut self) { self.builder.start_node(SyntaxKind::LetStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Eq) { self.consume(SyntaxKind::Eq); self.skip_trivia(); self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn if_statement(&mut self) { self.builder.start_node(SyntaxKind::IfStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); if self.at_keyword("else") { self.consume(SyntaxKind::Keyword); self.skip_trivia(); if self.at_keyword("if") { self.if_statement(); } else { self.block(); } } self.builder.finish_node(); } fn while_statement(&mut self) { self.builder.start_node(SyntaxKind::WhileStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn return_statement(&mut self) { self.builder.start_node(SyntaxKind::ReturnStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); if !self.at(SyntaxKind::Semicolon) { self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn function_definition(&mut self) { self.builder.start_node(SyntaxKind::FnDef.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); self.parameter_list(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn parameter_list(&mut self) { self.builder.start_node(SyntaxKind::ParamList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn block(&mut self) { self.builder.start_node(SyntaxKind::BlockStmt.into()); self.consume(SyntaxKind::LBrace); while !self.at_end() && !self.at(SyntaxKind::RBrace) { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.consume(SyntaxKind::RBrace); self.builder.finish_node(); } fn expression_statement(&mut self) { self.builder.start_node(SyntaxKind::ExprStmt.into()); self.expression(); self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn expression(&mut self) { self.binary_expression(0); } fn binary_expression(&mut self, min_precedence: u8) { self.unary_expression(); // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } while let Some(op_precedence) = self.current_binary_op_precedence() { if op_precedence < min_precedence { break; } let checkpoint = self.builder.checkpoint(); if let Some( k @ (SyntaxKind::Plus | SyntaxKind::Minus | SyntaxKind::Star | SyntaxKind::Slash | SyntaxKind::Eq | SyntaxKind::Neq | SyntaxKind::Lt | SyntaxKind::Gt), ) = self.current_kind() { self.consume(k); } // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.binary_expression(op_precedence + 1); self.builder .start_node_at(checkpoint, SyntaxKind::BinaryExpr.into()); self.builder.finish_node(); } } fn unary_expression(&mut self) { if self.at(SyntaxKind::Minus) || self.at(SyntaxKind::Plus) { self.builder.start_node(SyntaxKind::UnaryExpr.into()); self.consume(self.current_kind().unwrap()); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.unary_expression(); self.builder.finish_node(); } else { self.postfix_expression(); } } fn postfix_expression(&mut self) { self.primary_expression(); while self.at(SyntaxKind::LParen) { self.builder.start_node(SyntaxKind::CallExpr.into()); let checkpoint = self.builder.checkpoint(); self.argument_list(); self.builder .start_node_at(checkpoint, SyntaxKind::CallExpr.into()); self.builder.finish_node(); } } fn argument_list(&mut self) { self.builder.start_node(SyntaxKind::ArgList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.expression(); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn primary_expression(&mut self) { match self.current_kind() { Some(SyntaxKind::Number) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::Number); self.builder.finish_node(); } Some(SyntaxKind::String) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::String); self.builder.finish_node(); } Some(SyntaxKind::Ident) => { self.consume(SyntaxKind::Ident); } Some(SyntaxKind::LParen) => { self.builder.start_node(SyntaxKind::ParenExpr.into()); self.consume(SyntaxKind::LParen); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.expression(); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } _ => { self.error("Expected expression"); self.advance(); } } } fn current_binary_op_precedence(&self) -> Option<u8> { match self.current_kind()? { SyntaxKind::Star | SyntaxKind::Slash => Some(5), SyntaxKind::Plus | SyntaxKind::Minus => Some(4), SyntaxKind::Lt | SyntaxKind::Gt => Some(3), SyntaxKind::Eq | SyntaxKind::Neq => Some(2), _ => None, } } fn trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { let kind = self.current_kind().unwrap(); self.consume(kind); } } fn skip_trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.advance(); } } fn consume(&mut self, expected: SyntaxKind) { if self.at(expected) { let token = &self.tokens[self.cursor]; self.builder.token(expected.into(), &token.text); self.advance(); } else { self.error(&format!("Expected {:?}", expected)); } } fn at(&self, kind: SyntaxKind) -> bool { self.current_kind() == Some(kind) } fn at_keyword(&self, keyword: &str) -> bool { self.at(SyntaxKind::Keyword) && self.current_text() == Some(keyword) } fn current_kind(&self) -> Option<SyntaxKind> { self.tokens.get(self.cursor).map(|t| t.kind) } fn current_text(&self) -> Option<&str> { self.tokens.get(self.cursor).map(|t| t.text.as_str()) } fn advance(&mut self) { if self.cursor < self.tokens.len() { self.cursor += 1; } } fn at_end(&self) -> bool { self.cursor >= self.tokens.len() } fn error(&mut self, message: &str) { let offset = self .tokens .get(self.cursor) .map(|t| t.offset) .unwrap_or_else(|| TextSize::from(0)); self.errors.push(ParseError { message: message.to_string(), range: TextRange::empty(offset), }); } } pub fn tokenize(input: &str) -> Vec<Token> { let mut tokens = Vec::new(); let mut offset = TextSize::from(0); let mut chars = input.chars().peekable(); while let Some(ch) = chars.next() { let start = offset; offset += TextSize::of(ch); let (kind, text) = match ch { ' ' | '\t' | '\n' | '\r' => { let mut text = String::from(ch); while let Some(&next) = chars.peek() { if next.is_whitespace() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Whitespace, text) } '/' if chars.peek() == Some(&'/') => { chars.next(); offset += TextSize::of('/'); let mut text = String::from("//"); while let Some(&next) = chars.peek() { if next != '\n' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Comment, text) } '+' => (SyntaxKind::Plus, String::from("+")), '-' => (SyntaxKind::Minus, String::from("-")), '*' => (SyntaxKind::Star, String::from("*")), '/' => (SyntaxKind::Slash, String::from("/")), '=' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Eq, String::from("==")) } '=' => (SyntaxKind::Eq, String::from("=")), '!' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Neq, String::from("!=")) } '<' => (SyntaxKind::Lt, String::from("<")), '>' => (SyntaxKind::Gt, String::from(">")), '(' => (SyntaxKind::LParen, String::from("(")), ')' => (SyntaxKind::RParen, String::from(")")), '{' => (SyntaxKind::LBrace, String::from("{")), '}' => (SyntaxKind::RBrace, String::from("}")), ';' => (SyntaxKind::Semicolon, String::from(";")), ',' => (SyntaxKind::Comma, String::from(",")), '"' => { let mut text = String::from("\""); while let Some(&next) = chars.peek() { text.push(next); offset += TextSize::of(next); chars.next(); if next == '"' { break; } } (SyntaxKind::String, text) } c if c.is_ascii_digit() => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_digit() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Number, text) } c if c.is_ascii_alphabetic() || c == '_' => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_alphanumeric() || next == '_' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } let kind = match text.as_str() { "let" | "if" | "else" | "while" | "for" | "fn" | "return" | "true" | "false" | "struct" | "enum" | "impl" => SyntaxKind::Keyword, _ => SyntaxKind::Ident, }; (kind, text) } _ => (SyntaxKind::Error, String::from(ch)), }; tokens.push(Token { kind, text, offset: start, }); } tokens } #[derive(Debug)] pub struct IncrementalReparser { _old_tree: SyntaxNodeRef, edits: Vec<TextEdit>, } #[derive(Debug, Clone)] pub struct TextEdit { pub range: TextRange, pub new_text: String, } impl IncrementalReparser { pub fn new(tree: SyntaxNodeRef) -> Self { Self { _old_tree: tree, edits: Vec::new(), } } pub fn add_edit(&mut self, edit: TextEdit) { self.edits.push(edit); } pub fn reparse(&self, new_text: &str) -> ParseResult { let tokens = tokenize(new_text); let parser = Parser::new(tokens); parser.parse() } } pub struct SyntaxTreeBuilder { green: GreenNode, } impl SyntaxTreeBuilder { pub fn new(green: GreenNode) -> Self { Self { green } } pub fn build(self) -> SyntaxNodeRef { SyntaxNodeRef::new_root(self.green) } } pub fn parse_expression(input: &str) -> SyntaxNodeRef { let tokens = tokenize(input); let mut parser = Parser::new(tokens); // Build just an expression tree parser.builder.start_node(SyntaxKind::Root.into()); // Include leading whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } // Parse the expression if !parser.at_end() { parser.expression(); } // Include trailing whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } parser.builder.finish_node(); let green_node = parser.builder.finish(); SyntaxTreeBuilder::new(green_node).build() } pub struct AstNode { syntax: SyntaxNodeRef, } impl AstNode { pub fn cast(syntax: SyntaxNodeRef) -> Option<Self> { match syntax.kind() { SyntaxKind::Root | SyntaxKind::BinaryExpr | SyntaxKind::UnaryExpr | SyntaxKind::ParenExpr | SyntaxKind::Literal | SyntaxKind::BlockStmt | SyntaxKind::ExprStmt | SyntaxKind::LetStmt | SyntaxKind::IfStmt | SyntaxKind::WhileStmt | SyntaxKind::ReturnStmt | SyntaxKind::FnDef | SyntaxKind::CallExpr => Some(Self { syntax }), _ => None, } } pub fn syntax(&self) -> &SyntaxNodeRef { &self.syntax } } pub trait AstToken { fn cast(syntax: SyntaxTokenRef) -> Option<Self> where Self: Sized; fn syntax(&self) -> &SyntaxTokenRef; } pub struct Identifier { syntax: SyntaxTokenRef, } impl AstToken for Identifier { fn cast(syntax: SyntaxTokenRef) -> Option<Self> { if syntax.kind() == SyntaxKind::Ident { Some(Self { syntax }) } else { None } } fn syntax(&self) -> &SyntaxTokenRef { &self.syntax } } impl Identifier { pub fn text(&self) -> &str { self.syntax.text() } } pub fn walk_tree(node: &SyntaxNodeRef, depth: usize) { let indent = " ".repeat(depth); println!("{}{:?}", indent, node.kind()); for child in node.children() { walk_tree(&child, depth + 1); } } pub fn find_node_at_offset(root: &SyntaxNodeRef, offset: TextSize) -> Option<SyntaxNodeRef> { if !root.text_range().contains(offset) { return None; } let mut result = root.clone(); for child in root.children() { if child.text_range().contains(offset) { if let Some(deeper) = find_node_at_offset(&child, offset) { result = deeper; } } } Some(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let input = "let x = 42 + y;"; let tokens = tokenize(input); assert_eq!(tokens[0].kind, SyntaxKind::Keyword); assert_eq!(tokens[0].text, "let"); assert_eq!(tokens[2].kind, SyntaxKind::Ident); assert_eq!(tokens[2].text, "x"); } #[test] fn test_parse_expression() { let input = "x + 42 * (y - 3)"; // First check tokenization let tokens = tokenize(input); // The expression should tokenize properly assert!(tokens.len() > 5, "Expected more tokens, got: {:?}", tokens); // Check if tokens include operators let has_plus = tokens.iter().any(|t| t.kind == SyntaxKind::Plus); let has_star = tokens.iter().any(|t| t.kind == SyntaxKind::Star); assert!(has_plus, "Missing + operator in tokens: {:?}", tokens); assert!(has_star, "Missing * operator in tokens: {:?}", tokens); let tree = parse_expression(input); assert_eq!(tree.kind(), SyntaxKind::Root); // Check the tree contains the full expression let tree_text = tree.text().to_string(); assert_eq!( tree_text.trim(), input, "Tree text doesn't match input. Got '{}' expected '{}'", tree_text.trim(), input ); } #[test] fn test_incremental_reparse() { let input = "let x = 42;"; let tree = parse_expression(input); let mut reparser = IncrementalReparser::new(tree); reparser.add_edit(TextEdit { range: TextRange::new(TextSize::from(8), TextSize::from(10)), new_text: "100".to_string(), }); let new_tree = reparser.reparse("let x = 100;"); assert_eq!(new_tree.errors.len(), 0); } #[test] fn test_ast_node_cast() { let input = "42 + x"; let tree = parse_expression(input); if let Some(first_child) = tree.first_child() { let ast_node = AstNode::cast(first_child); assert!(ast_node.is_some()); } } } pub struct Parser { builder: GreenNodeBuilder<'static>, errors: Vec<ParseError>, tokens: Vec<Token>, cursor: usize, } }
#![allow(unused)] fn main() { use rowan::{GreenNode, GreenNodeBuilder, Language, SyntaxNode, SyntaxToken, TextRange, TextSize}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u16)] pub enum SyntaxKind { Whitespace = 0, Comment, Ident, Number, String, Plus, Minus, Star, Slash, Eq, Neq, Lt, Gt, LParen, RParen, LBrace, RBrace, Semicolon, Comma, Keyword, Error, Root, BinaryExpr, UnaryExpr, ParenExpr, Literal, BlockStmt, ExprStmt, LetStmt, IfStmt, WhileStmt, ReturnStmt, FnDef, ParamList, TypeRef, Path, CallExpr, ArgList, } impl From<SyntaxKind> for rowan::SyntaxKind { fn from(kind: SyntaxKind) -> Self { Self(kind as u16) } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Lang {} impl Language for Lang { type Kind = SyntaxKind; fn kind_from_raw(raw: rowan::SyntaxKind) -> Self::Kind { assert!(raw.0 <= SyntaxKind::ArgList as u16); unsafe { std::mem::transmute::<u16, SyntaxKind>(raw.0) } } fn kind_to_raw(kind: Self::Kind) -> rowan::SyntaxKind { kind.into() } } pub type SyntaxNodeRef = SyntaxNode<Lang>; pub type SyntaxTokenRef = SyntaxToken<Lang>; #[derive(Debug, Clone)] pub struct ParseResult { pub green_node: GreenNode, pub errors: Vec<ParseError>, } #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub range: TextRange, } pub struct Parser { builder: GreenNodeBuilder<'static>, errors: Vec<ParseError>, tokens: Vec<Token>, cursor: usize, } #[derive(Debug, Clone)] pub struct Token { pub kind: SyntaxKind, pub text: String, pub offset: TextSize, } impl Parser { pub fn new(tokens: Vec<Token>) -> Self { Self { builder: GreenNodeBuilder::new(), errors: Vec::new(), tokens, cursor: 0, } } pub fn parse(mut self) -> ParseResult { self.builder.start_node(SyntaxKind::Root.into()); while !self.at_end() { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.builder.finish_node(); ParseResult { green_node: self.builder.finish(), errors: self.errors, } } fn statement(&mut self) { match self.current_kind() { Some(SyntaxKind::Keyword) if self.current_text() == Some("let") => { self.let_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("if") => { self.if_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("while") => { self.while_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("return") => { self.return_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("fn") => { self.function_definition(); } _ => { self.expression_statement(); } } } fn let_statement(&mut self) { self.builder.start_node(SyntaxKind::LetStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Eq) { self.consume(SyntaxKind::Eq); self.skip_trivia(); self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn if_statement(&mut self) { self.builder.start_node(SyntaxKind::IfStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); if self.at_keyword("else") { self.consume(SyntaxKind::Keyword); self.skip_trivia(); if self.at_keyword("if") { self.if_statement(); } else { self.block(); } } self.builder.finish_node(); } fn while_statement(&mut self) { self.builder.start_node(SyntaxKind::WhileStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn return_statement(&mut self) { self.builder.start_node(SyntaxKind::ReturnStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); if !self.at(SyntaxKind::Semicolon) { self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn function_definition(&mut self) { self.builder.start_node(SyntaxKind::FnDef.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); self.parameter_list(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn parameter_list(&mut self) { self.builder.start_node(SyntaxKind::ParamList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn block(&mut self) { self.builder.start_node(SyntaxKind::BlockStmt.into()); self.consume(SyntaxKind::LBrace); while !self.at_end() && !self.at(SyntaxKind::RBrace) { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.consume(SyntaxKind::RBrace); self.builder.finish_node(); } fn expression_statement(&mut self) { self.builder.start_node(SyntaxKind::ExprStmt.into()); self.expression(); self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn expression(&mut self) { self.binary_expression(0); } fn binary_expression(&mut self, min_precedence: u8) { self.unary_expression(); // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } while let Some(op_precedence) = self.current_binary_op_precedence() { if op_precedence < min_precedence { break; } let checkpoint = self.builder.checkpoint(); if let Some( k @ (SyntaxKind::Plus | SyntaxKind::Minus | SyntaxKind::Star | SyntaxKind::Slash | SyntaxKind::Eq | SyntaxKind::Neq | SyntaxKind::Lt | SyntaxKind::Gt), ) = self.current_kind() { self.consume(k); } // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.binary_expression(op_precedence + 1); self.builder .start_node_at(checkpoint, SyntaxKind::BinaryExpr.into()); self.builder.finish_node(); } } fn unary_expression(&mut self) { if self.at(SyntaxKind::Minus) || self.at(SyntaxKind::Plus) { self.builder.start_node(SyntaxKind::UnaryExpr.into()); self.consume(self.current_kind().unwrap()); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.unary_expression(); self.builder.finish_node(); } else { self.postfix_expression(); } } fn postfix_expression(&mut self) { self.primary_expression(); while self.at(SyntaxKind::LParen) { self.builder.start_node(SyntaxKind::CallExpr.into()); let checkpoint = self.builder.checkpoint(); self.argument_list(); self.builder .start_node_at(checkpoint, SyntaxKind::CallExpr.into()); self.builder.finish_node(); } } fn argument_list(&mut self) { self.builder.start_node(SyntaxKind::ArgList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.expression(); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn primary_expression(&mut self) { match self.current_kind() { Some(SyntaxKind::Number) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::Number); self.builder.finish_node(); } Some(SyntaxKind::String) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::String); self.builder.finish_node(); } Some(SyntaxKind::Ident) => { self.consume(SyntaxKind::Ident); } Some(SyntaxKind::LParen) => { self.builder.start_node(SyntaxKind::ParenExpr.into()); self.consume(SyntaxKind::LParen); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.expression(); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } _ => { self.error("Expected expression"); self.advance(); } } } fn current_binary_op_precedence(&self) -> Option<u8> { match self.current_kind()? { SyntaxKind::Star | SyntaxKind::Slash => Some(5), SyntaxKind::Plus | SyntaxKind::Minus => Some(4), SyntaxKind::Lt | SyntaxKind::Gt => Some(3), SyntaxKind::Eq | SyntaxKind::Neq => Some(2), _ => None, } } fn trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { let kind = self.current_kind().unwrap(); self.consume(kind); } } fn skip_trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.advance(); } } fn consume(&mut self, expected: SyntaxKind) { if self.at(expected) { let token = &self.tokens[self.cursor]; self.builder.token(expected.into(), &token.text); self.advance(); } else { self.error(&format!("Expected {:?}", expected)); } } fn at(&self, kind: SyntaxKind) -> bool { self.current_kind() == Some(kind) } fn at_keyword(&self, keyword: &str) -> bool { self.at(SyntaxKind::Keyword) && self.current_text() == Some(keyword) } fn current_kind(&self) -> Option<SyntaxKind> { self.tokens.get(self.cursor).map(|t| t.kind) } fn current_text(&self) -> Option<&str> { self.tokens.get(self.cursor).map(|t| t.text.as_str()) } fn advance(&mut self) { if self.cursor < self.tokens.len() { self.cursor += 1; } } fn at_end(&self) -> bool { self.cursor >= self.tokens.len() } fn error(&mut self, message: &str) { let offset = self .tokens .get(self.cursor) .map(|t| t.offset) .unwrap_or_else(|| TextSize::from(0)); self.errors.push(ParseError { message: message.to_string(), range: TextRange::empty(offset), }); } } pub fn tokenize(input: &str) -> Vec<Token> { let mut tokens = Vec::new(); let mut offset = TextSize::from(0); let mut chars = input.chars().peekable(); while let Some(ch) = chars.next() { let start = offset; offset += TextSize::of(ch); let (kind, text) = match ch { ' ' | '\t' | '\n' | '\r' => { let mut text = String::from(ch); while let Some(&next) = chars.peek() { if next.is_whitespace() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Whitespace, text) } '/' if chars.peek() == Some(&'/') => { chars.next(); offset += TextSize::of('/'); let mut text = String::from("//"); while let Some(&next) = chars.peek() { if next != '\n' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Comment, text) } '+' => (SyntaxKind::Plus, String::from("+")), '-' => (SyntaxKind::Minus, String::from("-")), '*' => (SyntaxKind::Star, String::from("*")), '/' => (SyntaxKind::Slash, String::from("/")), '=' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Eq, String::from("==")) } '=' => (SyntaxKind::Eq, String::from("=")), '!' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Neq, String::from("!=")) } '<' => (SyntaxKind::Lt, String::from("<")), '>' => (SyntaxKind::Gt, String::from(">")), '(' => (SyntaxKind::LParen, String::from("(")), ')' => (SyntaxKind::RParen, String::from(")")), '{' => (SyntaxKind::LBrace, String::from("{")), '}' => (SyntaxKind::RBrace, String::from("}")), ';' => (SyntaxKind::Semicolon, String::from(";")), ',' => (SyntaxKind::Comma, String::from(",")), '"' => { let mut text = String::from("\""); while let Some(&next) = chars.peek() { text.push(next); offset += TextSize::of(next); chars.next(); if next == '"' { break; } } (SyntaxKind::String, text) } c if c.is_ascii_digit() => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_digit() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Number, text) } c if c.is_ascii_alphabetic() || c == '_' => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_alphanumeric() || next == '_' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } let kind = match text.as_str() { "let" | "if" | "else" | "while" | "for" | "fn" | "return" | "true" | "false" | "struct" | "enum" | "impl" => SyntaxKind::Keyword, _ => SyntaxKind::Ident, }; (kind, text) } _ => (SyntaxKind::Error, String::from(ch)), }; tokens.push(Token { kind, text, offset: start, }); } tokens } #[derive(Debug)] pub struct IncrementalReparser { _old_tree: SyntaxNodeRef, edits: Vec<TextEdit>, } #[derive(Debug, Clone)] pub struct TextEdit { pub range: TextRange, pub new_text: String, } impl IncrementalReparser { pub fn new(tree: SyntaxNodeRef) -> Self { Self { _old_tree: tree, edits: Vec::new(), } } pub fn add_edit(&mut self, edit: TextEdit) { self.edits.push(edit); } pub fn reparse(&self, new_text: &str) -> ParseResult { let tokens = tokenize(new_text); let parser = Parser::new(tokens); parser.parse() } } pub struct SyntaxTreeBuilder { green: GreenNode, } impl SyntaxTreeBuilder { pub fn new(green: GreenNode) -> Self { Self { green } } pub fn build(self) -> SyntaxNodeRef { SyntaxNodeRef::new_root(self.green) } } pub struct AstNode { syntax: SyntaxNodeRef, } impl AstNode { pub fn cast(syntax: SyntaxNodeRef) -> Option<Self> { match syntax.kind() { SyntaxKind::Root | SyntaxKind::BinaryExpr | SyntaxKind::UnaryExpr | SyntaxKind::ParenExpr | SyntaxKind::Literal | SyntaxKind::BlockStmt | SyntaxKind::ExprStmt | SyntaxKind::LetStmt | SyntaxKind::IfStmt | SyntaxKind::WhileStmt | SyntaxKind::ReturnStmt | SyntaxKind::FnDef | SyntaxKind::CallExpr => Some(Self { syntax }), _ => None, } } pub fn syntax(&self) -> &SyntaxNodeRef { &self.syntax } } pub trait AstToken { fn cast(syntax: SyntaxTokenRef) -> Option<Self> where Self: Sized; fn syntax(&self) -> &SyntaxTokenRef; } pub struct Identifier { syntax: SyntaxTokenRef, } impl AstToken for Identifier { fn cast(syntax: SyntaxTokenRef) -> Option<Self> { if syntax.kind() == SyntaxKind::Ident { Some(Self { syntax }) } else { None } } fn syntax(&self) -> &SyntaxTokenRef { &self.syntax } } impl Identifier { pub fn text(&self) -> &str { self.syntax.text() } } pub fn walk_tree(node: &SyntaxNodeRef, depth: usize) { let indent = " ".repeat(depth); println!("{}{:?}", indent, node.kind()); for child in node.children() { walk_tree(&child, depth + 1); } } pub fn find_node_at_offset(root: &SyntaxNodeRef, offset: TextSize) -> Option<SyntaxNodeRef> { if !root.text_range().contains(offset) { return None; } let mut result = root.clone(); for child in root.children() { if child.text_range().contains(offset) { if let Some(deeper) = find_node_at_offset(&child, offset) { result = deeper; } } } Some(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let input = "let x = 42 + y;"; let tokens = tokenize(input); assert_eq!(tokens[0].kind, SyntaxKind::Keyword); assert_eq!(tokens[0].text, "let"); assert_eq!(tokens[2].kind, SyntaxKind::Ident); assert_eq!(tokens[2].text, "x"); } #[test] fn test_parse_expression() { let input = "x + 42 * (y - 3)"; // First check tokenization let tokens = tokenize(input); // The expression should tokenize properly assert!(tokens.len() > 5, "Expected more tokens, got: {:?}", tokens); // Check if tokens include operators let has_plus = tokens.iter().any(|t| t.kind == SyntaxKind::Plus); let has_star = tokens.iter().any(|t| t.kind == SyntaxKind::Star); assert!(has_plus, "Missing + operator in tokens: {:?}", tokens); assert!(has_star, "Missing * operator in tokens: {:?}", tokens); let tree = parse_expression(input); assert_eq!(tree.kind(), SyntaxKind::Root); // Check the tree contains the full expression let tree_text = tree.text().to_string(); assert_eq!( tree_text.trim(), input, "Tree text doesn't match input. Got '{}' expected '{}'", tree_text.trim(), input ); } #[test] fn test_incremental_reparse() { let input = "let x = 42;"; let tree = parse_expression(input); let mut reparser = IncrementalReparser::new(tree); reparser.add_edit(TextEdit { range: TextRange::new(TextSize::from(8), TextSize::from(10)), new_text: "100".to_string(), }); let new_tree = reparser.reparse("let x = 100;"); assert_eq!(new_tree.errors.len(), 0); } #[test] fn test_ast_node_cast() { let input = "42 + x"; let tree = parse_expression(input); if let Some(first_child) = tree.first_child() { let ast_node = AstNode::cast(first_child); assert!(ast_node.is_some()); } } } pub fn parse_expression(input: &str) -> SyntaxNodeRef { let tokens = tokenize(input); let mut parser = Parser::new(tokens); // Build just an expression tree parser.builder.start_node(SyntaxKind::Root.into()); // Include leading whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } // Parse the expression if !parser.at_end() { parser.expression(); } // Include trailing whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } parser.builder.finish_node(); let green_node = parser.builder.finish(); SyntaxTreeBuilder::new(green_node).build() } }
The parser maintains a builder that constructs the green tree incrementally. The start_node and finish_node methods create hierarchical structure, while the token method adds leaf nodes. This approach allows the parser to build the tree in a single pass without backtracking or tree rewriting.
Expression Parsing
The parser implements recursive descent with operator precedence for expressions:
#![allow(unused)] fn main() { use rowan::{GreenNode, GreenNodeBuilder, Language, SyntaxNode, SyntaxToken, TextRange, TextSize}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u16)] pub enum SyntaxKind { Whitespace = 0, Comment, Ident, Number, String, Plus, Minus, Star, Slash, Eq, Neq, Lt, Gt, LParen, RParen, LBrace, RBrace, Semicolon, Comma, Keyword, Error, Root, BinaryExpr, UnaryExpr, ParenExpr, Literal, BlockStmt, ExprStmt, LetStmt, IfStmt, WhileStmt, ReturnStmt, FnDef, ParamList, TypeRef, Path, CallExpr, ArgList, } impl From<SyntaxKind> for rowan::SyntaxKind { fn from(kind: SyntaxKind) -> Self { Self(kind as u16) } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Lang {} impl Language for Lang { type Kind = SyntaxKind; fn kind_from_raw(raw: rowan::SyntaxKind) -> Self::Kind { assert!(raw.0 <= SyntaxKind::ArgList as u16); unsafe { std::mem::transmute::<u16, SyntaxKind>(raw.0) } } fn kind_to_raw(kind: Self::Kind) -> rowan::SyntaxKind { kind.into() } } pub type SyntaxNodeRef = SyntaxNode<Lang>; pub type SyntaxTokenRef = SyntaxToken<Lang>; #[derive(Debug, Clone)] pub struct ParseResult { pub green_node: GreenNode, pub errors: Vec<ParseError>, } #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub range: TextRange, } pub struct Parser { builder: GreenNodeBuilder<'static>, errors: Vec<ParseError>, tokens: Vec<Token>, cursor: usize, } #[derive(Debug, Clone)] pub struct Token { pub kind: SyntaxKind, pub text: String, pub offset: TextSize, } pub fn tokenize(input: &str) -> Vec<Token> { let mut tokens = Vec::new(); let mut offset = TextSize::from(0); let mut chars = input.chars().peekable(); while let Some(ch) = chars.next() { let start = offset; offset += TextSize::of(ch); let (kind, text) = match ch { ' ' | '\t' | '\n' | '\r' => { let mut text = String::from(ch); while let Some(&next) = chars.peek() { if next.is_whitespace() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Whitespace, text) } '/' if chars.peek() == Some(&'/') => { chars.next(); offset += TextSize::of('/'); let mut text = String::from("//"); while let Some(&next) = chars.peek() { if next != '\n' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Comment, text) } '+' => (SyntaxKind::Plus, String::from("+")), '-' => (SyntaxKind::Minus, String::from("-")), '*' => (SyntaxKind::Star, String::from("*")), '/' => (SyntaxKind::Slash, String::from("/")), '=' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Eq, String::from("==")) } '=' => (SyntaxKind::Eq, String::from("=")), '!' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Neq, String::from("!=")) } '<' => (SyntaxKind::Lt, String::from("<")), '>' => (SyntaxKind::Gt, String::from(">")), '(' => (SyntaxKind::LParen, String::from("(")), ')' => (SyntaxKind::RParen, String::from(")")), '{' => (SyntaxKind::LBrace, String::from("{")), '}' => (SyntaxKind::RBrace, String::from("}")), ';' => (SyntaxKind::Semicolon, String::from(";")), ',' => (SyntaxKind::Comma, String::from(",")), '"' => { let mut text = String::from("\""); while let Some(&next) = chars.peek() { text.push(next); offset += TextSize::of(next); chars.next(); if next == '"' { break; } } (SyntaxKind::String, text) } c if c.is_ascii_digit() => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_digit() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Number, text) } c if c.is_ascii_alphabetic() || c == '_' => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_alphanumeric() || next == '_' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } let kind = match text.as_str() { "let" | "if" | "else" | "while" | "for" | "fn" | "return" | "true" | "false" | "struct" | "enum" | "impl" => SyntaxKind::Keyword, _ => SyntaxKind::Ident, }; (kind, text) } _ => (SyntaxKind::Error, String::from(ch)), }; tokens.push(Token { kind, text, offset: start, }); } tokens } #[derive(Debug)] pub struct IncrementalReparser { _old_tree: SyntaxNodeRef, edits: Vec<TextEdit>, } #[derive(Debug, Clone)] pub struct TextEdit { pub range: TextRange, pub new_text: String, } impl IncrementalReparser { pub fn new(tree: SyntaxNodeRef) -> Self { Self { _old_tree: tree, edits: Vec::new(), } } pub fn add_edit(&mut self, edit: TextEdit) { self.edits.push(edit); } pub fn reparse(&self, new_text: &str) -> ParseResult { let tokens = tokenize(new_text); let parser = Parser::new(tokens); parser.parse() } } pub struct SyntaxTreeBuilder { green: GreenNode, } impl SyntaxTreeBuilder { pub fn new(green: GreenNode) -> Self { Self { green } } pub fn build(self) -> SyntaxNodeRef { SyntaxNodeRef::new_root(self.green) } } pub fn parse_expression(input: &str) -> SyntaxNodeRef { let tokens = tokenize(input); let mut parser = Parser::new(tokens); // Build just an expression tree parser.builder.start_node(SyntaxKind::Root.into()); // Include leading whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } // Parse the expression if !parser.at_end() { parser.expression(); } // Include trailing whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } parser.builder.finish_node(); let green_node = parser.builder.finish(); SyntaxTreeBuilder::new(green_node).build() } pub struct AstNode { syntax: SyntaxNodeRef, } impl AstNode { pub fn cast(syntax: SyntaxNodeRef) -> Option<Self> { match syntax.kind() { SyntaxKind::Root | SyntaxKind::BinaryExpr | SyntaxKind::UnaryExpr | SyntaxKind::ParenExpr | SyntaxKind::Literal | SyntaxKind::BlockStmt | SyntaxKind::ExprStmt | SyntaxKind::LetStmt | SyntaxKind::IfStmt | SyntaxKind::WhileStmt | SyntaxKind::ReturnStmt | SyntaxKind::FnDef | SyntaxKind::CallExpr => Some(Self { syntax }), _ => None, } } pub fn syntax(&self) -> &SyntaxNodeRef { &self.syntax } } pub trait AstToken { fn cast(syntax: SyntaxTokenRef) -> Option<Self> where Self: Sized; fn syntax(&self) -> &SyntaxTokenRef; } pub struct Identifier { syntax: SyntaxTokenRef, } impl AstToken for Identifier { fn cast(syntax: SyntaxTokenRef) -> Option<Self> { if syntax.kind() == SyntaxKind::Ident { Some(Self { syntax }) } else { None } } fn syntax(&self) -> &SyntaxTokenRef { &self.syntax } } impl Identifier { pub fn text(&self) -> &str { self.syntax.text() } } pub fn walk_tree(node: &SyntaxNodeRef, depth: usize) { let indent = " ".repeat(depth); println!("{}{:?}", indent, node.kind()); for child in node.children() { walk_tree(&child, depth + 1); } } pub fn find_node_at_offset(root: &SyntaxNodeRef, offset: TextSize) -> Option<SyntaxNodeRef> { if !root.text_range().contains(offset) { return None; } let mut result = root.clone(); for child in root.children() { if child.text_range().contains(offset) { if let Some(deeper) = find_node_at_offset(&child, offset) { result = deeper; } } } Some(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let input = "let x = 42 + y;"; let tokens = tokenize(input); assert_eq!(tokens[0].kind, SyntaxKind::Keyword); assert_eq!(tokens[0].text, "let"); assert_eq!(tokens[2].kind, SyntaxKind::Ident); assert_eq!(tokens[2].text, "x"); } #[test] fn test_parse_expression() { let input = "x + 42 * (y - 3)"; // First check tokenization let tokens = tokenize(input); // The expression should tokenize properly assert!(tokens.len() > 5, "Expected more tokens, got: {:?}", tokens); // Check if tokens include operators let has_plus = tokens.iter().any(|t| t.kind == SyntaxKind::Plus); let has_star = tokens.iter().any(|t| t.kind == SyntaxKind::Star); assert!(has_plus, "Missing + operator in tokens: {:?}", tokens); assert!(has_star, "Missing * operator in tokens: {:?}", tokens); let tree = parse_expression(input); assert_eq!(tree.kind(), SyntaxKind::Root); // Check the tree contains the full expression let tree_text = tree.text().to_string(); assert_eq!( tree_text.trim(), input, "Tree text doesn't match input. Got '{}' expected '{}'", tree_text.trim(), input ); } #[test] fn test_incremental_reparse() { let input = "let x = 42;"; let tree = parse_expression(input); let mut reparser = IncrementalReparser::new(tree); reparser.add_edit(TextEdit { range: TextRange::new(TextSize::from(8), TextSize::from(10)), new_text: "100".to_string(), }); let new_tree = reparser.reparse("let x = 100;"); assert_eq!(new_tree.errors.len(), 0); } #[test] fn test_ast_node_cast() { let input = "42 + x"; let tree = parse_expression(input); if let Some(first_child) = tree.first_child() { let ast_node = AstNode::cast(first_child); assert!(ast_node.is_some()); } } } impl Parser { pub fn new(tokens: Vec<Token>) -> Self { Self { builder: GreenNodeBuilder::new(), errors: Vec::new(), tokens, cursor: 0, } } pub fn parse(mut self) -> ParseResult { self.builder.start_node(SyntaxKind::Root.into()); while !self.at_end() { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.builder.finish_node(); ParseResult { green_node: self.builder.finish(), errors: self.errors, } } fn statement(&mut self) { match self.current_kind() { Some(SyntaxKind::Keyword) if self.current_text() == Some("let") => { self.let_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("if") => { self.if_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("while") => { self.while_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("return") => { self.return_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("fn") => { self.function_definition(); } _ => { self.expression_statement(); } } } fn let_statement(&mut self) { self.builder.start_node(SyntaxKind::LetStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Eq) { self.consume(SyntaxKind::Eq); self.skip_trivia(); self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn if_statement(&mut self) { self.builder.start_node(SyntaxKind::IfStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); if self.at_keyword("else") { self.consume(SyntaxKind::Keyword); self.skip_trivia(); if self.at_keyword("if") { self.if_statement(); } else { self.block(); } } self.builder.finish_node(); } fn while_statement(&mut self) { self.builder.start_node(SyntaxKind::WhileStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn return_statement(&mut self) { self.builder.start_node(SyntaxKind::ReturnStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); if !self.at(SyntaxKind::Semicolon) { self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn function_definition(&mut self) { self.builder.start_node(SyntaxKind::FnDef.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); self.parameter_list(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn parameter_list(&mut self) { self.builder.start_node(SyntaxKind::ParamList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn block(&mut self) { self.builder.start_node(SyntaxKind::BlockStmt.into()); self.consume(SyntaxKind::LBrace); while !self.at_end() && !self.at(SyntaxKind::RBrace) { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.consume(SyntaxKind::RBrace); self.builder.finish_node(); } fn expression_statement(&mut self) { self.builder.start_node(SyntaxKind::ExprStmt.into()); self.expression(); self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn expression(&mut self) { self.binary_expression(0); } fn binary_expression(&mut self, min_precedence: u8) { self.unary_expression(); // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } while let Some(op_precedence) = self.current_binary_op_precedence() { if op_precedence < min_precedence { break; } let checkpoint = self.builder.checkpoint(); if let Some( k @ (SyntaxKind::Plus | SyntaxKind::Minus | SyntaxKind::Star | SyntaxKind::Slash | SyntaxKind::Eq | SyntaxKind::Neq | SyntaxKind::Lt | SyntaxKind::Gt), ) = self.current_kind() { self.consume(k); } // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.binary_expression(op_precedence + 1); self.builder .start_node_at(checkpoint, SyntaxKind::BinaryExpr.into()); self.builder.finish_node(); } } fn unary_expression(&mut self) { if self.at(SyntaxKind::Minus) || self.at(SyntaxKind::Plus) { self.builder.start_node(SyntaxKind::UnaryExpr.into()); self.consume(self.current_kind().unwrap()); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.unary_expression(); self.builder.finish_node(); } else { self.postfix_expression(); } } fn postfix_expression(&mut self) { self.primary_expression(); while self.at(SyntaxKind::LParen) { self.builder.start_node(SyntaxKind::CallExpr.into()); let checkpoint = self.builder.checkpoint(); self.argument_list(); self.builder .start_node_at(checkpoint, SyntaxKind::CallExpr.into()); self.builder.finish_node(); } } fn argument_list(&mut self) { self.builder.start_node(SyntaxKind::ArgList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.expression(); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn primary_expression(&mut self) { match self.current_kind() { Some(SyntaxKind::Number) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::Number); self.builder.finish_node(); } Some(SyntaxKind::String) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::String); self.builder.finish_node(); } Some(SyntaxKind::Ident) => { self.consume(SyntaxKind::Ident); } Some(SyntaxKind::LParen) => { self.builder.start_node(SyntaxKind::ParenExpr.into()); self.consume(SyntaxKind::LParen); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.expression(); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } _ => { self.error("Expected expression"); self.advance(); } } } fn current_binary_op_precedence(&self) -> Option<u8> { match self.current_kind()? { SyntaxKind::Star | SyntaxKind::Slash => Some(5), SyntaxKind::Plus | SyntaxKind::Minus => Some(4), SyntaxKind::Lt | SyntaxKind::Gt => Some(3), SyntaxKind::Eq | SyntaxKind::Neq => Some(2), _ => None, } } fn trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { let kind = self.current_kind().unwrap(); self.consume(kind); } } fn skip_trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.advance(); } } fn consume(&mut self, expected: SyntaxKind) { if self.at(expected) { let token = &self.tokens[self.cursor]; self.builder.token(expected.into(), &token.text); self.advance(); } else { self.error(&format!("Expected {:?}", expected)); } } fn at(&self, kind: SyntaxKind) -> bool { self.current_kind() == Some(kind) } fn at_keyword(&self, keyword: &str) -> bool { self.at(SyntaxKind::Keyword) && self.current_text() == Some(keyword) } fn current_kind(&self) -> Option<SyntaxKind> { self.tokens.get(self.cursor).map(|t| t.kind) } fn current_text(&self) -> Option<&str> { self.tokens.get(self.cursor).map(|t| t.text.as_str()) } fn advance(&mut self) { if self.cursor < self.tokens.len() { self.cursor += 1; } } fn at_end(&self) -> bool { self.cursor >= self.tokens.len() } fn error(&mut self, message: &str) { let offset = self .tokens .get(self.cursor) .map(|t| t.offset) .unwrap_or_else(|| TextSize::from(0)); self.errors.push(ParseError { message: message.to_string(), range: TextRange::empty(offset), }); } } }
Binary expression parsing uses precedence climbing to handle operator priorities correctly. The method recursively parses higher-precedence expressions on the right side, building a left-associative tree structure. The checkpoint mechanism allows the parser to reorganize nodes during parsing without rebuilding the entire subtree.
Statement Parsing
Statement parsing demonstrates the handling of control flow constructs. The parser handles nested structures like if-else chains and function definitions by recursively invoking the appropriate parsing methods. Each statement type creates its own node in the syntax tree, preserving the complete structure including keywords, delimiters, and nested blocks.
Tokenization
The tokenizer converts source text into a stream of typed tokens:
#![allow(unused)] fn main() { use rowan::{GreenNode, GreenNodeBuilder, Language, SyntaxNode, SyntaxToken, TextRange, TextSize}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u16)] pub enum SyntaxKind { Whitespace = 0, Comment, Ident, Number, String, Plus, Minus, Star, Slash, Eq, Neq, Lt, Gt, LParen, RParen, LBrace, RBrace, Semicolon, Comma, Keyword, Error, Root, BinaryExpr, UnaryExpr, ParenExpr, Literal, BlockStmt, ExprStmt, LetStmt, IfStmt, WhileStmt, ReturnStmt, FnDef, ParamList, TypeRef, Path, CallExpr, ArgList, } impl From<SyntaxKind> for rowan::SyntaxKind { fn from(kind: SyntaxKind) -> Self { Self(kind as u16) } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Lang {} impl Language for Lang { type Kind = SyntaxKind; fn kind_from_raw(raw: rowan::SyntaxKind) -> Self::Kind { assert!(raw.0 <= SyntaxKind::ArgList as u16); unsafe { std::mem::transmute::<u16, SyntaxKind>(raw.0) } } fn kind_to_raw(kind: Self::Kind) -> rowan::SyntaxKind { kind.into() } } pub type SyntaxNodeRef = SyntaxNode<Lang>; pub type SyntaxTokenRef = SyntaxToken<Lang>; #[derive(Debug, Clone)] pub struct ParseResult { pub green_node: GreenNode, pub errors: Vec<ParseError>, } #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub range: TextRange, } pub struct Parser { builder: GreenNodeBuilder<'static>, errors: Vec<ParseError>, tokens: Vec<Token>, cursor: usize, } #[derive(Debug, Clone)] pub struct Token { pub kind: SyntaxKind, pub text: String, pub offset: TextSize, } impl Parser { pub fn new(tokens: Vec<Token>) -> Self { Self { builder: GreenNodeBuilder::new(), errors: Vec::new(), tokens, cursor: 0, } } pub fn parse(mut self) -> ParseResult { self.builder.start_node(SyntaxKind::Root.into()); while !self.at_end() { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.builder.finish_node(); ParseResult { green_node: self.builder.finish(), errors: self.errors, } } fn statement(&mut self) { match self.current_kind() { Some(SyntaxKind::Keyword) if self.current_text() == Some("let") => { self.let_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("if") => { self.if_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("while") => { self.while_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("return") => { self.return_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("fn") => { self.function_definition(); } _ => { self.expression_statement(); } } } fn let_statement(&mut self) { self.builder.start_node(SyntaxKind::LetStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Eq) { self.consume(SyntaxKind::Eq); self.skip_trivia(); self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn if_statement(&mut self) { self.builder.start_node(SyntaxKind::IfStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); if self.at_keyword("else") { self.consume(SyntaxKind::Keyword); self.skip_trivia(); if self.at_keyword("if") { self.if_statement(); } else { self.block(); } } self.builder.finish_node(); } fn while_statement(&mut self) { self.builder.start_node(SyntaxKind::WhileStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn return_statement(&mut self) { self.builder.start_node(SyntaxKind::ReturnStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); if !self.at(SyntaxKind::Semicolon) { self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn function_definition(&mut self) { self.builder.start_node(SyntaxKind::FnDef.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); self.parameter_list(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn parameter_list(&mut self) { self.builder.start_node(SyntaxKind::ParamList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn block(&mut self) { self.builder.start_node(SyntaxKind::BlockStmt.into()); self.consume(SyntaxKind::LBrace); while !self.at_end() && !self.at(SyntaxKind::RBrace) { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.consume(SyntaxKind::RBrace); self.builder.finish_node(); } fn expression_statement(&mut self) { self.builder.start_node(SyntaxKind::ExprStmt.into()); self.expression(); self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn expression(&mut self) { self.binary_expression(0); } fn binary_expression(&mut self, min_precedence: u8) { self.unary_expression(); // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } while let Some(op_precedence) = self.current_binary_op_precedence() { if op_precedence < min_precedence { break; } let checkpoint = self.builder.checkpoint(); if let Some( k @ (SyntaxKind::Plus | SyntaxKind::Minus | SyntaxKind::Star | SyntaxKind::Slash | SyntaxKind::Eq | SyntaxKind::Neq | SyntaxKind::Lt | SyntaxKind::Gt), ) = self.current_kind() { self.consume(k); } // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.binary_expression(op_precedence + 1); self.builder .start_node_at(checkpoint, SyntaxKind::BinaryExpr.into()); self.builder.finish_node(); } } fn unary_expression(&mut self) { if self.at(SyntaxKind::Minus) || self.at(SyntaxKind::Plus) { self.builder.start_node(SyntaxKind::UnaryExpr.into()); self.consume(self.current_kind().unwrap()); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.unary_expression(); self.builder.finish_node(); } else { self.postfix_expression(); } } fn postfix_expression(&mut self) { self.primary_expression(); while self.at(SyntaxKind::LParen) { self.builder.start_node(SyntaxKind::CallExpr.into()); let checkpoint = self.builder.checkpoint(); self.argument_list(); self.builder .start_node_at(checkpoint, SyntaxKind::CallExpr.into()); self.builder.finish_node(); } } fn argument_list(&mut self) { self.builder.start_node(SyntaxKind::ArgList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.expression(); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn primary_expression(&mut self) { match self.current_kind() { Some(SyntaxKind::Number) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::Number); self.builder.finish_node(); } Some(SyntaxKind::String) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::String); self.builder.finish_node(); } Some(SyntaxKind::Ident) => { self.consume(SyntaxKind::Ident); } Some(SyntaxKind::LParen) => { self.builder.start_node(SyntaxKind::ParenExpr.into()); self.consume(SyntaxKind::LParen); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.expression(); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } _ => { self.error("Expected expression"); self.advance(); } } } fn current_binary_op_precedence(&self) -> Option<u8> { match self.current_kind()? { SyntaxKind::Star | SyntaxKind::Slash => Some(5), SyntaxKind::Plus | SyntaxKind::Minus => Some(4), SyntaxKind::Lt | SyntaxKind::Gt => Some(3), SyntaxKind::Eq | SyntaxKind::Neq => Some(2), _ => None, } } fn trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { let kind = self.current_kind().unwrap(); self.consume(kind); } } fn skip_trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.advance(); } } fn consume(&mut self, expected: SyntaxKind) { if self.at(expected) { let token = &self.tokens[self.cursor]; self.builder.token(expected.into(), &token.text); self.advance(); } else { self.error(&format!("Expected {:?}", expected)); } } fn at(&self, kind: SyntaxKind) -> bool { self.current_kind() == Some(kind) } fn at_keyword(&self, keyword: &str) -> bool { self.at(SyntaxKind::Keyword) && self.current_text() == Some(keyword) } fn current_kind(&self) -> Option<SyntaxKind> { self.tokens.get(self.cursor).map(|t| t.kind) } fn current_text(&self) -> Option<&str> { self.tokens.get(self.cursor).map(|t| t.text.as_str()) } fn advance(&mut self) { if self.cursor < self.tokens.len() { self.cursor += 1; } } fn at_end(&self) -> bool { self.cursor >= self.tokens.len() } fn error(&mut self, message: &str) { let offset = self .tokens .get(self.cursor) .map(|t| t.offset) .unwrap_or_else(|| TextSize::from(0)); self.errors.push(ParseError { message: message.to_string(), range: TextRange::empty(offset), }); } } #[derive(Debug)] pub struct IncrementalReparser { _old_tree: SyntaxNodeRef, edits: Vec<TextEdit>, } #[derive(Debug, Clone)] pub struct TextEdit { pub range: TextRange, pub new_text: String, } impl IncrementalReparser { pub fn new(tree: SyntaxNodeRef) -> Self { Self { _old_tree: tree, edits: Vec::new(), } } pub fn add_edit(&mut self, edit: TextEdit) { self.edits.push(edit); } pub fn reparse(&self, new_text: &str) -> ParseResult { let tokens = tokenize(new_text); let parser = Parser::new(tokens); parser.parse() } } pub struct SyntaxTreeBuilder { green: GreenNode, } impl SyntaxTreeBuilder { pub fn new(green: GreenNode) -> Self { Self { green } } pub fn build(self) -> SyntaxNodeRef { SyntaxNodeRef::new_root(self.green) } } pub fn parse_expression(input: &str) -> SyntaxNodeRef { let tokens = tokenize(input); let mut parser = Parser::new(tokens); // Build just an expression tree parser.builder.start_node(SyntaxKind::Root.into()); // Include leading whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } // Parse the expression if !parser.at_end() { parser.expression(); } // Include trailing whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } parser.builder.finish_node(); let green_node = parser.builder.finish(); SyntaxTreeBuilder::new(green_node).build() } pub struct AstNode { syntax: SyntaxNodeRef, } impl AstNode { pub fn cast(syntax: SyntaxNodeRef) -> Option<Self> { match syntax.kind() { SyntaxKind::Root | SyntaxKind::BinaryExpr | SyntaxKind::UnaryExpr | SyntaxKind::ParenExpr | SyntaxKind::Literal | SyntaxKind::BlockStmt | SyntaxKind::ExprStmt | SyntaxKind::LetStmt | SyntaxKind::IfStmt | SyntaxKind::WhileStmt | SyntaxKind::ReturnStmt | SyntaxKind::FnDef | SyntaxKind::CallExpr => Some(Self { syntax }), _ => None, } } pub fn syntax(&self) -> &SyntaxNodeRef { &self.syntax } } pub trait AstToken { fn cast(syntax: SyntaxTokenRef) -> Option<Self> where Self: Sized; fn syntax(&self) -> &SyntaxTokenRef; } pub struct Identifier { syntax: SyntaxTokenRef, } impl AstToken for Identifier { fn cast(syntax: SyntaxTokenRef) -> Option<Self> { if syntax.kind() == SyntaxKind::Ident { Some(Self { syntax }) } else { None } } fn syntax(&self) -> &SyntaxTokenRef { &self.syntax } } impl Identifier { pub fn text(&self) -> &str { self.syntax.text() } } pub fn walk_tree(node: &SyntaxNodeRef, depth: usize) { let indent = " ".repeat(depth); println!("{}{:?}", indent, node.kind()); for child in node.children() { walk_tree(&child, depth + 1); } } pub fn find_node_at_offset(root: &SyntaxNodeRef, offset: TextSize) -> Option<SyntaxNodeRef> { if !root.text_range().contains(offset) { return None; } let mut result = root.clone(); for child in root.children() { if child.text_range().contains(offset) { if let Some(deeper) = find_node_at_offset(&child, offset) { result = deeper; } } } Some(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let input = "let x = 42 + y;"; let tokens = tokenize(input); assert_eq!(tokens[0].kind, SyntaxKind::Keyword); assert_eq!(tokens[0].text, "let"); assert_eq!(tokens[2].kind, SyntaxKind::Ident); assert_eq!(tokens[2].text, "x"); } #[test] fn test_parse_expression() { let input = "x + 42 * (y - 3)"; // First check tokenization let tokens = tokenize(input); // The expression should tokenize properly assert!(tokens.len() > 5, "Expected more tokens, got: {:?}", tokens); // Check if tokens include operators let has_plus = tokens.iter().any(|t| t.kind == SyntaxKind::Plus); let has_star = tokens.iter().any(|t| t.kind == SyntaxKind::Star); assert!(has_plus, "Missing + operator in tokens: {:?}", tokens); assert!(has_star, "Missing * operator in tokens: {:?}", tokens); let tree = parse_expression(input); assert_eq!(tree.kind(), SyntaxKind::Root); // Check the tree contains the full expression let tree_text = tree.text().to_string(); assert_eq!( tree_text.trim(), input, "Tree text doesn't match input. Got '{}' expected '{}'", tree_text.trim(), input ); } #[test] fn test_incremental_reparse() { let input = "let x = 42;"; let tree = parse_expression(input); let mut reparser = IncrementalReparser::new(tree); reparser.add_edit(TextEdit { range: TextRange::new(TextSize::from(8), TextSize::from(10)), new_text: "100".to_string(), }); let new_tree = reparser.reparse("let x = 100;"); assert_eq!(new_tree.errors.len(), 0); } #[test] fn test_ast_node_cast() { let input = "42 + x"; let tree = parse_expression(input); if let Some(first_child) = tree.first_child() { let ast_node = AstNode::cast(first_child); assert!(ast_node.is_some()); } } } pub fn tokenize(input: &str) -> Vec<Token> { let mut tokens = Vec::new(); let mut offset = TextSize::from(0); let mut chars = input.chars().peekable(); while let Some(ch) = chars.next() { let start = offset; offset += TextSize::of(ch); let (kind, text) = match ch { ' ' | '\t' | '\n' | '\r' => { let mut text = String::from(ch); while let Some(&next) = chars.peek() { if next.is_whitespace() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Whitespace, text) } '/' if chars.peek() == Some(&'/') => { chars.next(); offset += TextSize::of('/'); let mut text = String::from("//"); while let Some(&next) = chars.peek() { if next != '\n' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Comment, text) } '+' => (SyntaxKind::Plus, String::from("+")), '-' => (SyntaxKind::Minus, String::from("-")), '*' => (SyntaxKind::Star, String::from("*")), '/' => (SyntaxKind::Slash, String::from("/")), '=' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Eq, String::from("==")) } '=' => (SyntaxKind::Eq, String::from("=")), '!' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Neq, String::from("!=")) } '<' => (SyntaxKind::Lt, String::from("<")), '>' => (SyntaxKind::Gt, String::from(">")), '(' => (SyntaxKind::LParen, String::from("(")), ')' => (SyntaxKind::RParen, String::from(")")), '{' => (SyntaxKind::LBrace, String::from("{")), '}' => (SyntaxKind::RBrace, String::from("}")), ';' => (SyntaxKind::Semicolon, String::from(";")), ',' => (SyntaxKind::Comma, String::from(",")), '"' => { let mut text = String::from("\""); while let Some(&next) = chars.peek() { text.push(next); offset += TextSize::of(next); chars.next(); if next == '"' { break; } } (SyntaxKind::String, text) } c if c.is_ascii_digit() => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_digit() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Number, text) } c if c.is_ascii_alphabetic() || c == '_' => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_alphanumeric() || next == '_' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } let kind = match text.as_str() { "let" | "if" | "else" | "while" | "for" | "fn" | "return" | "true" | "false" | "struct" | "enum" | "impl" => SyntaxKind::Keyword, _ => SyntaxKind::Ident, }; (kind, text) } _ => (SyntaxKind::Error, String::from(ch)), }; tokens.push(Token { kind, text, offset: start, }); } tokens } }
Tokenization tracks both the token content and its position in the source text using TextSize. This position information enables accurate error reporting and supports incremental reparsing by identifying which tokens have changed. The tokenizer handles multi-character tokens like comments and strings by consuming characters until reaching a delimiter.
Incremental Reparsing
Rowan supports efficient incremental updates through text edits:
#![allow(unused)] fn main() { use rowan::{GreenNode, GreenNodeBuilder, Language, SyntaxNode, SyntaxToken, TextRange, TextSize}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u16)] pub enum SyntaxKind { Whitespace = 0, Comment, Ident, Number, String, Plus, Minus, Star, Slash, Eq, Neq, Lt, Gt, LParen, RParen, LBrace, RBrace, Semicolon, Comma, Keyword, Error, Root, BinaryExpr, UnaryExpr, ParenExpr, Literal, BlockStmt, ExprStmt, LetStmt, IfStmt, WhileStmt, ReturnStmt, FnDef, ParamList, TypeRef, Path, CallExpr, ArgList, } impl From<SyntaxKind> for rowan::SyntaxKind { fn from(kind: SyntaxKind) -> Self { Self(kind as u16) } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Lang {} impl Language for Lang { type Kind = SyntaxKind; fn kind_from_raw(raw: rowan::SyntaxKind) -> Self::Kind { assert!(raw.0 <= SyntaxKind::ArgList as u16); unsafe { std::mem::transmute::<u16, SyntaxKind>(raw.0) } } fn kind_to_raw(kind: Self::Kind) -> rowan::SyntaxKind { kind.into() } } pub type SyntaxNodeRef = SyntaxNode<Lang>; pub type SyntaxTokenRef = SyntaxToken<Lang>; #[derive(Debug, Clone)] pub struct ParseResult { pub green_node: GreenNode, pub errors: Vec<ParseError>, } #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub range: TextRange, } pub struct Parser { builder: GreenNodeBuilder<'static>, errors: Vec<ParseError>, tokens: Vec<Token>, cursor: usize, } #[derive(Debug, Clone)] pub struct Token { pub kind: SyntaxKind, pub text: String, pub offset: TextSize, } impl Parser { pub fn new(tokens: Vec<Token>) -> Self { Self { builder: GreenNodeBuilder::new(), errors: Vec::new(), tokens, cursor: 0, } } pub fn parse(mut self) -> ParseResult { self.builder.start_node(SyntaxKind::Root.into()); while !self.at_end() { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.builder.finish_node(); ParseResult { green_node: self.builder.finish(), errors: self.errors, } } fn statement(&mut self) { match self.current_kind() { Some(SyntaxKind::Keyword) if self.current_text() == Some("let") => { self.let_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("if") => { self.if_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("while") => { self.while_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("return") => { self.return_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("fn") => { self.function_definition(); } _ => { self.expression_statement(); } } } fn let_statement(&mut self) { self.builder.start_node(SyntaxKind::LetStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Eq) { self.consume(SyntaxKind::Eq); self.skip_trivia(); self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn if_statement(&mut self) { self.builder.start_node(SyntaxKind::IfStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); if self.at_keyword("else") { self.consume(SyntaxKind::Keyword); self.skip_trivia(); if self.at_keyword("if") { self.if_statement(); } else { self.block(); } } self.builder.finish_node(); } fn while_statement(&mut self) { self.builder.start_node(SyntaxKind::WhileStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn return_statement(&mut self) { self.builder.start_node(SyntaxKind::ReturnStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); if !self.at(SyntaxKind::Semicolon) { self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn function_definition(&mut self) { self.builder.start_node(SyntaxKind::FnDef.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); self.parameter_list(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn parameter_list(&mut self) { self.builder.start_node(SyntaxKind::ParamList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn block(&mut self) { self.builder.start_node(SyntaxKind::BlockStmt.into()); self.consume(SyntaxKind::LBrace); while !self.at_end() && !self.at(SyntaxKind::RBrace) { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.consume(SyntaxKind::RBrace); self.builder.finish_node(); } fn expression_statement(&mut self) { self.builder.start_node(SyntaxKind::ExprStmt.into()); self.expression(); self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn expression(&mut self) { self.binary_expression(0); } fn binary_expression(&mut self, min_precedence: u8) { self.unary_expression(); // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } while let Some(op_precedence) = self.current_binary_op_precedence() { if op_precedence < min_precedence { break; } let checkpoint = self.builder.checkpoint(); if let Some( k @ (SyntaxKind::Plus | SyntaxKind::Minus | SyntaxKind::Star | SyntaxKind::Slash | SyntaxKind::Eq | SyntaxKind::Neq | SyntaxKind::Lt | SyntaxKind::Gt), ) = self.current_kind() { self.consume(k); } // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.binary_expression(op_precedence + 1); self.builder .start_node_at(checkpoint, SyntaxKind::BinaryExpr.into()); self.builder.finish_node(); } } fn unary_expression(&mut self) { if self.at(SyntaxKind::Minus) || self.at(SyntaxKind::Plus) { self.builder.start_node(SyntaxKind::UnaryExpr.into()); self.consume(self.current_kind().unwrap()); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.unary_expression(); self.builder.finish_node(); } else { self.postfix_expression(); } } fn postfix_expression(&mut self) { self.primary_expression(); while self.at(SyntaxKind::LParen) { self.builder.start_node(SyntaxKind::CallExpr.into()); let checkpoint = self.builder.checkpoint(); self.argument_list(); self.builder .start_node_at(checkpoint, SyntaxKind::CallExpr.into()); self.builder.finish_node(); } } fn argument_list(&mut self) { self.builder.start_node(SyntaxKind::ArgList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.expression(); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn primary_expression(&mut self) { match self.current_kind() { Some(SyntaxKind::Number) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::Number); self.builder.finish_node(); } Some(SyntaxKind::String) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::String); self.builder.finish_node(); } Some(SyntaxKind::Ident) => { self.consume(SyntaxKind::Ident); } Some(SyntaxKind::LParen) => { self.builder.start_node(SyntaxKind::ParenExpr.into()); self.consume(SyntaxKind::LParen); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.expression(); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } _ => { self.error("Expected expression"); self.advance(); } } } fn current_binary_op_precedence(&self) -> Option<u8> { match self.current_kind()? { SyntaxKind::Star | SyntaxKind::Slash => Some(5), SyntaxKind::Plus | SyntaxKind::Minus => Some(4), SyntaxKind::Lt | SyntaxKind::Gt => Some(3), SyntaxKind::Eq | SyntaxKind::Neq => Some(2), _ => None, } } fn trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { let kind = self.current_kind().unwrap(); self.consume(kind); } } fn skip_trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.advance(); } } fn consume(&mut self, expected: SyntaxKind) { if self.at(expected) { let token = &self.tokens[self.cursor]; self.builder.token(expected.into(), &token.text); self.advance(); } else { self.error(&format!("Expected {:?}", expected)); } } fn at(&self, kind: SyntaxKind) -> bool { self.current_kind() == Some(kind) } fn at_keyword(&self, keyword: &str) -> bool { self.at(SyntaxKind::Keyword) && self.current_text() == Some(keyword) } fn current_kind(&self) -> Option<SyntaxKind> { self.tokens.get(self.cursor).map(|t| t.kind) } fn current_text(&self) -> Option<&str> { self.tokens.get(self.cursor).map(|t| t.text.as_str()) } fn advance(&mut self) { if self.cursor < self.tokens.len() { self.cursor += 1; } } fn at_end(&self) -> bool { self.cursor >= self.tokens.len() } fn error(&mut self, message: &str) { let offset = self .tokens .get(self.cursor) .map(|t| t.offset) .unwrap_or_else(|| TextSize::from(0)); self.errors.push(ParseError { message: message.to_string(), range: TextRange::empty(offset), }); } } pub fn tokenize(input: &str) -> Vec<Token> { let mut tokens = Vec::new(); let mut offset = TextSize::from(0); let mut chars = input.chars().peekable(); while let Some(ch) = chars.next() { let start = offset; offset += TextSize::of(ch); let (kind, text) = match ch { ' ' | '\t' | '\n' | '\r' => { let mut text = String::from(ch); while let Some(&next) = chars.peek() { if next.is_whitespace() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Whitespace, text) } '/' if chars.peek() == Some(&'/') => { chars.next(); offset += TextSize::of('/'); let mut text = String::from("//"); while let Some(&next) = chars.peek() { if next != '\n' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Comment, text) } '+' => (SyntaxKind::Plus, String::from("+")), '-' => (SyntaxKind::Minus, String::from("-")), '*' => (SyntaxKind::Star, String::from("*")), '/' => (SyntaxKind::Slash, String::from("/")), '=' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Eq, String::from("==")) } '=' => (SyntaxKind::Eq, String::from("=")), '!' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Neq, String::from("!=")) } '<' => (SyntaxKind::Lt, String::from("<")), '>' => (SyntaxKind::Gt, String::from(">")), '(' => (SyntaxKind::LParen, String::from("(")), ')' => (SyntaxKind::RParen, String::from(")")), '{' => (SyntaxKind::LBrace, String::from("{")), '}' => (SyntaxKind::RBrace, String::from("}")), ';' => (SyntaxKind::Semicolon, String::from(";")), ',' => (SyntaxKind::Comma, String::from(",")), '"' => { let mut text = String::from("\""); while let Some(&next) = chars.peek() { text.push(next); offset += TextSize::of(next); chars.next(); if next == '"' { break; } } (SyntaxKind::String, text) } c if c.is_ascii_digit() => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_digit() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Number, text) } c if c.is_ascii_alphabetic() || c == '_' => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_alphanumeric() || next == '_' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } let kind = match text.as_str() { "let" | "if" | "else" | "while" | "for" | "fn" | "return" | "true" | "false" | "struct" | "enum" | "impl" => SyntaxKind::Keyword, _ => SyntaxKind::Ident, }; (kind, text) } _ => (SyntaxKind::Error, String::from(ch)), }; tokens.push(Token { kind, text, offset: start, }); } tokens } #[derive(Debug, Clone)] pub struct TextEdit { pub range: TextRange, pub new_text: String, } impl IncrementalReparser { pub fn new(tree: SyntaxNodeRef) -> Self { Self { _old_tree: tree, edits: Vec::new(), } } pub fn add_edit(&mut self, edit: TextEdit) { self.edits.push(edit); } pub fn reparse(&self, new_text: &str) -> ParseResult { let tokens = tokenize(new_text); let parser = Parser::new(tokens); parser.parse() } } pub struct SyntaxTreeBuilder { green: GreenNode, } impl SyntaxTreeBuilder { pub fn new(green: GreenNode) -> Self { Self { green } } pub fn build(self) -> SyntaxNodeRef { SyntaxNodeRef::new_root(self.green) } } pub fn parse_expression(input: &str) -> SyntaxNodeRef { let tokens = tokenize(input); let mut parser = Parser::new(tokens); // Build just an expression tree parser.builder.start_node(SyntaxKind::Root.into()); // Include leading whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } // Parse the expression if !parser.at_end() { parser.expression(); } // Include trailing whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } parser.builder.finish_node(); let green_node = parser.builder.finish(); SyntaxTreeBuilder::new(green_node).build() } pub struct AstNode { syntax: SyntaxNodeRef, } impl AstNode { pub fn cast(syntax: SyntaxNodeRef) -> Option<Self> { match syntax.kind() { SyntaxKind::Root | SyntaxKind::BinaryExpr | SyntaxKind::UnaryExpr | SyntaxKind::ParenExpr | SyntaxKind::Literal | SyntaxKind::BlockStmt | SyntaxKind::ExprStmt | SyntaxKind::LetStmt | SyntaxKind::IfStmt | SyntaxKind::WhileStmt | SyntaxKind::ReturnStmt | SyntaxKind::FnDef | SyntaxKind::CallExpr => Some(Self { syntax }), _ => None, } } pub fn syntax(&self) -> &SyntaxNodeRef { &self.syntax } } pub trait AstToken { fn cast(syntax: SyntaxTokenRef) -> Option<Self> where Self: Sized; fn syntax(&self) -> &SyntaxTokenRef; } pub struct Identifier { syntax: SyntaxTokenRef, } impl AstToken for Identifier { fn cast(syntax: SyntaxTokenRef) -> Option<Self> { if syntax.kind() == SyntaxKind::Ident { Some(Self { syntax }) } else { None } } fn syntax(&self) -> &SyntaxTokenRef { &self.syntax } } impl Identifier { pub fn text(&self) -> &str { self.syntax.text() } } pub fn walk_tree(node: &SyntaxNodeRef, depth: usize) { let indent = " ".repeat(depth); println!("{}{:?}", indent, node.kind()); for child in node.children() { walk_tree(&child, depth + 1); } } pub fn find_node_at_offset(root: &SyntaxNodeRef, offset: TextSize) -> Option<SyntaxNodeRef> { if !root.text_range().contains(offset) { return None; } let mut result = root.clone(); for child in root.children() { if child.text_range().contains(offset) { if let Some(deeper) = find_node_at_offset(&child, offset) { result = deeper; } } } Some(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let input = "let x = 42 + y;"; let tokens = tokenize(input); assert_eq!(tokens[0].kind, SyntaxKind::Keyword); assert_eq!(tokens[0].text, "let"); assert_eq!(tokens[2].kind, SyntaxKind::Ident); assert_eq!(tokens[2].text, "x"); } #[test] fn test_parse_expression() { let input = "x + 42 * (y - 3)"; // First check tokenization let tokens = tokenize(input); // The expression should tokenize properly assert!(tokens.len() > 5, "Expected more tokens, got: {:?}", tokens); // Check if tokens include operators let has_plus = tokens.iter().any(|t| t.kind == SyntaxKind::Plus); let has_star = tokens.iter().any(|t| t.kind == SyntaxKind::Star); assert!(has_plus, "Missing + operator in tokens: {:?}", tokens); assert!(has_star, "Missing * operator in tokens: {:?}", tokens); let tree = parse_expression(input); assert_eq!(tree.kind(), SyntaxKind::Root); // Check the tree contains the full expression let tree_text = tree.text().to_string(); assert_eq!( tree_text.trim(), input, "Tree text doesn't match input. Got '{}' expected '{}'", tree_text.trim(), input ); } #[test] fn test_incremental_reparse() { let input = "let x = 42;"; let tree = parse_expression(input); let mut reparser = IncrementalReparser::new(tree); reparser.add_edit(TextEdit { range: TextRange::new(TextSize::from(8), TextSize::from(10)), new_text: "100".to_string(), }); let new_tree = reparser.reparse("let x = 100;"); assert_eq!(new_tree.errors.len(), 0); } #[test] fn test_ast_node_cast() { let input = "42 + x"; let tree = parse_expression(input); if let Some(first_child) = tree.first_child() { let ast_node = AstNode::cast(first_child); assert!(ast_node.is_some()); } } } #[derive(Debug)] pub struct IncrementalReparser { _old_tree: SyntaxNodeRef, edits: Vec<TextEdit>, } }
#![allow(unused)] fn main() { use rowan::{GreenNode, GreenNodeBuilder, Language, SyntaxNode, SyntaxToken, TextRange, TextSize}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u16)] pub enum SyntaxKind { Whitespace = 0, Comment, Ident, Number, String, Plus, Minus, Star, Slash, Eq, Neq, Lt, Gt, LParen, RParen, LBrace, RBrace, Semicolon, Comma, Keyword, Error, Root, BinaryExpr, UnaryExpr, ParenExpr, Literal, BlockStmt, ExprStmt, LetStmt, IfStmt, WhileStmt, ReturnStmt, FnDef, ParamList, TypeRef, Path, CallExpr, ArgList, } impl From<SyntaxKind> for rowan::SyntaxKind { fn from(kind: SyntaxKind) -> Self { Self(kind as u16) } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Lang {} impl Language for Lang { type Kind = SyntaxKind; fn kind_from_raw(raw: rowan::SyntaxKind) -> Self::Kind { assert!(raw.0 <= SyntaxKind::ArgList as u16); unsafe { std::mem::transmute::<u16, SyntaxKind>(raw.0) } } fn kind_to_raw(kind: Self::Kind) -> rowan::SyntaxKind { kind.into() } } pub type SyntaxNodeRef = SyntaxNode<Lang>; pub type SyntaxTokenRef = SyntaxToken<Lang>; #[derive(Debug, Clone)] pub struct ParseResult { pub green_node: GreenNode, pub errors: Vec<ParseError>, } #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub range: TextRange, } pub struct Parser { builder: GreenNodeBuilder<'static>, errors: Vec<ParseError>, tokens: Vec<Token>, cursor: usize, } #[derive(Debug, Clone)] pub struct Token { pub kind: SyntaxKind, pub text: String, pub offset: TextSize, } impl Parser { pub fn new(tokens: Vec<Token>) -> Self { Self { builder: GreenNodeBuilder::new(), errors: Vec::new(), tokens, cursor: 0, } } pub fn parse(mut self) -> ParseResult { self.builder.start_node(SyntaxKind::Root.into()); while !self.at_end() { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.builder.finish_node(); ParseResult { green_node: self.builder.finish(), errors: self.errors, } } fn statement(&mut self) { match self.current_kind() { Some(SyntaxKind::Keyword) if self.current_text() == Some("let") => { self.let_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("if") => { self.if_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("while") => { self.while_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("return") => { self.return_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("fn") => { self.function_definition(); } _ => { self.expression_statement(); } } } fn let_statement(&mut self) { self.builder.start_node(SyntaxKind::LetStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Eq) { self.consume(SyntaxKind::Eq); self.skip_trivia(); self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn if_statement(&mut self) { self.builder.start_node(SyntaxKind::IfStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); if self.at_keyword("else") { self.consume(SyntaxKind::Keyword); self.skip_trivia(); if self.at_keyword("if") { self.if_statement(); } else { self.block(); } } self.builder.finish_node(); } fn while_statement(&mut self) { self.builder.start_node(SyntaxKind::WhileStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn return_statement(&mut self) { self.builder.start_node(SyntaxKind::ReturnStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); if !self.at(SyntaxKind::Semicolon) { self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn function_definition(&mut self) { self.builder.start_node(SyntaxKind::FnDef.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); self.parameter_list(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn parameter_list(&mut self) { self.builder.start_node(SyntaxKind::ParamList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn block(&mut self) { self.builder.start_node(SyntaxKind::BlockStmt.into()); self.consume(SyntaxKind::LBrace); while !self.at_end() && !self.at(SyntaxKind::RBrace) { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.consume(SyntaxKind::RBrace); self.builder.finish_node(); } fn expression_statement(&mut self) { self.builder.start_node(SyntaxKind::ExprStmt.into()); self.expression(); self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn expression(&mut self) { self.binary_expression(0); } fn binary_expression(&mut self, min_precedence: u8) { self.unary_expression(); // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } while let Some(op_precedence) = self.current_binary_op_precedence() { if op_precedence < min_precedence { break; } let checkpoint = self.builder.checkpoint(); if let Some( k @ (SyntaxKind::Plus | SyntaxKind::Minus | SyntaxKind::Star | SyntaxKind::Slash | SyntaxKind::Eq | SyntaxKind::Neq | SyntaxKind::Lt | SyntaxKind::Gt), ) = self.current_kind() { self.consume(k); } // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.binary_expression(op_precedence + 1); self.builder .start_node_at(checkpoint, SyntaxKind::BinaryExpr.into()); self.builder.finish_node(); } } fn unary_expression(&mut self) { if self.at(SyntaxKind::Minus) || self.at(SyntaxKind::Plus) { self.builder.start_node(SyntaxKind::UnaryExpr.into()); self.consume(self.current_kind().unwrap()); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.unary_expression(); self.builder.finish_node(); } else { self.postfix_expression(); } } fn postfix_expression(&mut self) { self.primary_expression(); while self.at(SyntaxKind::LParen) { self.builder.start_node(SyntaxKind::CallExpr.into()); let checkpoint = self.builder.checkpoint(); self.argument_list(); self.builder .start_node_at(checkpoint, SyntaxKind::CallExpr.into()); self.builder.finish_node(); } } fn argument_list(&mut self) { self.builder.start_node(SyntaxKind::ArgList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.expression(); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn primary_expression(&mut self) { match self.current_kind() { Some(SyntaxKind::Number) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::Number); self.builder.finish_node(); } Some(SyntaxKind::String) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::String); self.builder.finish_node(); } Some(SyntaxKind::Ident) => { self.consume(SyntaxKind::Ident); } Some(SyntaxKind::LParen) => { self.builder.start_node(SyntaxKind::ParenExpr.into()); self.consume(SyntaxKind::LParen); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.expression(); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } _ => { self.error("Expected expression"); self.advance(); } } } fn current_binary_op_precedence(&self) -> Option<u8> { match self.current_kind()? { SyntaxKind::Star | SyntaxKind::Slash => Some(5), SyntaxKind::Plus | SyntaxKind::Minus => Some(4), SyntaxKind::Lt | SyntaxKind::Gt => Some(3), SyntaxKind::Eq | SyntaxKind::Neq => Some(2), _ => None, } } fn trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { let kind = self.current_kind().unwrap(); self.consume(kind); } } fn skip_trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.advance(); } } fn consume(&mut self, expected: SyntaxKind) { if self.at(expected) { let token = &self.tokens[self.cursor]; self.builder.token(expected.into(), &token.text); self.advance(); } else { self.error(&format!("Expected {:?}", expected)); } } fn at(&self, kind: SyntaxKind) -> bool { self.current_kind() == Some(kind) } fn at_keyword(&self, keyword: &str) -> bool { self.at(SyntaxKind::Keyword) && self.current_text() == Some(keyword) } fn current_kind(&self) -> Option<SyntaxKind> { self.tokens.get(self.cursor).map(|t| t.kind) } fn current_text(&self) -> Option<&str> { self.tokens.get(self.cursor).map(|t| t.text.as_str()) } fn advance(&mut self) { if self.cursor < self.tokens.len() { self.cursor += 1; } } fn at_end(&self) -> bool { self.cursor >= self.tokens.len() } fn error(&mut self, message: &str) { let offset = self .tokens .get(self.cursor) .map(|t| t.offset) .unwrap_or_else(|| TextSize::from(0)); self.errors.push(ParseError { message: message.to_string(), range: TextRange::empty(offset), }); } } pub fn tokenize(input: &str) -> Vec<Token> { let mut tokens = Vec::new(); let mut offset = TextSize::from(0); let mut chars = input.chars().peekable(); while let Some(ch) = chars.next() { let start = offset; offset += TextSize::of(ch); let (kind, text) = match ch { ' ' | '\t' | '\n' | '\r' => { let mut text = String::from(ch); while let Some(&next) = chars.peek() { if next.is_whitespace() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Whitespace, text) } '/' if chars.peek() == Some(&'/') => { chars.next(); offset += TextSize::of('/'); let mut text = String::from("//"); while let Some(&next) = chars.peek() { if next != '\n' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Comment, text) } '+' => (SyntaxKind::Plus, String::from("+")), '-' => (SyntaxKind::Minus, String::from("-")), '*' => (SyntaxKind::Star, String::from("*")), '/' => (SyntaxKind::Slash, String::from("/")), '=' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Eq, String::from("==")) } '=' => (SyntaxKind::Eq, String::from("=")), '!' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Neq, String::from("!=")) } '<' => (SyntaxKind::Lt, String::from("<")), '>' => (SyntaxKind::Gt, String::from(">")), '(' => (SyntaxKind::LParen, String::from("(")), ')' => (SyntaxKind::RParen, String::from(")")), '{' => (SyntaxKind::LBrace, String::from("{")), '}' => (SyntaxKind::RBrace, String::from("}")), ';' => (SyntaxKind::Semicolon, String::from(";")), ',' => (SyntaxKind::Comma, String::from(",")), '"' => { let mut text = String::from("\""); while let Some(&next) = chars.peek() { text.push(next); offset += TextSize::of(next); chars.next(); if next == '"' { break; } } (SyntaxKind::String, text) } c if c.is_ascii_digit() => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_digit() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Number, text) } c if c.is_ascii_alphabetic() || c == '_' => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_alphanumeric() || next == '_' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } let kind = match text.as_str() { "let" | "if" | "else" | "while" | "for" | "fn" | "return" | "true" | "false" | "struct" | "enum" | "impl" => SyntaxKind::Keyword, _ => SyntaxKind::Ident, }; (kind, text) } _ => (SyntaxKind::Error, String::from(ch)), }; tokens.push(Token { kind, text, offset: start, }); } tokens } #[derive(Debug)] pub struct IncrementalReparser { _old_tree: SyntaxNodeRef, edits: Vec<TextEdit>, } #[derive(Debug, Clone)] pub struct TextEdit { pub range: TextRange, pub new_text: String, } pub struct SyntaxTreeBuilder { green: GreenNode, } impl SyntaxTreeBuilder { pub fn new(green: GreenNode) -> Self { Self { green } } pub fn build(self) -> SyntaxNodeRef { SyntaxNodeRef::new_root(self.green) } } pub fn parse_expression(input: &str) -> SyntaxNodeRef { let tokens = tokenize(input); let mut parser = Parser::new(tokens); // Build just an expression tree parser.builder.start_node(SyntaxKind::Root.into()); // Include leading whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } // Parse the expression if !parser.at_end() { parser.expression(); } // Include trailing whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } parser.builder.finish_node(); let green_node = parser.builder.finish(); SyntaxTreeBuilder::new(green_node).build() } pub struct AstNode { syntax: SyntaxNodeRef, } impl AstNode { pub fn cast(syntax: SyntaxNodeRef) -> Option<Self> { match syntax.kind() { SyntaxKind::Root | SyntaxKind::BinaryExpr | SyntaxKind::UnaryExpr | SyntaxKind::ParenExpr | SyntaxKind::Literal | SyntaxKind::BlockStmt | SyntaxKind::ExprStmt | SyntaxKind::LetStmt | SyntaxKind::IfStmt | SyntaxKind::WhileStmt | SyntaxKind::ReturnStmt | SyntaxKind::FnDef | SyntaxKind::CallExpr => Some(Self { syntax }), _ => None, } } pub fn syntax(&self) -> &SyntaxNodeRef { &self.syntax } } pub trait AstToken { fn cast(syntax: SyntaxTokenRef) -> Option<Self> where Self: Sized; fn syntax(&self) -> &SyntaxTokenRef; } pub struct Identifier { syntax: SyntaxTokenRef, } impl AstToken for Identifier { fn cast(syntax: SyntaxTokenRef) -> Option<Self> { if syntax.kind() == SyntaxKind::Ident { Some(Self { syntax }) } else { None } } fn syntax(&self) -> &SyntaxTokenRef { &self.syntax } } impl Identifier { pub fn text(&self) -> &str { self.syntax.text() } } pub fn walk_tree(node: &SyntaxNodeRef, depth: usize) { let indent = " ".repeat(depth); println!("{}{:?}", indent, node.kind()); for child in node.children() { walk_tree(&child, depth + 1); } } pub fn find_node_at_offset(root: &SyntaxNodeRef, offset: TextSize) -> Option<SyntaxNodeRef> { if !root.text_range().contains(offset) { return None; } let mut result = root.clone(); for child in root.children() { if child.text_range().contains(offset) { if let Some(deeper) = find_node_at_offset(&child, offset) { result = deeper; } } } Some(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let input = "let x = 42 + y;"; let tokens = tokenize(input); assert_eq!(tokens[0].kind, SyntaxKind::Keyword); assert_eq!(tokens[0].text, "let"); assert_eq!(tokens[2].kind, SyntaxKind::Ident); assert_eq!(tokens[2].text, "x"); } #[test] fn test_parse_expression() { let input = "x + 42 * (y - 3)"; // First check tokenization let tokens = tokenize(input); // The expression should tokenize properly assert!(tokens.len() > 5, "Expected more tokens, got: {:?}", tokens); // Check if tokens include operators let has_plus = tokens.iter().any(|t| t.kind == SyntaxKind::Plus); let has_star = tokens.iter().any(|t| t.kind == SyntaxKind::Star); assert!(has_plus, "Missing + operator in tokens: {:?}", tokens); assert!(has_star, "Missing * operator in tokens: {:?}", tokens); let tree = parse_expression(input); assert_eq!(tree.kind(), SyntaxKind::Root); // Check the tree contains the full expression let tree_text = tree.text().to_string(); assert_eq!( tree_text.trim(), input, "Tree text doesn't match input. Got '{}' expected '{}'", tree_text.trim(), input ); } #[test] fn test_incremental_reparse() { let input = "let x = 42;"; let tree = parse_expression(input); let mut reparser = IncrementalReparser::new(tree); reparser.add_edit(TextEdit { range: TextRange::new(TextSize::from(8), TextSize::from(10)), new_text: "100".to_string(), }); let new_tree = reparser.reparse("let x = 100;"); assert_eq!(new_tree.errors.len(), 0); } #[test] fn test_ast_node_cast() { let input = "42 + x"; let tree = parse_expression(input); if let Some(first_child) = tree.first_child() { let ast_node = AstNode::cast(first_child); assert!(ast_node.is_some()); } } } impl IncrementalReparser { pub fn new(tree: SyntaxNodeRef) -> Self { Self { _old_tree: tree, edits: Vec::new(), } } pub fn add_edit(&mut self, edit: TextEdit) { self.edits.push(edit); } pub fn reparse(&self, new_text: &str) -> ParseResult { let tokens = tokenize(new_text); let parser = Parser::new(tokens); parser.parse() } } }
The incremental reparser tracks edits to the source text and efficiently rebuilds only the affected portions of the syntax tree. This capability is crucial for IDE scenarios where the source changes frequently and full reparsing would be prohibitively expensive. The reparser identifies unchanged subtrees that can be reused from the previous parse.
AST Layer
The AST layer provides a typed interface over the syntax tree:
#![allow(unused)] fn main() { use rowan::{GreenNode, GreenNodeBuilder, Language, SyntaxNode, SyntaxToken, TextRange, TextSize}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u16)] pub enum SyntaxKind { Whitespace = 0, Comment, Ident, Number, String, Plus, Minus, Star, Slash, Eq, Neq, Lt, Gt, LParen, RParen, LBrace, RBrace, Semicolon, Comma, Keyword, Error, Root, BinaryExpr, UnaryExpr, ParenExpr, Literal, BlockStmt, ExprStmt, LetStmt, IfStmt, WhileStmt, ReturnStmt, FnDef, ParamList, TypeRef, Path, CallExpr, ArgList, } impl From<SyntaxKind> for rowan::SyntaxKind { fn from(kind: SyntaxKind) -> Self { Self(kind as u16) } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Lang {} impl Language for Lang { type Kind = SyntaxKind; fn kind_from_raw(raw: rowan::SyntaxKind) -> Self::Kind { assert!(raw.0 <= SyntaxKind::ArgList as u16); unsafe { std::mem::transmute::<u16, SyntaxKind>(raw.0) } } fn kind_to_raw(kind: Self::Kind) -> rowan::SyntaxKind { kind.into() } } pub type SyntaxNodeRef = SyntaxNode<Lang>; pub type SyntaxTokenRef = SyntaxToken<Lang>; #[derive(Debug, Clone)] pub struct ParseResult { pub green_node: GreenNode, pub errors: Vec<ParseError>, } #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub range: TextRange, } pub struct Parser { builder: GreenNodeBuilder<'static>, errors: Vec<ParseError>, tokens: Vec<Token>, cursor: usize, } #[derive(Debug, Clone)] pub struct Token { pub kind: SyntaxKind, pub text: String, pub offset: TextSize, } impl Parser { pub fn new(tokens: Vec<Token>) -> Self { Self { builder: GreenNodeBuilder::new(), errors: Vec::new(), tokens, cursor: 0, } } pub fn parse(mut self) -> ParseResult { self.builder.start_node(SyntaxKind::Root.into()); while !self.at_end() { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.builder.finish_node(); ParseResult { green_node: self.builder.finish(), errors: self.errors, } } fn statement(&mut self) { match self.current_kind() { Some(SyntaxKind::Keyword) if self.current_text() == Some("let") => { self.let_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("if") => { self.if_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("while") => { self.while_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("return") => { self.return_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("fn") => { self.function_definition(); } _ => { self.expression_statement(); } } } fn let_statement(&mut self) { self.builder.start_node(SyntaxKind::LetStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Eq) { self.consume(SyntaxKind::Eq); self.skip_trivia(); self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn if_statement(&mut self) { self.builder.start_node(SyntaxKind::IfStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); if self.at_keyword("else") { self.consume(SyntaxKind::Keyword); self.skip_trivia(); if self.at_keyword("if") { self.if_statement(); } else { self.block(); } } self.builder.finish_node(); } fn while_statement(&mut self) { self.builder.start_node(SyntaxKind::WhileStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn return_statement(&mut self) { self.builder.start_node(SyntaxKind::ReturnStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); if !self.at(SyntaxKind::Semicolon) { self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn function_definition(&mut self) { self.builder.start_node(SyntaxKind::FnDef.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); self.parameter_list(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn parameter_list(&mut self) { self.builder.start_node(SyntaxKind::ParamList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn block(&mut self) { self.builder.start_node(SyntaxKind::BlockStmt.into()); self.consume(SyntaxKind::LBrace); while !self.at_end() && !self.at(SyntaxKind::RBrace) { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.consume(SyntaxKind::RBrace); self.builder.finish_node(); } fn expression_statement(&mut self) { self.builder.start_node(SyntaxKind::ExprStmt.into()); self.expression(); self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn expression(&mut self) { self.binary_expression(0); } fn binary_expression(&mut self, min_precedence: u8) { self.unary_expression(); // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } while let Some(op_precedence) = self.current_binary_op_precedence() { if op_precedence < min_precedence { break; } let checkpoint = self.builder.checkpoint(); if let Some( k @ (SyntaxKind::Plus | SyntaxKind::Minus | SyntaxKind::Star | SyntaxKind::Slash | SyntaxKind::Eq | SyntaxKind::Neq | SyntaxKind::Lt | SyntaxKind::Gt), ) = self.current_kind() { self.consume(k); } // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.binary_expression(op_precedence + 1); self.builder .start_node_at(checkpoint, SyntaxKind::BinaryExpr.into()); self.builder.finish_node(); } } fn unary_expression(&mut self) { if self.at(SyntaxKind::Minus) || self.at(SyntaxKind::Plus) { self.builder.start_node(SyntaxKind::UnaryExpr.into()); self.consume(self.current_kind().unwrap()); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.unary_expression(); self.builder.finish_node(); } else { self.postfix_expression(); } } fn postfix_expression(&mut self) { self.primary_expression(); while self.at(SyntaxKind::LParen) { self.builder.start_node(SyntaxKind::CallExpr.into()); let checkpoint = self.builder.checkpoint(); self.argument_list(); self.builder .start_node_at(checkpoint, SyntaxKind::CallExpr.into()); self.builder.finish_node(); } } fn argument_list(&mut self) { self.builder.start_node(SyntaxKind::ArgList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.expression(); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn primary_expression(&mut self) { match self.current_kind() { Some(SyntaxKind::Number) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::Number); self.builder.finish_node(); } Some(SyntaxKind::String) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::String); self.builder.finish_node(); } Some(SyntaxKind::Ident) => { self.consume(SyntaxKind::Ident); } Some(SyntaxKind::LParen) => { self.builder.start_node(SyntaxKind::ParenExpr.into()); self.consume(SyntaxKind::LParen); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.expression(); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } _ => { self.error("Expected expression"); self.advance(); } } } fn current_binary_op_precedence(&self) -> Option<u8> { match self.current_kind()? { SyntaxKind::Star | SyntaxKind::Slash => Some(5), SyntaxKind::Plus | SyntaxKind::Minus => Some(4), SyntaxKind::Lt | SyntaxKind::Gt => Some(3), SyntaxKind::Eq | SyntaxKind::Neq => Some(2), _ => None, } } fn trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { let kind = self.current_kind().unwrap(); self.consume(kind); } } fn skip_trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.advance(); } } fn consume(&mut self, expected: SyntaxKind) { if self.at(expected) { let token = &self.tokens[self.cursor]; self.builder.token(expected.into(), &token.text); self.advance(); } else { self.error(&format!("Expected {:?}", expected)); } } fn at(&self, kind: SyntaxKind) -> bool { self.current_kind() == Some(kind) } fn at_keyword(&self, keyword: &str) -> bool { self.at(SyntaxKind::Keyword) && self.current_text() == Some(keyword) } fn current_kind(&self) -> Option<SyntaxKind> { self.tokens.get(self.cursor).map(|t| t.kind) } fn current_text(&self) -> Option<&str> { self.tokens.get(self.cursor).map(|t| t.text.as_str()) } fn advance(&mut self) { if self.cursor < self.tokens.len() { self.cursor += 1; } } fn at_end(&self) -> bool { self.cursor >= self.tokens.len() } fn error(&mut self, message: &str) { let offset = self .tokens .get(self.cursor) .map(|t| t.offset) .unwrap_or_else(|| TextSize::from(0)); self.errors.push(ParseError { message: message.to_string(), range: TextRange::empty(offset), }); } } pub fn tokenize(input: &str) -> Vec<Token> { let mut tokens = Vec::new(); let mut offset = TextSize::from(0); let mut chars = input.chars().peekable(); while let Some(ch) = chars.next() { let start = offset; offset += TextSize::of(ch); let (kind, text) = match ch { ' ' | '\t' | '\n' | '\r' => { let mut text = String::from(ch); while let Some(&next) = chars.peek() { if next.is_whitespace() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Whitespace, text) } '/' if chars.peek() == Some(&'/') => { chars.next(); offset += TextSize::of('/'); let mut text = String::from("//"); while let Some(&next) = chars.peek() { if next != '\n' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Comment, text) } '+' => (SyntaxKind::Plus, String::from("+")), '-' => (SyntaxKind::Minus, String::from("-")), '*' => (SyntaxKind::Star, String::from("*")), '/' => (SyntaxKind::Slash, String::from("/")), '=' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Eq, String::from("==")) } '=' => (SyntaxKind::Eq, String::from("=")), '!' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Neq, String::from("!=")) } '<' => (SyntaxKind::Lt, String::from("<")), '>' => (SyntaxKind::Gt, String::from(">")), '(' => (SyntaxKind::LParen, String::from("(")), ')' => (SyntaxKind::RParen, String::from(")")), '{' => (SyntaxKind::LBrace, String::from("{")), '}' => (SyntaxKind::RBrace, String::from("}")), ';' => (SyntaxKind::Semicolon, String::from(";")), ',' => (SyntaxKind::Comma, String::from(",")), '"' => { let mut text = String::from("\""); while let Some(&next) = chars.peek() { text.push(next); offset += TextSize::of(next); chars.next(); if next == '"' { break; } } (SyntaxKind::String, text) } c if c.is_ascii_digit() => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_digit() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Number, text) } c if c.is_ascii_alphabetic() || c == '_' => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_alphanumeric() || next == '_' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } let kind = match text.as_str() { "let" | "if" | "else" | "while" | "for" | "fn" | "return" | "true" | "false" | "struct" | "enum" | "impl" => SyntaxKind::Keyword, _ => SyntaxKind::Ident, }; (kind, text) } _ => (SyntaxKind::Error, String::from(ch)), }; tokens.push(Token { kind, text, offset: start, }); } tokens } #[derive(Debug)] pub struct IncrementalReparser { _old_tree: SyntaxNodeRef, edits: Vec<TextEdit>, } #[derive(Debug, Clone)] pub struct TextEdit { pub range: TextRange, pub new_text: String, } impl IncrementalReparser { pub fn new(tree: SyntaxNodeRef) -> Self { Self { _old_tree: tree, edits: Vec::new(), } } pub fn add_edit(&mut self, edit: TextEdit) { self.edits.push(edit); } pub fn reparse(&self, new_text: &str) -> ParseResult { let tokens = tokenize(new_text); let parser = Parser::new(tokens); parser.parse() } } pub struct SyntaxTreeBuilder { green: GreenNode, } impl SyntaxTreeBuilder { pub fn new(green: GreenNode) -> Self { Self { green } } pub fn build(self) -> SyntaxNodeRef { SyntaxNodeRef::new_root(self.green) } } pub fn parse_expression(input: &str) -> SyntaxNodeRef { let tokens = tokenize(input); let mut parser = Parser::new(tokens); // Build just an expression tree parser.builder.start_node(SyntaxKind::Root.into()); // Include leading whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } // Parse the expression if !parser.at_end() { parser.expression(); } // Include trailing whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } parser.builder.finish_node(); let green_node = parser.builder.finish(); SyntaxTreeBuilder::new(green_node).build() } impl AstNode { pub fn cast(syntax: SyntaxNodeRef) -> Option<Self> { match syntax.kind() { SyntaxKind::Root | SyntaxKind::BinaryExpr | SyntaxKind::UnaryExpr | SyntaxKind::ParenExpr | SyntaxKind::Literal | SyntaxKind::BlockStmt | SyntaxKind::ExprStmt | SyntaxKind::LetStmt | SyntaxKind::IfStmt | SyntaxKind::WhileStmt | SyntaxKind::ReturnStmt | SyntaxKind::FnDef | SyntaxKind::CallExpr => Some(Self { syntax }), _ => None, } } pub fn syntax(&self) -> &SyntaxNodeRef { &self.syntax } } pub trait AstToken { fn cast(syntax: SyntaxTokenRef) -> Option<Self> where Self: Sized; fn syntax(&self) -> &SyntaxTokenRef; } pub struct Identifier { syntax: SyntaxTokenRef, } impl AstToken for Identifier { fn cast(syntax: SyntaxTokenRef) -> Option<Self> { if syntax.kind() == SyntaxKind::Ident { Some(Self { syntax }) } else { None } } fn syntax(&self) -> &SyntaxTokenRef { &self.syntax } } impl Identifier { pub fn text(&self) -> &str { self.syntax.text() } } pub fn walk_tree(node: &SyntaxNodeRef, depth: usize) { let indent = " ".repeat(depth); println!("{}{:?}", indent, node.kind()); for child in node.children() { walk_tree(&child, depth + 1); } } pub fn find_node_at_offset(root: &SyntaxNodeRef, offset: TextSize) -> Option<SyntaxNodeRef> { if !root.text_range().contains(offset) { return None; } let mut result = root.clone(); for child in root.children() { if child.text_range().contains(offset) { if let Some(deeper) = find_node_at_offset(&child, offset) { result = deeper; } } } Some(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let input = "let x = 42 + y;"; let tokens = tokenize(input); assert_eq!(tokens[0].kind, SyntaxKind::Keyword); assert_eq!(tokens[0].text, "let"); assert_eq!(tokens[2].kind, SyntaxKind::Ident); assert_eq!(tokens[2].text, "x"); } #[test] fn test_parse_expression() { let input = "x + 42 * (y - 3)"; // First check tokenization let tokens = tokenize(input); // The expression should tokenize properly assert!(tokens.len() > 5, "Expected more tokens, got: {:?}", tokens); // Check if tokens include operators let has_plus = tokens.iter().any(|t| t.kind == SyntaxKind::Plus); let has_star = tokens.iter().any(|t| t.kind == SyntaxKind::Star); assert!(has_plus, "Missing + operator in tokens: {:?}", tokens); assert!(has_star, "Missing * operator in tokens: {:?}", tokens); let tree = parse_expression(input); assert_eq!(tree.kind(), SyntaxKind::Root); // Check the tree contains the full expression let tree_text = tree.text().to_string(); assert_eq!( tree_text.trim(), input, "Tree text doesn't match input. Got '{}' expected '{}'", tree_text.trim(), input ); } #[test] fn test_incremental_reparse() { let input = "let x = 42;"; let tree = parse_expression(input); let mut reparser = IncrementalReparser::new(tree); reparser.add_edit(TextEdit { range: TextRange::new(TextSize::from(8), TextSize::from(10)), new_text: "100".to_string(), }); let new_tree = reparser.reparse("let x = 100;"); assert_eq!(new_tree.errors.len(), 0); } #[test] fn test_ast_node_cast() { let input = "42 + x"; let tree = parse_expression(input); if let Some(first_child) = tree.first_child() { let ast_node = AstNode::cast(first_child); assert!(ast_node.is_some()); } } } pub struct AstNode { syntax: SyntaxNodeRef, } }
#![allow(unused)] fn main() { use rowan::{GreenNode, GreenNodeBuilder, Language, SyntaxNode, SyntaxToken, TextRange, TextSize}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u16)] pub enum SyntaxKind { Whitespace = 0, Comment, Ident, Number, String, Plus, Minus, Star, Slash, Eq, Neq, Lt, Gt, LParen, RParen, LBrace, RBrace, Semicolon, Comma, Keyword, Error, Root, BinaryExpr, UnaryExpr, ParenExpr, Literal, BlockStmt, ExprStmt, LetStmt, IfStmt, WhileStmt, ReturnStmt, FnDef, ParamList, TypeRef, Path, CallExpr, ArgList, } impl From<SyntaxKind> for rowan::SyntaxKind { fn from(kind: SyntaxKind) -> Self { Self(kind as u16) } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Lang {} impl Language for Lang { type Kind = SyntaxKind; fn kind_from_raw(raw: rowan::SyntaxKind) -> Self::Kind { assert!(raw.0 <= SyntaxKind::ArgList as u16); unsafe { std::mem::transmute::<u16, SyntaxKind>(raw.0) } } fn kind_to_raw(kind: Self::Kind) -> rowan::SyntaxKind { kind.into() } } pub type SyntaxNodeRef = SyntaxNode<Lang>; pub type SyntaxTokenRef = SyntaxToken<Lang>; #[derive(Debug, Clone)] pub struct ParseResult { pub green_node: GreenNode, pub errors: Vec<ParseError>, } #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub range: TextRange, } pub struct Parser { builder: GreenNodeBuilder<'static>, errors: Vec<ParseError>, tokens: Vec<Token>, cursor: usize, } #[derive(Debug, Clone)] pub struct Token { pub kind: SyntaxKind, pub text: String, pub offset: TextSize, } impl Parser { pub fn new(tokens: Vec<Token>) -> Self { Self { builder: GreenNodeBuilder::new(), errors: Vec::new(), tokens, cursor: 0, } } pub fn parse(mut self) -> ParseResult { self.builder.start_node(SyntaxKind::Root.into()); while !self.at_end() { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.builder.finish_node(); ParseResult { green_node: self.builder.finish(), errors: self.errors, } } fn statement(&mut self) { match self.current_kind() { Some(SyntaxKind::Keyword) if self.current_text() == Some("let") => { self.let_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("if") => { self.if_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("while") => { self.while_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("return") => { self.return_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("fn") => { self.function_definition(); } _ => { self.expression_statement(); } } } fn let_statement(&mut self) { self.builder.start_node(SyntaxKind::LetStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Eq) { self.consume(SyntaxKind::Eq); self.skip_trivia(); self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn if_statement(&mut self) { self.builder.start_node(SyntaxKind::IfStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); if self.at_keyword("else") { self.consume(SyntaxKind::Keyword); self.skip_trivia(); if self.at_keyword("if") { self.if_statement(); } else { self.block(); } } self.builder.finish_node(); } fn while_statement(&mut self) { self.builder.start_node(SyntaxKind::WhileStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn return_statement(&mut self) { self.builder.start_node(SyntaxKind::ReturnStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); if !self.at(SyntaxKind::Semicolon) { self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn function_definition(&mut self) { self.builder.start_node(SyntaxKind::FnDef.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); self.parameter_list(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn parameter_list(&mut self) { self.builder.start_node(SyntaxKind::ParamList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn block(&mut self) { self.builder.start_node(SyntaxKind::BlockStmt.into()); self.consume(SyntaxKind::LBrace); while !self.at_end() && !self.at(SyntaxKind::RBrace) { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.consume(SyntaxKind::RBrace); self.builder.finish_node(); } fn expression_statement(&mut self) { self.builder.start_node(SyntaxKind::ExprStmt.into()); self.expression(); self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn expression(&mut self) { self.binary_expression(0); } fn binary_expression(&mut self, min_precedence: u8) { self.unary_expression(); // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } while let Some(op_precedence) = self.current_binary_op_precedence() { if op_precedence < min_precedence { break; } let checkpoint = self.builder.checkpoint(); if let Some( k @ (SyntaxKind::Plus | SyntaxKind::Minus | SyntaxKind::Star | SyntaxKind::Slash | SyntaxKind::Eq | SyntaxKind::Neq | SyntaxKind::Lt | SyntaxKind::Gt), ) = self.current_kind() { self.consume(k); } // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.binary_expression(op_precedence + 1); self.builder .start_node_at(checkpoint, SyntaxKind::BinaryExpr.into()); self.builder.finish_node(); } } fn unary_expression(&mut self) { if self.at(SyntaxKind::Minus) || self.at(SyntaxKind::Plus) { self.builder.start_node(SyntaxKind::UnaryExpr.into()); self.consume(self.current_kind().unwrap()); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.unary_expression(); self.builder.finish_node(); } else { self.postfix_expression(); } } fn postfix_expression(&mut self) { self.primary_expression(); while self.at(SyntaxKind::LParen) { self.builder.start_node(SyntaxKind::CallExpr.into()); let checkpoint = self.builder.checkpoint(); self.argument_list(); self.builder .start_node_at(checkpoint, SyntaxKind::CallExpr.into()); self.builder.finish_node(); } } fn argument_list(&mut self) { self.builder.start_node(SyntaxKind::ArgList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.expression(); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn primary_expression(&mut self) { match self.current_kind() { Some(SyntaxKind::Number) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::Number); self.builder.finish_node(); } Some(SyntaxKind::String) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::String); self.builder.finish_node(); } Some(SyntaxKind::Ident) => { self.consume(SyntaxKind::Ident); } Some(SyntaxKind::LParen) => { self.builder.start_node(SyntaxKind::ParenExpr.into()); self.consume(SyntaxKind::LParen); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.expression(); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } _ => { self.error("Expected expression"); self.advance(); } } } fn current_binary_op_precedence(&self) -> Option<u8> { match self.current_kind()? { SyntaxKind::Star | SyntaxKind::Slash => Some(5), SyntaxKind::Plus | SyntaxKind::Minus => Some(4), SyntaxKind::Lt | SyntaxKind::Gt => Some(3), SyntaxKind::Eq | SyntaxKind::Neq => Some(2), _ => None, } } fn trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { let kind = self.current_kind().unwrap(); self.consume(kind); } } fn skip_trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.advance(); } } fn consume(&mut self, expected: SyntaxKind) { if self.at(expected) { let token = &self.tokens[self.cursor]; self.builder.token(expected.into(), &token.text); self.advance(); } else { self.error(&format!("Expected {:?}", expected)); } } fn at(&self, kind: SyntaxKind) -> bool { self.current_kind() == Some(kind) } fn at_keyword(&self, keyword: &str) -> bool { self.at(SyntaxKind::Keyword) && self.current_text() == Some(keyword) } fn current_kind(&self) -> Option<SyntaxKind> { self.tokens.get(self.cursor).map(|t| t.kind) } fn current_text(&self) -> Option<&str> { self.tokens.get(self.cursor).map(|t| t.text.as_str()) } fn advance(&mut self) { if self.cursor < self.tokens.len() { self.cursor += 1; } } fn at_end(&self) -> bool { self.cursor >= self.tokens.len() } fn error(&mut self, message: &str) { let offset = self .tokens .get(self.cursor) .map(|t| t.offset) .unwrap_or_else(|| TextSize::from(0)); self.errors.push(ParseError { message: message.to_string(), range: TextRange::empty(offset), }); } } pub fn tokenize(input: &str) -> Vec<Token> { let mut tokens = Vec::new(); let mut offset = TextSize::from(0); let mut chars = input.chars().peekable(); while let Some(ch) = chars.next() { let start = offset; offset += TextSize::of(ch); let (kind, text) = match ch { ' ' | '\t' | '\n' | '\r' => { let mut text = String::from(ch); while let Some(&next) = chars.peek() { if next.is_whitespace() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Whitespace, text) } '/' if chars.peek() == Some(&'/') => { chars.next(); offset += TextSize::of('/'); let mut text = String::from("//"); while let Some(&next) = chars.peek() { if next != '\n' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Comment, text) } '+' => (SyntaxKind::Plus, String::from("+")), '-' => (SyntaxKind::Minus, String::from("-")), '*' => (SyntaxKind::Star, String::from("*")), '/' => (SyntaxKind::Slash, String::from("/")), '=' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Eq, String::from("==")) } '=' => (SyntaxKind::Eq, String::from("=")), '!' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Neq, String::from("!=")) } '<' => (SyntaxKind::Lt, String::from("<")), '>' => (SyntaxKind::Gt, String::from(">")), '(' => (SyntaxKind::LParen, String::from("(")), ')' => (SyntaxKind::RParen, String::from(")")), '{' => (SyntaxKind::LBrace, String::from("{")), '}' => (SyntaxKind::RBrace, String::from("}")), ';' => (SyntaxKind::Semicolon, String::from(";")), ',' => (SyntaxKind::Comma, String::from(",")), '"' => { let mut text = String::from("\""); while let Some(&next) = chars.peek() { text.push(next); offset += TextSize::of(next); chars.next(); if next == '"' { break; } } (SyntaxKind::String, text) } c if c.is_ascii_digit() => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_digit() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Number, text) } c if c.is_ascii_alphabetic() || c == '_' => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_alphanumeric() || next == '_' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } let kind = match text.as_str() { "let" | "if" | "else" | "while" | "for" | "fn" | "return" | "true" | "false" | "struct" | "enum" | "impl" => SyntaxKind::Keyword, _ => SyntaxKind::Ident, }; (kind, text) } _ => (SyntaxKind::Error, String::from(ch)), }; tokens.push(Token { kind, text, offset: start, }); } tokens } #[derive(Debug)] pub struct IncrementalReparser { _old_tree: SyntaxNodeRef, edits: Vec<TextEdit>, } #[derive(Debug, Clone)] pub struct TextEdit { pub range: TextRange, pub new_text: String, } impl IncrementalReparser { pub fn new(tree: SyntaxNodeRef) -> Self { Self { _old_tree: tree, edits: Vec::new(), } } pub fn add_edit(&mut self, edit: TextEdit) { self.edits.push(edit); } pub fn reparse(&self, new_text: &str) -> ParseResult { let tokens = tokenize(new_text); let parser = Parser::new(tokens); parser.parse() } } pub struct SyntaxTreeBuilder { green: GreenNode, } impl SyntaxTreeBuilder { pub fn new(green: GreenNode) -> Self { Self { green } } pub fn build(self) -> SyntaxNodeRef { SyntaxNodeRef::new_root(self.green) } } pub fn parse_expression(input: &str) -> SyntaxNodeRef { let tokens = tokenize(input); let mut parser = Parser::new(tokens); // Build just an expression tree parser.builder.start_node(SyntaxKind::Root.into()); // Include leading whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } // Parse the expression if !parser.at_end() { parser.expression(); } // Include trailing whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } parser.builder.finish_node(); let green_node = parser.builder.finish(); SyntaxTreeBuilder::new(green_node).build() } pub struct AstNode { syntax: SyntaxNodeRef, } impl AstNode { pub fn cast(syntax: SyntaxNodeRef) -> Option<Self> { match syntax.kind() { SyntaxKind::Root | SyntaxKind::BinaryExpr | SyntaxKind::UnaryExpr | SyntaxKind::ParenExpr | SyntaxKind::Literal | SyntaxKind::BlockStmt | SyntaxKind::ExprStmt | SyntaxKind::LetStmt | SyntaxKind::IfStmt | SyntaxKind::WhileStmt | SyntaxKind::ReturnStmt | SyntaxKind::FnDef | SyntaxKind::CallExpr => Some(Self { syntax }), _ => None, } } pub fn syntax(&self) -> &SyntaxNodeRef { &self.syntax } } pub struct Identifier { syntax: SyntaxTokenRef, } impl AstToken for Identifier { fn cast(syntax: SyntaxTokenRef) -> Option<Self> { if syntax.kind() == SyntaxKind::Ident { Some(Self { syntax }) } else { None } } fn syntax(&self) -> &SyntaxTokenRef { &self.syntax } } impl Identifier { pub fn text(&self) -> &str { self.syntax.text() } } pub fn walk_tree(node: &SyntaxNodeRef, depth: usize) { let indent = " ".repeat(depth); println!("{}{:?}", indent, node.kind()); for child in node.children() { walk_tree(&child, depth + 1); } } pub fn find_node_at_offset(root: &SyntaxNodeRef, offset: TextSize) -> Option<SyntaxNodeRef> { if !root.text_range().contains(offset) { return None; } let mut result = root.clone(); for child in root.children() { if child.text_range().contains(offset) { if let Some(deeper) = find_node_at_offset(&child, offset) { result = deeper; } } } Some(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let input = "let x = 42 + y;"; let tokens = tokenize(input); assert_eq!(tokens[0].kind, SyntaxKind::Keyword); assert_eq!(tokens[0].text, "let"); assert_eq!(tokens[2].kind, SyntaxKind::Ident); assert_eq!(tokens[2].text, "x"); } #[test] fn test_parse_expression() { let input = "x + 42 * (y - 3)"; // First check tokenization let tokens = tokenize(input); // The expression should tokenize properly assert!(tokens.len() > 5, "Expected more tokens, got: {:?}", tokens); // Check if tokens include operators let has_plus = tokens.iter().any(|t| t.kind == SyntaxKind::Plus); let has_star = tokens.iter().any(|t| t.kind == SyntaxKind::Star); assert!(has_plus, "Missing + operator in tokens: {:?}", tokens); assert!(has_star, "Missing * operator in tokens: {:?}", tokens); let tree = parse_expression(input); assert_eq!(tree.kind(), SyntaxKind::Root); // Check the tree contains the full expression let tree_text = tree.text().to_string(); assert_eq!( tree_text.trim(), input, "Tree text doesn't match input. Got '{}' expected '{}'", tree_text.trim(), input ); } #[test] fn test_incremental_reparse() { let input = "let x = 42;"; let tree = parse_expression(input); let mut reparser = IncrementalReparser::new(tree); reparser.add_edit(TextEdit { range: TextRange::new(TextSize::from(8), TextSize::from(10)), new_text: "100".to_string(), }); let new_tree = reparser.reparse("let x = 100;"); assert_eq!(new_tree.errors.len(), 0); } #[test] fn test_ast_node_cast() { let input = "42 + x"; let tree = parse_expression(input); if let Some(first_child) = tree.first_child() { let ast_node = AstNode::cast(first_child); assert!(ast_node.is_some()); } } } pub trait AstToken { fn cast(syntax: SyntaxTokenRef) -> Option<Self> where Self: Sized; fn syntax(&self) -> &SyntaxTokenRef; } }
AST nodes wrap syntax nodes with type-safe accessors for their children and properties. The cast method performs runtime type checking to ensure the syntax node has the expected kind. This layer provides the ergonomic API that language servers and other tools use to analyze and transform code.
Tree Traversal
Rowan provides utilities for navigating and searching the syntax tree:
#![allow(unused)] fn main() { use rowan::{GreenNode, GreenNodeBuilder, Language, SyntaxNode, SyntaxToken, TextRange, TextSize}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u16)] pub enum SyntaxKind { Whitespace = 0, Comment, Ident, Number, String, Plus, Minus, Star, Slash, Eq, Neq, Lt, Gt, LParen, RParen, LBrace, RBrace, Semicolon, Comma, Keyword, Error, Root, BinaryExpr, UnaryExpr, ParenExpr, Literal, BlockStmt, ExprStmt, LetStmt, IfStmt, WhileStmt, ReturnStmt, FnDef, ParamList, TypeRef, Path, CallExpr, ArgList, } impl From<SyntaxKind> for rowan::SyntaxKind { fn from(kind: SyntaxKind) -> Self { Self(kind as u16) } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Lang {} impl Language for Lang { type Kind = SyntaxKind; fn kind_from_raw(raw: rowan::SyntaxKind) -> Self::Kind { assert!(raw.0 <= SyntaxKind::ArgList as u16); unsafe { std::mem::transmute::<u16, SyntaxKind>(raw.0) } } fn kind_to_raw(kind: Self::Kind) -> rowan::SyntaxKind { kind.into() } } pub type SyntaxNodeRef = SyntaxNode<Lang>; pub type SyntaxTokenRef = SyntaxToken<Lang>; #[derive(Debug, Clone)] pub struct ParseResult { pub green_node: GreenNode, pub errors: Vec<ParseError>, } #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub range: TextRange, } pub struct Parser { builder: GreenNodeBuilder<'static>, errors: Vec<ParseError>, tokens: Vec<Token>, cursor: usize, } #[derive(Debug, Clone)] pub struct Token { pub kind: SyntaxKind, pub text: String, pub offset: TextSize, } impl Parser { pub fn new(tokens: Vec<Token>) -> Self { Self { builder: GreenNodeBuilder::new(), errors: Vec::new(), tokens, cursor: 0, } } pub fn parse(mut self) -> ParseResult { self.builder.start_node(SyntaxKind::Root.into()); while !self.at_end() { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.builder.finish_node(); ParseResult { green_node: self.builder.finish(), errors: self.errors, } } fn statement(&mut self) { match self.current_kind() { Some(SyntaxKind::Keyword) if self.current_text() == Some("let") => { self.let_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("if") => { self.if_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("while") => { self.while_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("return") => { self.return_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("fn") => { self.function_definition(); } _ => { self.expression_statement(); } } } fn let_statement(&mut self) { self.builder.start_node(SyntaxKind::LetStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Eq) { self.consume(SyntaxKind::Eq); self.skip_trivia(); self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn if_statement(&mut self) { self.builder.start_node(SyntaxKind::IfStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); if self.at_keyword("else") { self.consume(SyntaxKind::Keyword); self.skip_trivia(); if self.at_keyword("if") { self.if_statement(); } else { self.block(); } } self.builder.finish_node(); } fn while_statement(&mut self) { self.builder.start_node(SyntaxKind::WhileStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn return_statement(&mut self) { self.builder.start_node(SyntaxKind::ReturnStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); if !self.at(SyntaxKind::Semicolon) { self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn function_definition(&mut self) { self.builder.start_node(SyntaxKind::FnDef.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); self.parameter_list(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn parameter_list(&mut self) { self.builder.start_node(SyntaxKind::ParamList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn block(&mut self) { self.builder.start_node(SyntaxKind::BlockStmt.into()); self.consume(SyntaxKind::LBrace); while !self.at_end() && !self.at(SyntaxKind::RBrace) { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.consume(SyntaxKind::RBrace); self.builder.finish_node(); } fn expression_statement(&mut self) { self.builder.start_node(SyntaxKind::ExprStmt.into()); self.expression(); self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn expression(&mut self) { self.binary_expression(0); } fn binary_expression(&mut self, min_precedence: u8) { self.unary_expression(); // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } while let Some(op_precedence) = self.current_binary_op_precedence() { if op_precedence < min_precedence { break; } let checkpoint = self.builder.checkpoint(); if let Some( k @ (SyntaxKind::Plus | SyntaxKind::Minus | SyntaxKind::Star | SyntaxKind::Slash | SyntaxKind::Eq | SyntaxKind::Neq | SyntaxKind::Lt | SyntaxKind::Gt), ) = self.current_kind() { self.consume(k); } // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.binary_expression(op_precedence + 1); self.builder .start_node_at(checkpoint, SyntaxKind::BinaryExpr.into()); self.builder.finish_node(); } } fn unary_expression(&mut self) { if self.at(SyntaxKind::Minus) || self.at(SyntaxKind::Plus) { self.builder.start_node(SyntaxKind::UnaryExpr.into()); self.consume(self.current_kind().unwrap()); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.unary_expression(); self.builder.finish_node(); } else { self.postfix_expression(); } } fn postfix_expression(&mut self) { self.primary_expression(); while self.at(SyntaxKind::LParen) { self.builder.start_node(SyntaxKind::CallExpr.into()); let checkpoint = self.builder.checkpoint(); self.argument_list(); self.builder .start_node_at(checkpoint, SyntaxKind::CallExpr.into()); self.builder.finish_node(); } } fn argument_list(&mut self) { self.builder.start_node(SyntaxKind::ArgList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.expression(); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn primary_expression(&mut self) { match self.current_kind() { Some(SyntaxKind::Number) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::Number); self.builder.finish_node(); } Some(SyntaxKind::String) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::String); self.builder.finish_node(); } Some(SyntaxKind::Ident) => { self.consume(SyntaxKind::Ident); } Some(SyntaxKind::LParen) => { self.builder.start_node(SyntaxKind::ParenExpr.into()); self.consume(SyntaxKind::LParen); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.expression(); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } _ => { self.error("Expected expression"); self.advance(); } } } fn current_binary_op_precedence(&self) -> Option<u8> { match self.current_kind()? { SyntaxKind::Star | SyntaxKind::Slash => Some(5), SyntaxKind::Plus | SyntaxKind::Minus => Some(4), SyntaxKind::Lt | SyntaxKind::Gt => Some(3), SyntaxKind::Eq | SyntaxKind::Neq => Some(2), _ => None, } } fn trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { let kind = self.current_kind().unwrap(); self.consume(kind); } } fn skip_trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.advance(); } } fn consume(&mut self, expected: SyntaxKind) { if self.at(expected) { let token = &self.tokens[self.cursor]; self.builder.token(expected.into(), &token.text); self.advance(); } else { self.error(&format!("Expected {:?}", expected)); } } fn at(&self, kind: SyntaxKind) -> bool { self.current_kind() == Some(kind) } fn at_keyword(&self, keyword: &str) -> bool { self.at(SyntaxKind::Keyword) && self.current_text() == Some(keyword) } fn current_kind(&self) -> Option<SyntaxKind> { self.tokens.get(self.cursor).map(|t| t.kind) } fn current_text(&self) -> Option<&str> { self.tokens.get(self.cursor).map(|t| t.text.as_str()) } fn advance(&mut self) { if self.cursor < self.tokens.len() { self.cursor += 1; } } fn at_end(&self) -> bool { self.cursor >= self.tokens.len() } fn error(&mut self, message: &str) { let offset = self .tokens .get(self.cursor) .map(|t| t.offset) .unwrap_or_else(|| TextSize::from(0)); self.errors.push(ParseError { message: message.to_string(), range: TextRange::empty(offset), }); } } pub fn tokenize(input: &str) -> Vec<Token> { let mut tokens = Vec::new(); let mut offset = TextSize::from(0); let mut chars = input.chars().peekable(); while let Some(ch) = chars.next() { let start = offset; offset += TextSize::of(ch); let (kind, text) = match ch { ' ' | '\t' | '\n' | '\r' => { let mut text = String::from(ch); while let Some(&next) = chars.peek() { if next.is_whitespace() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Whitespace, text) } '/' if chars.peek() == Some(&'/') => { chars.next(); offset += TextSize::of('/'); let mut text = String::from("//"); while let Some(&next) = chars.peek() { if next != '\n' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Comment, text) } '+' => (SyntaxKind::Plus, String::from("+")), '-' => (SyntaxKind::Minus, String::from("-")), '*' => (SyntaxKind::Star, String::from("*")), '/' => (SyntaxKind::Slash, String::from("/")), '=' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Eq, String::from("==")) } '=' => (SyntaxKind::Eq, String::from("=")), '!' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Neq, String::from("!=")) } '<' => (SyntaxKind::Lt, String::from("<")), '>' => (SyntaxKind::Gt, String::from(">")), '(' => (SyntaxKind::LParen, String::from("(")), ')' => (SyntaxKind::RParen, String::from(")")), '{' => (SyntaxKind::LBrace, String::from("{")), '}' => (SyntaxKind::RBrace, String::from("}")), ';' => (SyntaxKind::Semicolon, String::from(";")), ',' => (SyntaxKind::Comma, String::from(",")), '"' => { let mut text = String::from("\""); while let Some(&next) = chars.peek() { text.push(next); offset += TextSize::of(next); chars.next(); if next == '"' { break; } } (SyntaxKind::String, text) } c if c.is_ascii_digit() => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_digit() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Number, text) } c if c.is_ascii_alphabetic() || c == '_' => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_alphanumeric() || next == '_' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } let kind = match text.as_str() { "let" | "if" | "else" | "while" | "for" | "fn" | "return" | "true" | "false" | "struct" | "enum" | "impl" => SyntaxKind::Keyword, _ => SyntaxKind::Ident, }; (kind, text) } _ => (SyntaxKind::Error, String::from(ch)), }; tokens.push(Token { kind, text, offset: start, }); } tokens } #[derive(Debug)] pub struct IncrementalReparser { _old_tree: SyntaxNodeRef, edits: Vec<TextEdit>, } #[derive(Debug, Clone)] pub struct TextEdit { pub range: TextRange, pub new_text: String, } impl IncrementalReparser { pub fn new(tree: SyntaxNodeRef) -> Self { Self { _old_tree: tree, edits: Vec::new(), } } pub fn add_edit(&mut self, edit: TextEdit) { self.edits.push(edit); } pub fn reparse(&self, new_text: &str) -> ParseResult { let tokens = tokenize(new_text); let parser = Parser::new(tokens); parser.parse() } } pub struct SyntaxTreeBuilder { green: GreenNode, } impl SyntaxTreeBuilder { pub fn new(green: GreenNode) -> Self { Self { green } } pub fn build(self) -> SyntaxNodeRef { SyntaxNodeRef::new_root(self.green) } } pub fn parse_expression(input: &str) -> SyntaxNodeRef { let tokens = tokenize(input); let mut parser = Parser::new(tokens); // Build just an expression tree parser.builder.start_node(SyntaxKind::Root.into()); // Include leading whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } // Parse the expression if !parser.at_end() { parser.expression(); } // Include trailing whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } parser.builder.finish_node(); let green_node = parser.builder.finish(); SyntaxTreeBuilder::new(green_node).build() } pub struct AstNode { syntax: SyntaxNodeRef, } impl AstNode { pub fn cast(syntax: SyntaxNodeRef) -> Option<Self> { match syntax.kind() { SyntaxKind::Root | SyntaxKind::BinaryExpr | SyntaxKind::UnaryExpr | SyntaxKind::ParenExpr | SyntaxKind::Literal | SyntaxKind::BlockStmt | SyntaxKind::ExprStmt | SyntaxKind::LetStmt | SyntaxKind::IfStmt | SyntaxKind::WhileStmt | SyntaxKind::ReturnStmt | SyntaxKind::FnDef | SyntaxKind::CallExpr => Some(Self { syntax }), _ => None, } } pub fn syntax(&self) -> &SyntaxNodeRef { &self.syntax } } pub trait AstToken { fn cast(syntax: SyntaxTokenRef) -> Option<Self> where Self: Sized; fn syntax(&self) -> &SyntaxTokenRef; } pub struct Identifier { syntax: SyntaxTokenRef, } impl AstToken for Identifier { fn cast(syntax: SyntaxTokenRef) -> Option<Self> { if syntax.kind() == SyntaxKind::Ident { Some(Self { syntax }) } else { None } } fn syntax(&self) -> &SyntaxTokenRef { &self.syntax } } impl Identifier { pub fn text(&self) -> &str { self.syntax.text() } } pub fn find_node_at_offset(root: &SyntaxNodeRef, offset: TextSize) -> Option<SyntaxNodeRef> { if !root.text_range().contains(offset) { return None; } let mut result = root.clone(); for child in root.children() { if child.text_range().contains(offset) { if let Some(deeper) = find_node_at_offset(&child, offset) { result = deeper; } } } Some(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let input = "let x = 42 + y;"; let tokens = tokenize(input); assert_eq!(tokens[0].kind, SyntaxKind::Keyword); assert_eq!(tokens[0].text, "let"); assert_eq!(tokens[2].kind, SyntaxKind::Ident); assert_eq!(tokens[2].text, "x"); } #[test] fn test_parse_expression() { let input = "x + 42 * (y - 3)"; // First check tokenization let tokens = tokenize(input); // The expression should tokenize properly assert!(tokens.len() > 5, "Expected more tokens, got: {:?}", tokens); // Check if tokens include operators let has_plus = tokens.iter().any(|t| t.kind == SyntaxKind::Plus); let has_star = tokens.iter().any(|t| t.kind == SyntaxKind::Star); assert!(has_plus, "Missing + operator in tokens: {:?}", tokens); assert!(has_star, "Missing * operator in tokens: {:?}", tokens); let tree = parse_expression(input); assert_eq!(tree.kind(), SyntaxKind::Root); // Check the tree contains the full expression let tree_text = tree.text().to_string(); assert_eq!( tree_text.trim(), input, "Tree text doesn't match input. Got '{}' expected '{}'", tree_text.trim(), input ); } #[test] fn test_incremental_reparse() { let input = "let x = 42;"; let tree = parse_expression(input); let mut reparser = IncrementalReparser::new(tree); reparser.add_edit(TextEdit { range: TextRange::new(TextSize::from(8), TextSize::from(10)), new_text: "100".to_string(), }); let new_tree = reparser.reparse("let x = 100;"); assert_eq!(new_tree.errors.len(), 0); } #[test] fn test_ast_node_cast() { let input = "42 + x"; let tree = parse_expression(input); if let Some(first_child) = tree.first_child() { let ast_node = AstNode::cast(first_child); assert!(ast_node.is_some()); } } } pub fn walk_tree(node: &SyntaxNodeRef, depth: usize) { let indent = " ".repeat(depth); println!("{}{:?}", indent, node.kind()); for child in node.children() { walk_tree(&child, depth + 1); } } }
#![allow(unused)] fn main() { use rowan::{GreenNode, GreenNodeBuilder, Language, SyntaxNode, SyntaxToken, TextRange, TextSize}; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(u16)] pub enum SyntaxKind { Whitespace = 0, Comment, Ident, Number, String, Plus, Minus, Star, Slash, Eq, Neq, Lt, Gt, LParen, RParen, LBrace, RBrace, Semicolon, Comma, Keyword, Error, Root, BinaryExpr, UnaryExpr, ParenExpr, Literal, BlockStmt, ExprStmt, LetStmt, IfStmt, WhileStmt, ReturnStmt, FnDef, ParamList, TypeRef, Path, CallExpr, ArgList, } impl From<SyntaxKind> for rowan::SyntaxKind { fn from(kind: SyntaxKind) -> Self { Self(kind as u16) } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum Lang {} impl Language for Lang { type Kind = SyntaxKind; fn kind_from_raw(raw: rowan::SyntaxKind) -> Self::Kind { assert!(raw.0 <= SyntaxKind::ArgList as u16); unsafe { std::mem::transmute::<u16, SyntaxKind>(raw.0) } } fn kind_to_raw(kind: Self::Kind) -> rowan::SyntaxKind { kind.into() } } pub type SyntaxNodeRef = SyntaxNode<Lang>; pub type SyntaxTokenRef = SyntaxToken<Lang>; #[derive(Debug, Clone)] pub struct ParseResult { pub green_node: GreenNode, pub errors: Vec<ParseError>, } #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub range: TextRange, } pub struct Parser { builder: GreenNodeBuilder<'static>, errors: Vec<ParseError>, tokens: Vec<Token>, cursor: usize, } #[derive(Debug, Clone)] pub struct Token { pub kind: SyntaxKind, pub text: String, pub offset: TextSize, } impl Parser { pub fn new(tokens: Vec<Token>) -> Self { Self { builder: GreenNodeBuilder::new(), errors: Vec::new(), tokens, cursor: 0, } } pub fn parse(mut self) -> ParseResult { self.builder.start_node(SyntaxKind::Root.into()); while !self.at_end() { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.builder.finish_node(); ParseResult { green_node: self.builder.finish(), errors: self.errors, } } fn statement(&mut self) { match self.current_kind() { Some(SyntaxKind::Keyword) if self.current_text() == Some("let") => { self.let_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("if") => { self.if_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("while") => { self.while_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("return") => { self.return_statement(); } Some(SyntaxKind::Keyword) if self.current_text() == Some("fn") => { self.function_definition(); } _ => { self.expression_statement(); } } } fn let_statement(&mut self) { self.builder.start_node(SyntaxKind::LetStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Eq) { self.consume(SyntaxKind::Eq); self.skip_trivia(); self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn if_statement(&mut self) { self.builder.start_node(SyntaxKind::IfStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); if self.at_keyword("else") { self.consume(SyntaxKind::Keyword); self.skip_trivia(); if self.at_keyword("if") { self.if_statement(); } else { self.block(); } } self.builder.finish_node(); } fn while_statement(&mut self) { self.builder.start_node(SyntaxKind::WhileStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.expression(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn return_statement(&mut self) { self.builder.start_node(SyntaxKind::ReturnStmt.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); if !self.at(SyntaxKind::Semicolon) { self.expression(); } self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn function_definition(&mut self) { self.builder.start_node(SyntaxKind::FnDef.into()); self.consume(SyntaxKind::Keyword); self.skip_trivia(); self.consume(SyntaxKind::Ident); self.skip_trivia(); self.parameter_list(); self.skip_trivia(); self.block(); self.builder.finish_node(); } fn parameter_list(&mut self) { self.builder.start_node(SyntaxKind::ParamList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.consume(SyntaxKind::Ident); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn block(&mut self) { self.builder.start_node(SyntaxKind::BlockStmt.into()); self.consume(SyntaxKind::LBrace); while !self.at_end() && !self.at(SyntaxKind::RBrace) { if self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.trivia(); } else { self.statement(); } } self.consume(SyntaxKind::RBrace); self.builder.finish_node(); } fn expression_statement(&mut self) { self.builder.start_node(SyntaxKind::ExprStmt.into()); self.expression(); self.skip_trivia(); self.consume(SyntaxKind::Semicolon); self.builder.finish_node(); } fn expression(&mut self) { self.binary_expression(0); } fn binary_expression(&mut self, min_precedence: u8) { self.unary_expression(); // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } while let Some(op_precedence) = self.current_binary_op_precedence() { if op_precedence < min_precedence { break; } let checkpoint = self.builder.checkpoint(); if let Some( k @ (SyntaxKind::Plus | SyntaxKind::Minus | SyntaxKind::Star | SyntaxKind::Slash | SyntaxKind::Eq | SyntaxKind::Neq | SyntaxKind::Lt | SyntaxKind::Gt), ) = self.current_kind() { self.consume(k); } // Include whitespace in the tree while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.binary_expression(op_precedence + 1); self.builder .start_node_at(checkpoint, SyntaxKind::BinaryExpr.into()); self.builder.finish_node(); } } fn unary_expression(&mut self) { if self.at(SyntaxKind::Minus) || self.at(SyntaxKind::Plus) { self.builder.start_node(SyntaxKind::UnaryExpr.into()); self.consume(self.current_kind().unwrap()); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.unary_expression(); self.builder.finish_node(); } else { self.postfix_expression(); } } fn postfix_expression(&mut self) { self.primary_expression(); while self.at(SyntaxKind::LParen) { self.builder.start_node(SyntaxKind::CallExpr.into()); let checkpoint = self.builder.checkpoint(); self.argument_list(); self.builder .start_node_at(checkpoint, SyntaxKind::CallExpr.into()); self.builder.finish_node(); } } fn argument_list(&mut self) { self.builder.start_node(SyntaxKind::ArgList.into()); self.consume(SyntaxKind::LParen); self.skip_trivia(); if !self.at(SyntaxKind::RParen) { loop { self.expression(); self.skip_trivia(); if self.at(SyntaxKind::Comma) { self.consume(SyntaxKind::Comma); self.skip_trivia(); } else { break; } } } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } fn primary_expression(&mut self) { match self.current_kind() { Some(SyntaxKind::Number) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::Number); self.builder.finish_node(); } Some(SyntaxKind::String) => { self.builder.start_node(SyntaxKind::Literal.into()); self.consume(SyntaxKind::String); self.builder.finish_node(); } Some(SyntaxKind::Ident) => { self.consume(SyntaxKind::Ident); } Some(SyntaxKind::LParen) => { self.builder.start_node(SyntaxKind::ParenExpr.into()); self.consume(SyntaxKind::LParen); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.expression(); while self.at(SyntaxKind::Whitespace) { self.trivia(); } self.consume(SyntaxKind::RParen); self.builder.finish_node(); } _ => { self.error("Expected expression"); self.advance(); } } } fn current_binary_op_precedence(&self) -> Option<u8> { match self.current_kind()? { SyntaxKind::Star | SyntaxKind::Slash => Some(5), SyntaxKind::Plus | SyntaxKind::Minus => Some(4), SyntaxKind::Lt | SyntaxKind::Gt => Some(3), SyntaxKind::Eq | SyntaxKind::Neq => Some(2), _ => None, } } fn trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { let kind = self.current_kind().unwrap(); self.consume(kind); } } fn skip_trivia(&mut self) { while self.at(SyntaxKind::Whitespace) || self.at(SyntaxKind::Comment) { self.advance(); } } fn consume(&mut self, expected: SyntaxKind) { if self.at(expected) { let token = &self.tokens[self.cursor]; self.builder.token(expected.into(), &token.text); self.advance(); } else { self.error(&format!("Expected {:?}", expected)); } } fn at(&self, kind: SyntaxKind) -> bool { self.current_kind() == Some(kind) } fn at_keyword(&self, keyword: &str) -> bool { self.at(SyntaxKind::Keyword) && self.current_text() == Some(keyword) } fn current_kind(&self) -> Option<SyntaxKind> { self.tokens.get(self.cursor).map(|t| t.kind) } fn current_text(&self) -> Option<&str> { self.tokens.get(self.cursor).map(|t| t.text.as_str()) } fn advance(&mut self) { if self.cursor < self.tokens.len() { self.cursor += 1; } } fn at_end(&self) -> bool { self.cursor >= self.tokens.len() } fn error(&mut self, message: &str) { let offset = self .tokens .get(self.cursor) .map(|t| t.offset) .unwrap_or_else(|| TextSize::from(0)); self.errors.push(ParseError { message: message.to_string(), range: TextRange::empty(offset), }); } } pub fn tokenize(input: &str) -> Vec<Token> { let mut tokens = Vec::new(); let mut offset = TextSize::from(0); let mut chars = input.chars().peekable(); while let Some(ch) = chars.next() { let start = offset; offset += TextSize::of(ch); let (kind, text) = match ch { ' ' | '\t' | '\n' | '\r' => { let mut text = String::from(ch); while let Some(&next) = chars.peek() { if next.is_whitespace() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Whitespace, text) } '/' if chars.peek() == Some(&'/') => { chars.next(); offset += TextSize::of('/'); let mut text = String::from("//"); while let Some(&next) = chars.peek() { if next != '\n' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Comment, text) } '+' => (SyntaxKind::Plus, String::from("+")), '-' => (SyntaxKind::Minus, String::from("-")), '*' => (SyntaxKind::Star, String::from("*")), '/' => (SyntaxKind::Slash, String::from("/")), '=' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Eq, String::from("==")) } '=' => (SyntaxKind::Eq, String::from("=")), '!' if chars.peek() == Some(&'=') => { chars.next(); offset += TextSize::of('='); (SyntaxKind::Neq, String::from("!=")) } '<' => (SyntaxKind::Lt, String::from("<")), '>' => (SyntaxKind::Gt, String::from(">")), '(' => (SyntaxKind::LParen, String::from("(")), ')' => (SyntaxKind::RParen, String::from(")")), '{' => (SyntaxKind::LBrace, String::from("{")), '}' => (SyntaxKind::RBrace, String::from("}")), ';' => (SyntaxKind::Semicolon, String::from(";")), ',' => (SyntaxKind::Comma, String::from(",")), '"' => { let mut text = String::from("\""); while let Some(&next) = chars.peek() { text.push(next); offset += TextSize::of(next); chars.next(); if next == '"' { break; } } (SyntaxKind::String, text) } c if c.is_ascii_digit() => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_digit() { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } (SyntaxKind::Number, text) } c if c.is_ascii_alphabetic() || c == '_' => { let mut text = String::from(c); while let Some(&next) = chars.peek() { if next.is_ascii_alphanumeric() || next == '_' { text.push(next); offset += TextSize::of(next); chars.next(); } else { break; } } let kind = match text.as_str() { "let" | "if" | "else" | "while" | "for" | "fn" | "return" | "true" | "false" | "struct" | "enum" | "impl" => SyntaxKind::Keyword, _ => SyntaxKind::Ident, }; (kind, text) } _ => (SyntaxKind::Error, String::from(ch)), }; tokens.push(Token { kind, text, offset: start, }); } tokens } #[derive(Debug)] pub struct IncrementalReparser { _old_tree: SyntaxNodeRef, edits: Vec<TextEdit>, } #[derive(Debug, Clone)] pub struct TextEdit { pub range: TextRange, pub new_text: String, } impl IncrementalReparser { pub fn new(tree: SyntaxNodeRef) -> Self { Self { _old_tree: tree, edits: Vec::new(), } } pub fn add_edit(&mut self, edit: TextEdit) { self.edits.push(edit); } pub fn reparse(&self, new_text: &str) -> ParseResult { let tokens = tokenize(new_text); let parser = Parser::new(tokens); parser.parse() } } pub struct SyntaxTreeBuilder { green: GreenNode, } impl SyntaxTreeBuilder { pub fn new(green: GreenNode) -> Self { Self { green } } pub fn build(self) -> SyntaxNodeRef { SyntaxNodeRef::new_root(self.green) } } pub fn parse_expression(input: &str) -> SyntaxNodeRef { let tokens = tokenize(input); let mut parser = Parser::new(tokens); // Build just an expression tree parser.builder.start_node(SyntaxKind::Root.into()); // Include leading whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } // Parse the expression if !parser.at_end() { parser.expression(); } // Include trailing whitespace as trivia while parser.at(SyntaxKind::Whitespace) { parser.trivia(); } parser.builder.finish_node(); let green_node = parser.builder.finish(); SyntaxTreeBuilder::new(green_node).build() } pub struct AstNode { syntax: SyntaxNodeRef, } impl AstNode { pub fn cast(syntax: SyntaxNodeRef) -> Option<Self> { match syntax.kind() { SyntaxKind::Root | SyntaxKind::BinaryExpr | SyntaxKind::UnaryExpr | SyntaxKind::ParenExpr | SyntaxKind::Literal | SyntaxKind::BlockStmt | SyntaxKind::ExprStmt | SyntaxKind::LetStmt | SyntaxKind::IfStmt | SyntaxKind::WhileStmt | SyntaxKind::ReturnStmt | SyntaxKind::FnDef | SyntaxKind::CallExpr => Some(Self { syntax }), _ => None, } } pub fn syntax(&self) -> &SyntaxNodeRef { &self.syntax } } pub trait AstToken { fn cast(syntax: SyntaxTokenRef) -> Option<Self> where Self: Sized; fn syntax(&self) -> &SyntaxTokenRef; } pub struct Identifier { syntax: SyntaxTokenRef, } impl AstToken for Identifier { fn cast(syntax: SyntaxTokenRef) -> Option<Self> { if syntax.kind() == SyntaxKind::Ident { Some(Self { syntax }) } else { None } } fn syntax(&self) -> &SyntaxTokenRef { &self.syntax } } impl Identifier { pub fn text(&self) -> &str { self.syntax.text() } } pub fn walk_tree(node: &SyntaxNodeRef, depth: usize) { let indent = " ".repeat(depth); println!("{}{:?}", indent, node.kind()); for child in node.children() { walk_tree(&child, depth + 1); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let input = "let x = 42 + y;"; let tokens = tokenize(input); assert_eq!(tokens[0].kind, SyntaxKind::Keyword); assert_eq!(tokens[0].text, "let"); assert_eq!(tokens[2].kind, SyntaxKind::Ident); assert_eq!(tokens[2].text, "x"); } #[test] fn test_parse_expression() { let input = "x + 42 * (y - 3)"; // First check tokenization let tokens = tokenize(input); // The expression should tokenize properly assert!(tokens.len() > 5, "Expected more tokens, got: {:?}", tokens); // Check if tokens include operators let has_plus = tokens.iter().any(|t| t.kind == SyntaxKind::Plus); let has_star = tokens.iter().any(|t| t.kind == SyntaxKind::Star); assert!(has_plus, "Missing + operator in tokens: {:?}", tokens); assert!(has_star, "Missing * operator in tokens: {:?}", tokens); let tree = parse_expression(input); assert_eq!(tree.kind(), SyntaxKind::Root); // Check the tree contains the full expression let tree_text = tree.text().to_string(); assert_eq!( tree_text.trim(), input, "Tree text doesn't match input. Got '{}' expected '{}'", tree_text.trim(), input ); } #[test] fn test_incremental_reparse() { let input = "let x = 42;"; let tree = parse_expression(input); let mut reparser = IncrementalReparser::new(tree); reparser.add_edit(TextEdit { range: TextRange::new(TextSize::from(8), TextSize::from(10)), new_text: "100".to_string(), }); let new_tree = reparser.reparse("let x = 100;"); assert_eq!(new_tree.errors.len(), 0); } #[test] fn test_ast_node_cast() { let input = "42 + x"; let tree = parse_expression(input); if let Some(first_child) = tree.first_child() { let ast_node = AstNode::cast(first_child); assert!(ast_node.is_some()); } } } pub fn find_node_at_offset(root: &SyntaxNodeRef, offset: TextSize) -> Option<SyntaxNodeRef> { if !root.text_range().contains(offset) { return None; } let mut result = root.clone(); for child in root.children() { if child.text_range().contains(offset) { if let Some(deeper) = find_node_at_offset(&child, offset) { result = deeper; } } } Some(result) } }
Tree traversal functions enable common IDE operations like finding the syntax node at a cursor position or collecting all nodes of a specific type. The find_node_at_offset function is particularly useful for implementing hover information and go-to-definition features in language servers.
Best Practices
Design your syntax kinds hierarchy to balance granularity with usability. Too few kinds make the tree difficult to analyze, while too many create unnecessary complexity. Group related tokens into categories like operators or keywords when they behave similarly in the grammar.
Implement error recovery in the parser to produce valid trees even for incorrect input. Skip unexpected tokens rather than failing completely, and use error nodes to mark problematic regions. This approach enables IDE features to work on incomplete or incorrect code.
Use checkpoints and node wrapping to handle operator precedence and associativity. The checkpoint mechanism allows the parser to defer node creation until it has enough context to build the correct structure.
Preserve all source text including whitespace and comments. This lossless approach enables accurate source reconstruction and formatting preservation. Treat whitespace and comments as trivia tokens that the parser can skip when building the logical structure.
Cache tokenization results when implementing incremental parsing. Most edits affect only a small portion of the token stream, so reusing unchanged tokens significantly improves performance.
Build a typed AST layer over the raw syntax tree for ergonomic access. While the syntax tree provides complete information, the AST layer offers a more natural API for analysis and transformation tools.
The rowan architecture has proven highly successful in rust-analyzer, demonstrating that lossless syntax trees can provide both the precision needed for IDE features and the performance required for interactive use. The separation of green and red trees, combined with incremental reparsing, creates a solid foundation for modern language tooling.
rust_sitter
rust_sitter provides a declarative approach to generating Tree-sitter parsers directly from Rust code. Unlike traditional parser generators that require separate grammar files, rust_sitter uses procedural macros to transform annotated Rust enums and structs into fully functional Tree-sitter parsers. This approach ensures type safety, enables IDE support, and keeps the grammar definition close to the AST types that consume it.
The library excels at creating incremental parsers suitable for editor integration, where parsing must handle incomplete or invalid input gracefully. The generated parsers support error recovery, incremental reparsing, and syntax highlighting through Tree-sitter’s proven infrastructure. By defining grammars as Rust types, rust_sitter eliminates the impedance mismatch between grammar specifications and the data structures that represent parse trees.
Basic Arithmetic Grammar
#![allow(unused)] fn main() { #[rust_sitter::grammar("arithmetic")] pub mod arithmetic { #[rust_sitter::language] #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number( #[rust_sitter::leaf(pattern = r"\d+(\.\d+)?", transform = |v| v.parse().unwrap())] f64, ), #[rust_sitter::prec_left(1)] Add( Box<Expr>, #[rust_sitter::leaf(text = "+")] (), Box<Expr>, ), #[rust_sitter::prec_left(2)] Mul( Box<Expr>, #[rust_sitter::leaf(text = "*")] (), Box<Expr>, ), #[rust_sitter::prec(4)] Paren( #[rust_sitter::leaf(text = "(")] (), Box<Expr>, #[rust_sitter::leaf(text = ")")] (), ), } #[rust_sitter::extra] struct Whitespace { #[rust_sitter::leaf(pattern = r"\s")] _whitespace: (), } } }
The arithmetic module demonstrates fundamental grammar construction with operator precedence. The Expr enum represents different expression types, with each variant annotated to describe its parsing behavior. The leaf attribute defines terminal symbols, either matching exact text or patterns with optional transformations. Precedence annotations control parsing ambiguities, with higher numbers binding more tightly. Left-associative operators like addition and subtraction share the same precedence level, while right-associative exponentiation uses prec_right.
The Whitespace struct marked with the extra attribute defines tokens that can appear anywhere in the input without being part of the AST. This separation of structural and formatting concerns simplifies grammar definitions while maintaining flexibility in handling different coding styles.
Expression Evaluation
#![allow(unused)] fn main() { impl arithmetic::Expr { pub fn eval(&self) -> f64 { match self { arithmetic::Expr::Number(n) => *n, arithmetic::Expr::Add(l, _, r) => l.eval() + r.eval(), arithmetic::Expr::Sub(l, _, r) => l.eval() - r.eval(), arithmetic::Expr::Mul(l, _, r) => l.eval() * r.eval(), arithmetic::Expr::Div(l, _, r) => l.eval() / r.eval(), arithmetic::Expr::Pow(l, _, r) => l.eval().powf(r.eval()), arithmetic::Expr::Paren(_, e, _) => e.eval(), arithmetic::Expr::Neg(_, e) => -e.eval(), } } } }
The eval method demonstrates how parsed ASTs integrate with Rust code. Since the grammar produces strongly-typed Rust values, implementing interpreters or compilers becomes straightforward pattern matching. The recursive structure naturally maps to recursive evaluation, with each expression type defining its semantic behavior.
S-Expression Grammar
#![allow(unused)] fn main() { #[rust_sitter::grammar("s_expression")] pub mod s_expression { #[rust_sitter::language] #[derive(Debug, Clone, PartialEq)] pub enum SExpr { Symbol( #[rust_sitter::leaf(pattern = r"[a-zA-Z_][a-zA-Z0-9_\-]*", transform = |s| s.to_string())] String, ), Number( #[rust_sitter::leaf(pattern = r"-?\d+", transform = |s| s.parse().unwrap())] i64, ), String(StringLiteral), List( #[rust_sitter::leaf(text = "(")] (), #[rust_sitter::repeat(non_empty = false)] Vec<SExpr>, #[rust_sitter::leaf(text = ")")] (), ), } #[derive(Debug, Clone, PartialEq)] pub struct StringLiteral { #[rust_sitter::leaf(text = "\"")] _open: (), #[rust_sitter::leaf(pattern = r#"([^"\\]|\\.)*"#, transform = |s| s.to_string())] pub value: String, #[rust_sitter::leaf(text = "\"")] _close: (), } #[rust_sitter::extra] struct Whitespace { #[rust_sitter::leaf(pattern = r"[ \t\n\r]+")] _whitespace: (), } #[rust_sitter::extra] struct Comment { #[rust_sitter::leaf(pattern = r";[^\n]*")] _comment: (), } } }
The S-expression grammar showcases parsing of LISP-style symbolic expressions with nested lists and multiple data types. The SExpr enum includes symbols, numbers, strings, and lists. The repeat attribute with non_empty flag controls whether empty lists are allowed.
String parsing demonstrates pattern-based lexing with transformations. The pattern matches any sequence of non-quote characters or escape sequences, while the transform function converts the matched text into a Rust String. This separation of lexical and semantic concerns keeps grammars readable while supporting complex tokenization rules.
Configuration Language Grammar
#![allow(unused)] fn main() { #[rust_sitter::grammar("config")] pub mod config { use rust_sitter::Spanned; #[rust_sitter::language] #[derive(Debug, Clone)] pub struct Config { #[rust_sitter::repeat(non_empty = false)] pub entries: Vec<Entry>, } #[derive(Debug, Clone)] pub struct Entry { pub key: Key, #[rust_sitter::leaf(text = "=")] _eq: (), pub value: Spanned<Value>, #[rust_sitter::leaf(text = "\n")] _newline: (), } #[derive(Debug, Clone)] pub struct Key { #[rust_sitter::leaf(pattern = r"[a-zA-Z][a-zA-Z0-9_\.]*", transform = |s| s.to_string())] pub name: String, } #[derive(Debug, Clone)] pub enum Value { String(StringValue), Number( #[rust_sitter::leaf(pattern = r"-?\d+(\.\d+)?", transform = |s| s.parse().unwrap())] f64, ), Bool( #[rust_sitter::leaf(pattern = r"true|false", transform = |s| s == "true")] bool, ), List(ListValue), } #[derive(Debug, Clone)] pub struct ListValue { #[rust_sitter::leaf(text = "[")] _open: (), #[rust_sitter::repeat(non_empty = false)] #[rust_sitter::delimited( #[rust_sitter::leaf(text = ",")] () )] pub items: Vec<Value>, #[rust_sitter::leaf(text = "]")] _close: (), } } }
The config module implements a simple configuration file grammar with key-value pairs. The Config struct serves as the root node, containing a vector of entries. Each entry has a key, equals sign, value, and newline terminator. Values can be strings, numbers, booleans, or lists.
The Spanned type wraps values with source location information, useful for error reporting. Lists use the delimited attribute to handle comma-separated items. The extra whitespace and comment rules allow these tokens between any grammar elements.
This grammar demonstrates a practical use case for configuration files, showing how rust_sitter handles line-oriented formats with mixed value types.
Grammar Annotations
rust_sitter provides several key annotations for controlling parser generation. The grammar attribute on a module specifies the parser name and generates the parse function. The language attribute marks the root type that parsing produces. Within types, leaf attributes define terminal symbols with optional patterns and transformations.
Precedence control uses prec, prec_left, and prec_right attributes with numeric levels. Higher numbers bind more tightly, resolving ambiguities in expression parsing. Associativity attributes determine how operators of the same precedence combine, critical for arithmetic and logical operations.
The repeat attribute generates zero-or-more or one-or-more patterns, with non_empty controlling minimums. Combined with delimited, it handles separated lists common in programming languages. The extra attribute marks ignorable tokens like whitespace and comments that can appear between any symbols.
Error Recovery
Tree-sitter parsers excel at error recovery, continuing to parse even when encountering invalid syntax. This robustness makes them ideal for editor integration where code is frequently incomplete or temporarily malformed. The generated parser produces partial ASTs with error nodes marking problem areas, enabling features like syntax highlighting and code folding even in broken code.
Error nodes preserve the input text while marking parse failures, allowing tools to provide meaningful error messages. The incremental parsing capability means only changed regions require reparsing, maintaining responsiveness even in large files.
Testing Patterns
#![allow(unused)] fn main() { #[test] fn test_arithmetic_parsing() { // Parse simple number let expr = arithmetic::parse("42").unwrap(); assert_eq!(expr, arithmetic::Expr::Number(42.0)); // Parse addition let expr = arithmetic::parse("1 + 2").unwrap(); match expr { arithmetic::Expr::Add(l, _, r) => { assert_eq!(*l, arithmetic::Expr::Number(1.0)); assert_eq!(*r, arithmetic::Expr::Number(2.0)); } _ => panic!("Expected Add"), } // Parse with precedence let expr = arithmetic::parse("1 + 2 * 3").unwrap(); assert_eq!(expr.eval(), 7.0); // 1 + (2 * 3) } }
Parser testing combines unit tests for specific constructs with integration tests for complete programs. The parse function returns a Result, enabling standard Rust error handling. Comparing parsed ASTs with expected structures verifies parsing behavior, while evaluation tests confirm semantic correctness.
#![allow(unused)] fn main() { #[test] fn test_arithmetic_evaluation() { let cases = vec![ ("10.5 + 20.3", 30.8), ("100 - 50", 50.0), ("6 * 7", 42.0), ("20 / 4", 5.0), ("2 ^ 8", 256.0), ("(2 + 3) * (4 + 5)", 45.0), ("2 * 3 + 4 * 5", 26.0), ]; for (input, expected) in cases { let expr = arithmetic::parse(input).unwrap(); let result = expr.eval(); assert!((result - expected).abs() < 0.001, "Failed for '{}': got {}, expected {}", input, result, expected); } } }
Property-based testing works well with grammar-based parsers. Generate random valid inputs according to the grammar, parse them, and verify properties like roundtrip printing or evaluation consistency. This approach finds edge cases that hand-written tests might miss.
Build Configuration
The build.rs script integrates parser generation into the Cargo build process. The rust_sitter_tool::build_parsers function processes annotated modules, generating C code for Tree-sitter and Rust bindings. This generation happens at build time, ensuring parsers stay synchronized with their grammar definitions.
The generated code includes both the Tree-sitter parser tables and Rust wrapper functions. The parser tables use Tree-sitter’s compact representation optimized for incremental parsing, while the wrapper provides safe Rust APIs matching the original type definitions.
Integration with Tree-sitter
rust_sitter generates standard Tree-sitter parsers compatible with the entire Tree-sitter ecosystem. The parsers work with Tree-sitter’s highlighting queries, language servers, and editor plugins. This compatibility means rust_sitter grammars can power syntax highlighting in editors like Neovim, VSCode, and Emacs without additional work.
The generated parsers support Tree-sitter’s query language for pattern matching over syntax trees. Queries can extract specific patterns, power syntax highlighting, or identify code patterns for refactoring tools. This declarative approach to tree traversal complements the type-safe AST access from Rust code.
Performance Characteristics
Tree-sitter parsers use table-driven parsing with excellent performance characteristics. The generated parsers handle gigabyte-scale files with sub-second parse times, while incremental reparsing typically completes in microseconds. Memory usage remains bounded even for large files through Tree-sitter’s compressed tree representation.
The parsing algorithm uses a variant of LR parsing optimized for error recovery and incremental updates. Unlike traditional LR parsers that fail on first error, Tree-sitter continues parsing to produce useful partial results. This robustness comes with minimal performance overhead compared to strict parsers.
Best Practices
Structure grammars to mirror the intended AST closely, using Rust’s type system to enforce invariants. Separate lexical concerns using leaf patterns from structural concerns in the type hierarchy. Use precedence annotations consistently, documenting the intended associativity and binding strength.
Keep terminal patterns simple and unambiguous. Complex lexical rules should use separate leaf types rather than elaborate patterns. This separation improves error messages and makes grammars easier to understand and modify.
Design ASTs for consumption, not just parsing. Include semantic information in the types where it simplifies later processing. The transform functions on leaf nodes can perform initial interpretation, converting strings to numbers or normalizing identifiers.
Test grammars extensively with both valid and invalid inputs. Verify that error recovery produces useful partial results. Check that precedence and associativity match language specifications. Include tests for edge cases like empty inputs, deeply nested structures, and maximum-length identifiers.
rust_sitter bridges the gap between Tree-sitter’s powerful parsing infrastructure and Rust’s type system. By generating parsers from type definitions, it ensures grammar and AST remain synchronized while providing excellent IDE support and type safety. The combination of declarative grammar specification, automatic parser generation, and Tree-sitter’s robust parsing algorithm makes rust_sitter an excellent choice for building development tools and language implementations.
rustc_lexer
The rustc_lexer crate is the actual lexer used by the Rust compiler, extracted as a standalone library. Unlike traditional lexer generators, it provides a hand-written, highly optimized tokenizer specifically designed for the Rust language. This makes it invaluable for building Rust tooling, language servers, and compilers for Rust-like languages.
This lexer operates at the lowest level, producing raw tokens without any semantic understanding. It handles all of Rust’s complex lexical features including raw strings, byte strings, numeric literals with various bases, and proper Unicode support. The lexer is designed for maximum performance and minimal allocation, making it suitable for incremental parsing scenarios.
Basic Usage
The lexer provides a simple cursor-based API that produces one token at a time. Each token includes its kind and byte length in the source.
use std::ops::Range; use rustc_lexer::{self, Base, LiteralKind, TokenKind}; #[derive(Debug, Clone, PartialEq)] pub struct Token { pub kind: TokenKind, pub text: String, pub span: Range<usize>, } pub struct Lexer<'input> { input: &'input str, position: usize, } pub fn strip_shebang(input: &str) -> &str { rustc_lexer::strip_shebang(input) .map(|shebang_len| &input[shebang_len..]) .unwrap_or(input) } pub fn cook_lexer_literal( kind: LiteralKind, text: &str, _start: usize, ) -> Result<ParsedLiteral, LiteralError> { match kind { LiteralKind::Int { base, empty_int } => { if empty_int { return Err(LiteralError::EmptyInt); } let text = text.replace('_', ""); let value = match base { Base::Binary => u128::from_str_radix(&text[2..], 2), Base::Octal => u128::from_str_radix(&text[2..], 8), Base::Decimal => text.parse(), Base::Hexadecimal => u128::from_str_radix(&text[2..], 16), }; match value { Ok(n) => Ok(ParsedLiteral::Int(n)), Err(_) => Err(LiteralError::IntegerOverflow), } } LiteralKind::Float { base, empty_exponent, } => { if empty_exponent { return Err(LiteralError::EmptyExponent); } if base != Base::Decimal { return Err(LiteralError::NonDecimalFloat); } let text = text.replace('_', ""); match text.parse() { Ok(f) => Ok(ParsedLiteral::Float(f)), Err(_) => Err(LiteralError::InvalidFloat), } } LiteralKind::Char { terminated } => { if !terminated { return Err(LiteralError::UnterminatedChar); } let content = &text[1..text.len() - 1]; let unescaped = unescape_char(content)?; Ok(ParsedLiteral::Char(unescaped)) } LiteralKind::Byte { terminated } => { if !terminated { return Err(LiteralError::UnterminatedByte); } let content = &text[2..text.len() - 1]; let unescaped = unescape_byte(content)?; Ok(ParsedLiteral::Byte(unescaped)) } LiteralKind::Str { terminated } => { if !terminated { return Err(LiteralError::UnterminatedString); } let content = &text[1..text.len() - 1]; let unescaped = unescape_string(content)?; Ok(ParsedLiteral::Str(unescaped)) } LiteralKind::ByteStr { terminated } => { if !terminated { return Err(LiteralError::UnterminatedByteString); } let content = &text[2..text.len() - 1]; let unescaped = unescape_byte_string(content)?; Ok(ParsedLiteral::ByteStr(unescaped)) } LiteralKind::RawStr { n_hashes, started, terminated, } => { if !started || !terminated { return Err(LiteralError::UnterminatedRawString); } let _hashes = "#".repeat(n_hashes); let start = 2 + n_hashes; let end = text.len() - n_hashes - 1; let content = text[start..end].to_string(); Ok(ParsedLiteral::RawStr(content)) } LiteralKind::RawByteStr { n_hashes, started, terminated, } => { if !started || !terminated { return Err(LiteralError::UnterminatedRawByteString); } let _hashes = "#".repeat(n_hashes); let start = 3 + n_hashes; let end = text.len() - n_hashes - 1; let content = text.as_bytes()[start..end].to_vec(); Ok(ParsedLiteral::RawByteStr(content)) } } } #[derive(Debug, Clone, PartialEq)] pub enum ParsedLiteral { Int(u128), Float(f64), Char(char), Byte(u8), Str(String), ByteStr(Vec<u8>), RawStr(String), RawByteStr(Vec<u8>), } #[derive(Debug, Clone, PartialEq)] pub enum LiteralError { EmptyInt, IntegerOverflow, EmptyExponent, NonDecimalFloat, InvalidFloat, UnterminatedChar, UnterminatedByte, UnterminatedString, UnterminatedByteString, UnterminatedRawString, UnterminatedRawByteString, InvalidEscape(String), } fn unescape_char(s: &str) -> Result<char, LiteralError> { if let Some(stripped) = s.strip_prefix('\\') { match stripped { "n" => Ok('\n'), "r" => Ok('\r'), "t" => Ok('\t'), "\\" => Ok('\\'), "'" => Ok('\''), "\"" => Ok('"'), "0" => Ok('\0'), _ => Err(LiteralError::InvalidEscape(s.to_string())), } } else if s.len() == 1 { Ok(s.chars().next().unwrap()) } else { Err(LiteralError::InvalidEscape(s.to_string())) } } fn unescape_byte(s: &str) -> Result<u8, LiteralError> { unescape_char(s).and_then(|c| { if c as u32 <= 255 { Ok(c as u8) } else { Err(LiteralError::InvalidEscape(s.to_string())) } }) } fn unescape_string(s: &str) -> Result<String, LiteralError> { let mut result = String::new(); let mut chars = s.chars(); while let Some(ch) = chars.next() { if ch == '\\' { if let Some(next) = chars.next() { match next { 'n' => result.push('\n'), 'r' => result.push('\r'), 't' => result.push('\t'), '\\' => result.push('\\'), '\'' => result.push('\''), '"' => result.push('"'), '0' => result.push('\0'), _ => return Err(LiteralError::InvalidEscape(format!("\\{}", next))), } } } else { result.push(ch); } } Ok(result) } fn unescape_byte_string(s: &str) -> Result<Vec<u8>, LiteralError> { unescape_string(s).map(|s| s.into_bytes()) } pub fn tokenize_and_validate(input: &str) -> Result<Vec<Token>, Vec<ValidationError>> { let mut lexer = Lexer::new(input); let mut errors = Vec::new(); let tokens = lexer.tokenize_with_trivia(); for (i, token) in tokens.iter().enumerate() { match &token.kind { TokenKind::Unknown => { errors.push(ValidationError { token_index: i, kind: ValidationErrorKind::UnknownToken, span: token.span.clone(), }); } TokenKind::Literal { kind, .. } => { if let Err(e) = cook_lexer_literal(*kind, &token.text, token.span.start) { errors.push(ValidationError { token_index: i, kind: ValidationErrorKind::InvalidLiteral(e), span: token.span.clone(), }); } } _ => {} } } if errors.is_empty() { Ok(tokens) } else { Err(errors) } } #[derive(Debug, Clone)] pub struct ValidationError { pub token_index: usize, pub kind: ValidationErrorKind, pub span: Range<usize>, } #[derive(Debug, Clone)] pub enum ValidationErrorKind { UnknownToken, InvalidLiteral(LiteralError), } pub fn is_whitespace(kind: TokenKind) -> bool { matches!(kind, TokenKind::Whitespace) } pub fn is_comment(kind: TokenKind) -> bool { matches!( kind, TokenKind::LineComment | TokenKind::BlockComment { .. } ) } pub fn is_literal(kind: TokenKind) -> bool { matches!(kind, TokenKind::Literal { .. }) } pub fn describe_token(kind: TokenKind) -> &'static str { match kind { TokenKind::Ident => "identifier", TokenKind::RawIdent => "raw identifier", TokenKind::Literal { kind, .. } => match kind { LiteralKind::Int { .. } => "integer literal", LiteralKind::Float { .. } => "float literal", LiteralKind::Char { .. } => "character literal", LiteralKind::Byte { .. } => "byte literal", LiteralKind::Str { .. } => "string literal", LiteralKind::ByteStr { .. } => "byte string literal", LiteralKind::RawStr { .. } => "raw string literal", LiteralKind::RawByteStr { .. } => "raw byte string literal", }, TokenKind::Lifetime { .. } => "lifetime", TokenKind::Semi => "semicolon", TokenKind::Comma => "comma", TokenKind::Dot => "dot", TokenKind::OpenParen => "open parenthesis", TokenKind::CloseParen => "close parenthesis", TokenKind::OpenBrace => "open brace", TokenKind::CloseBrace => "close brace", TokenKind::OpenBracket => "open bracket", TokenKind::CloseBracket => "close bracket", TokenKind::At => "at sign", TokenKind::Pound => "pound sign", TokenKind::Tilde => "tilde", TokenKind::Question => "question mark", TokenKind::Colon => "colon", TokenKind::Dollar => "dollar sign", TokenKind::Eq => "equals", TokenKind::Lt => "less than", TokenKind::Gt => "greater than", TokenKind::Minus => "minus", TokenKind::And => "ampersand", TokenKind::Or => "pipe", TokenKind::Plus => "plus", TokenKind::Star => "star", TokenKind::Slash => "slash", TokenKind::Caret => "caret", TokenKind::Percent => "percent", TokenKind::Unknown => "unknown token", TokenKind::Not => "exclamation mark", TokenKind::Whitespace => "whitespace", TokenKind::LineComment => "line comment", TokenKind::BlockComment { .. } => "block comment", } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_tokenization() { let input = "fn main() { let x = 42; }"; let mut lexer = Lexer::new(input); let tokens = lexer.tokenize(); assert_eq!(tokens[0].kind, TokenKind::Ident); assert_eq!(tokens[0].text, "fn"); assert_eq!(tokens[1].kind, TokenKind::Ident); assert_eq!(tokens[1].text, "main"); assert_eq!(tokens[2].kind, TokenKind::OpenParen); assert_eq!(tokens[3].kind, TokenKind::CloseParen); assert_eq!(tokens[4].kind, TokenKind::OpenBrace); } #[test] fn test_literals() { let input = r##"42 3.14 'a' b'x' "hello" b"bytes" r#"raw"#"##; let mut lexer = Lexer::new(input); let tokens = lexer.tokenize(); // Check that all are literals for token in &tokens { assert!(is_literal(token.kind)); } } #[test] fn test_trivia_handling() { let input = "// comment\nfn /* block */ main()"; let mut lexer = Lexer::new(input); // Without trivia let tokens = lexer.tokenize(); assert_eq!(tokens.len(), 4); // fn main ( ) // With trivia let mut lexer = Lexer::new(input); let tokens = lexer.tokenize_with_trivia(); assert!(tokens.len() > 4); // includes comments and whitespace } #[test] fn test_shebang() { let input = "#!/usr/bin/env rust\nfn main() {}"; let stripped = strip_shebang(input); // The newline is included after stripping the shebang assert!(stripped.starts_with("\nfn main()")); } #[test] fn test_literal_parsing() { let cases = vec![ ( LiteralKind::Int { base: Base::Decimal, empty_int: false, }, "42", ParsedLiteral::Int(42), ), ( LiteralKind::Int { base: Base::Hexadecimal, empty_int: false, }, "0xFF", ParsedLiteral::Int(255), ), ( LiteralKind::Float { base: Base::Decimal, empty_exponent: false, }, "3.14", ParsedLiteral::Float(3.14), ), ( LiteralKind::Char { terminated: true }, "'a'", ParsedLiteral::Char('a'), ), ]; for (kind, text, expected) in cases { let result = cook_lexer_literal(kind, text, 0).unwrap(); assert_eq!(result, expected); } } } impl<'input> Lexer<'input> { pub fn new(input: &'input str) -> Self { Self { input, position: 0 } } pub fn tokenize(&mut self) -> Vec<Token> { let mut tokens = Vec::new(); while self.position < self.input.len() { let remaining = &self.input[self.position..]; let token = rustc_lexer::first_token(remaining); let start = self.position; let end = self.position + token.len as usize; let text = self.input[start..end].to_string(); // Skip whitespace and comments unless we're preserving them match token.kind { TokenKind::Whitespace | TokenKind::LineComment | TokenKind::BlockComment { .. } => { self.position = end; continue; } _ => {} } tokens.push(Token { kind: token.kind, text, span: start..end, }); self.position = end; } tokens } pub fn tokenize_with_trivia(&mut self) -> Vec<Token> { let mut tokens = Vec::new(); while self.position < self.input.len() { let remaining = &self.input[self.position..]; let token = rustc_lexer::first_token(remaining); let start = self.position; let end = self.position + token.len as usize; let text = self.input[start..end].to_string(); tokens.push(Token { kind: token.kind, text, span: start..end, }); self.position = end; } tokens } }
This wrapper accumulates tokens into a vector for convenience. The lexer skips whitespace and comments by default, focusing on syntactically significant tokens.
Token Kinds
The TokenKind enum covers all possible Rust tokens, from simple punctuation to complex literal forms. The lexer distinguishes between many subtle cases that are important for proper Rust parsing.
use std::ops::Range; use rustc_lexer::{self, Base, LiteralKind, TokenKind}; #[derive(Debug, Clone, PartialEq)] pub struct Token { pub kind: TokenKind, pub text: String, pub span: Range<usize>, } pub struct Lexer<'input> { input: &'input str, position: usize, } impl<'input> Lexer<'input> { pub fn new(input: &'input str) -> Self { Self { input, position: 0 } } pub fn tokenize(&mut self) -> Vec<Token> { let mut tokens = Vec::new(); while self.position < self.input.len() { let remaining = &self.input[self.position..]; let token = rustc_lexer::first_token(remaining); let start = self.position; let end = self.position + token.len as usize; let text = self.input[start..end].to_string(); // Skip whitespace and comments unless we're preserving them match token.kind { TokenKind::Whitespace | TokenKind::LineComment | TokenKind::BlockComment { .. } => { self.position = end; continue; } _ => {} } tokens.push(Token { kind: token.kind, text, span: start..end, }); self.position = end; } tokens } pub fn tokenize_with_trivia(&mut self) -> Vec<Token> { let mut tokens = Vec::new(); while self.position < self.input.len() { let remaining = &self.input[self.position..]; let token = rustc_lexer::first_token(remaining); let start = self.position; let end = self.position + token.len as usize; let text = self.input[start..end].to_string(); tokens.push(Token { kind: token.kind, text, span: start..end, }); self.position = end; } tokens } } pub fn strip_shebang(input: &str) -> &str { rustc_lexer::strip_shebang(input) .map(|shebang_len| &input[shebang_len..]) .unwrap_or(input) } pub fn cook_lexer_literal( kind: LiteralKind, text: &str, _start: usize, ) -> Result<ParsedLiteral, LiteralError> { match kind { LiteralKind::Int { base, empty_int } => { if empty_int { return Err(LiteralError::EmptyInt); } let text = text.replace('_', ""); let value = match base { Base::Binary => u128::from_str_radix(&text[2..], 2), Base::Octal => u128::from_str_radix(&text[2..], 8), Base::Decimal => text.parse(), Base::Hexadecimal => u128::from_str_radix(&text[2..], 16), }; match value { Ok(n) => Ok(ParsedLiteral::Int(n)), Err(_) => Err(LiteralError::IntegerOverflow), } } LiteralKind::Float { base, empty_exponent, } => { if empty_exponent { return Err(LiteralError::EmptyExponent); } if base != Base::Decimal { return Err(LiteralError::NonDecimalFloat); } let text = text.replace('_', ""); match text.parse() { Ok(f) => Ok(ParsedLiteral::Float(f)), Err(_) => Err(LiteralError::InvalidFloat), } } LiteralKind::Char { terminated } => { if !terminated { return Err(LiteralError::UnterminatedChar); } let content = &text[1..text.len() - 1]; let unescaped = unescape_char(content)?; Ok(ParsedLiteral::Char(unescaped)) } LiteralKind::Byte { terminated } => { if !terminated { return Err(LiteralError::UnterminatedByte); } let content = &text[2..text.len() - 1]; let unescaped = unescape_byte(content)?; Ok(ParsedLiteral::Byte(unescaped)) } LiteralKind::Str { terminated } => { if !terminated { return Err(LiteralError::UnterminatedString); } let content = &text[1..text.len() - 1]; let unescaped = unescape_string(content)?; Ok(ParsedLiteral::Str(unescaped)) } LiteralKind::ByteStr { terminated } => { if !terminated { return Err(LiteralError::UnterminatedByteString); } let content = &text[2..text.len() - 1]; let unescaped = unescape_byte_string(content)?; Ok(ParsedLiteral::ByteStr(unescaped)) } LiteralKind::RawStr { n_hashes, started, terminated, } => { if !started || !terminated { return Err(LiteralError::UnterminatedRawString); } let _hashes = "#".repeat(n_hashes); let start = 2 + n_hashes; let end = text.len() - n_hashes - 1; let content = text[start..end].to_string(); Ok(ParsedLiteral::RawStr(content)) } LiteralKind::RawByteStr { n_hashes, started, terminated, } => { if !started || !terminated { return Err(LiteralError::UnterminatedRawByteString); } let _hashes = "#".repeat(n_hashes); let start = 3 + n_hashes; let end = text.len() - n_hashes - 1; let content = text.as_bytes()[start..end].to_vec(); Ok(ParsedLiteral::RawByteStr(content)) } } } #[derive(Debug, Clone, PartialEq)] pub enum ParsedLiteral { Int(u128), Float(f64), Char(char), Byte(u8), Str(String), ByteStr(Vec<u8>), RawStr(String), RawByteStr(Vec<u8>), } #[derive(Debug, Clone, PartialEq)] pub enum LiteralError { EmptyInt, IntegerOverflow, EmptyExponent, NonDecimalFloat, InvalidFloat, UnterminatedChar, UnterminatedByte, UnterminatedString, UnterminatedByteString, UnterminatedRawString, UnterminatedRawByteString, InvalidEscape(String), } fn unescape_char(s: &str) -> Result<char, LiteralError> { if let Some(stripped) = s.strip_prefix('\\') { match stripped { "n" => Ok('\n'), "r" => Ok('\r'), "t" => Ok('\t'), "\\" => Ok('\\'), "'" => Ok('\''), "\"" => Ok('"'), "0" => Ok('\0'), _ => Err(LiteralError::InvalidEscape(s.to_string())), } } else if s.len() == 1 { Ok(s.chars().next().unwrap()) } else { Err(LiteralError::InvalidEscape(s.to_string())) } } fn unescape_byte(s: &str) -> Result<u8, LiteralError> { unescape_char(s).and_then(|c| { if c as u32 <= 255 { Ok(c as u8) } else { Err(LiteralError::InvalidEscape(s.to_string())) } }) } fn unescape_string(s: &str) -> Result<String, LiteralError> { let mut result = String::new(); let mut chars = s.chars(); while let Some(ch) = chars.next() { if ch == '\\' { if let Some(next) = chars.next() { match next { 'n' => result.push('\n'), 'r' => result.push('\r'), 't' => result.push('\t'), '\\' => result.push('\\'), '\'' => result.push('\''), '"' => result.push('"'), '0' => result.push('\0'), _ => return Err(LiteralError::InvalidEscape(format!("\\{}", next))), } } } else { result.push(ch); } } Ok(result) } fn unescape_byte_string(s: &str) -> Result<Vec<u8>, LiteralError> { unescape_string(s).map(|s| s.into_bytes()) } pub fn tokenize_and_validate(input: &str) -> Result<Vec<Token>, Vec<ValidationError>> { let mut lexer = Lexer::new(input); let mut errors = Vec::new(); let tokens = lexer.tokenize_with_trivia(); for (i, token) in tokens.iter().enumerate() { match &token.kind { TokenKind::Unknown => { errors.push(ValidationError { token_index: i, kind: ValidationErrorKind::UnknownToken, span: token.span.clone(), }); } TokenKind::Literal { kind, .. } => { if let Err(e) = cook_lexer_literal(*kind, &token.text, token.span.start) { errors.push(ValidationError { token_index: i, kind: ValidationErrorKind::InvalidLiteral(e), span: token.span.clone(), }); } } _ => {} } } if errors.is_empty() { Ok(tokens) } else { Err(errors) } } #[derive(Debug, Clone)] pub struct ValidationError { pub token_index: usize, pub kind: ValidationErrorKind, pub span: Range<usize>, } #[derive(Debug, Clone)] pub enum ValidationErrorKind { UnknownToken, InvalidLiteral(LiteralError), } pub fn is_whitespace(kind: TokenKind) -> bool { matches!(kind, TokenKind::Whitespace) } pub fn is_comment(kind: TokenKind) -> bool { matches!( kind, TokenKind::LineComment | TokenKind::BlockComment { .. } ) } pub fn is_literal(kind: TokenKind) -> bool { matches!(kind, TokenKind::Literal { .. }) } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_tokenization() { let input = "fn main() { let x = 42; }"; let mut lexer = Lexer::new(input); let tokens = lexer.tokenize(); assert_eq!(tokens[0].kind, TokenKind::Ident); assert_eq!(tokens[0].text, "fn"); assert_eq!(tokens[1].kind, TokenKind::Ident); assert_eq!(tokens[1].text, "main"); assert_eq!(tokens[2].kind, TokenKind::OpenParen); assert_eq!(tokens[3].kind, TokenKind::CloseParen); assert_eq!(tokens[4].kind, TokenKind::OpenBrace); } #[test] fn test_literals() { let input = r##"42 3.14 'a' b'x' "hello" b"bytes" r#"raw"#"##; let mut lexer = Lexer::new(input); let tokens = lexer.tokenize(); // Check that all are literals for token in &tokens { assert!(is_literal(token.kind)); } } #[test] fn test_trivia_handling() { let input = "// comment\nfn /* block */ main()"; let mut lexer = Lexer::new(input); // Without trivia let tokens = lexer.tokenize(); assert_eq!(tokens.len(), 4); // fn main ( ) // With trivia let mut lexer = Lexer::new(input); let tokens = lexer.tokenize_with_trivia(); assert!(tokens.len() > 4); // includes comments and whitespace } #[test] fn test_shebang() { let input = "#!/usr/bin/env rust\nfn main() {}"; let stripped = strip_shebang(input); // The newline is included after stripping the shebang assert!(stripped.starts_with("\nfn main()")); } #[test] fn test_literal_parsing() { let cases = vec![ ( LiteralKind::Int { base: Base::Decimal, empty_int: false, }, "42", ParsedLiteral::Int(42), ), ( LiteralKind::Int { base: Base::Hexadecimal, empty_int: false, }, "0xFF", ParsedLiteral::Int(255), ), ( LiteralKind::Float { base: Base::Decimal, empty_exponent: false, }, "3.14", ParsedLiteral::Float(3.14), ), ( LiteralKind::Char { terminated: true }, "'a'", ParsedLiteral::Char('a'), ), ]; for (kind, text, expected) in cases { let result = cook_lexer_literal(kind, text, 0).unwrap(); assert_eq!(result, expected); } } } pub fn describe_token(kind: TokenKind) -> &'static str { match kind { TokenKind::Ident => "identifier", TokenKind::RawIdent => "raw identifier", TokenKind::Literal { kind, .. } => match kind { LiteralKind::Int { .. } => "integer literal", LiteralKind::Float { .. } => "float literal", LiteralKind::Char { .. } => "character literal", LiteralKind::Byte { .. } => "byte literal", LiteralKind::Str { .. } => "string literal", LiteralKind::ByteStr { .. } => "byte string literal", LiteralKind::RawStr { .. } => "raw string literal", LiteralKind::RawByteStr { .. } => "raw byte string literal", }, TokenKind::Lifetime { .. } => "lifetime", TokenKind::Semi => "semicolon", TokenKind::Comma => "comma", TokenKind::Dot => "dot", TokenKind::OpenParen => "open parenthesis", TokenKind::CloseParen => "close parenthesis", TokenKind::OpenBrace => "open brace", TokenKind::CloseBrace => "close brace", TokenKind::OpenBracket => "open bracket", TokenKind::CloseBracket => "close bracket", TokenKind::At => "at sign", TokenKind::Pound => "pound sign", TokenKind::Tilde => "tilde", TokenKind::Question => "question mark", TokenKind::Colon => "colon", TokenKind::Dollar => "dollar sign", TokenKind::Eq => "equals", TokenKind::Lt => "less than", TokenKind::Gt => "greater than", TokenKind::Minus => "minus", TokenKind::And => "ampersand", TokenKind::Or => "pipe", TokenKind::Plus => "plus", TokenKind::Star => "star", TokenKind::Slash => "slash", TokenKind::Caret => "caret", TokenKind::Percent => "percent", TokenKind::Unknown => "unknown token", TokenKind::Not => "exclamation mark", TokenKind::Whitespace => "whitespace", TokenKind::LineComment => "line comment", TokenKind::BlockComment { .. } => "block comment", } }
This function provides human-readable descriptions for each token kind, useful for error messages and debugging.
Literal Processing
Raw tokens need to be “cooked” to extract their actual values. The lexer identifies literal kinds but doesn’t parse their contents, leaving that to a separate validation step.
use std::ops::Range; use rustc_lexer::{self, Base, LiteralKind, TokenKind}; #[derive(Debug, Clone, PartialEq)] pub struct Token { pub kind: TokenKind, pub text: String, pub span: Range<usize>, } pub struct Lexer<'input> { input: &'input str, position: usize, } impl<'input> Lexer<'input> { pub fn new(input: &'input str) -> Self { Self { input, position: 0 } } pub fn tokenize(&mut self) -> Vec<Token> { let mut tokens = Vec::new(); while self.position < self.input.len() { let remaining = &self.input[self.position..]; let token = rustc_lexer::first_token(remaining); let start = self.position; let end = self.position + token.len as usize; let text = self.input[start..end].to_string(); // Skip whitespace and comments unless we're preserving them match token.kind { TokenKind::Whitespace | TokenKind::LineComment | TokenKind::BlockComment { .. } => { self.position = end; continue; } _ => {} } tokens.push(Token { kind: token.kind, text, span: start..end, }); self.position = end; } tokens } pub fn tokenize_with_trivia(&mut self) -> Vec<Token> { let mut tokens = Vec::new(); while self.position < self.input.len() { let remaining = &self.input[self.position..]; let token = rustc_lexer::first_token(remaining); let start = self.position; let end = self.position + token.len as usize; let text = self.input[start..end].to_string(); tokens.push(Token { kind: token.kind, text, span: start..end, }); self.position = end; } tokens } } pub fn strip_shebang(input: &str) -> &str { rustc_lexer::strip_shebang(input) .map(|shebang_len| &input[shebang_len..]) .unwrap_or(input) } #[derive(Debug, Clone, PartialEq)] pub enum ParsedLiteral { Int(u128), Float(f64), Char(char), Byte(u8), Str(String), ByteStr(Vec<u8>), RawStr(String), RawByteStr(Vec<u8>), } #[derive(Debug, Clone, PartialEq)] pub enum LiteralError { EmptyInt, IntegerOverflow, EmptyExponent, NonDecimalFloat, InvalidFloat, UnterminatedChar, UnterminatedByte, UnterminatedString, UnterminatedByteString, UnterminatedRawString, UnterminatedRawByteString, InvalidEscape(String), } fn unescape_char(s: &str) -> Result<char, LiteralError> { if let Some(stripped) = s.strip_prefix('\\') { match stripped { "n" => Ok('\n'), "r" => Ok('\r'), "t" => Ok('\t'), "\\" => Ok('\\'), "'" => Ok('\''), "\"" => Ok('"'), "0" => Ok('\0'), _ => Err(LiteralError::InvalidEscape(s.to_string())), } } else if s.len() == 1 { Ok(s.chars().next().unwrap()) } else { Err(LiteralError::InvalidEscape(s.to_string())) } } fn unescape_byte(s: &str) -> Result<u8, LiteralError> { unescape_char(s).and_then(|c| { if c as u32 <= 255 { Ok(c as u8) } else { Err(LiteralError::InvalidEscape(s.to_string())) } }) } fn unescape_string(s: &str) -> Result<String, LiteralError> { let mut result = String::new(); let mut chars = s.chars(); while let Some(ch) = chars.next() { if ch == '\\' { if let Some(next) = chars.next() { match next { 'n' => result.push('\n'), 'r' => result.push('\r'), 't' => result.push('\t'), '\\' => result.push('\\'), '\'' => result.push('\''), '"' => result.push('"'), '0' => result.push('\0'), _ => return Err(LiteralError::InvalidEscape(format!("\\{}", next))), } } } else { result.push(ch); } } Ok(result) } fn unescape_byte_string(s: &str) -> Result<Vec<u8>, LiteralError> { unescape_string(s).map(|s| s.into_bytes()) } pub fn tokenize_and_validate(input: &str) -> Result<Vec<Token>, Vec<ValidationError>> { let mut lexer = Lexer::new(input); let mut errors = Vec::new(); let tokens = lexer.tokenize_with_trivia(); for (i, token) in tokens.iter().enumerate() { match &token.kind { TokenKind::Unknown => { errors.push(ValidationError { token_index: i, kind: ValidationErrorKind::UnknownToken, span: token.span.clone(), }); } TokenKind::Literal { kind, .. } => { if let Err(e) = cook_lexer_literal(*kind, &token.text, token.span.start) { errors.push(ValidationError { token_index: i, kind: ValidationErrorKind::InvalidLiteral(e), span: token.span.clone(), }); } } _ => {} } } if errors.is_empty() { Ok(tokens) } else { Err(errors) } } #[derive(Debug, Clone)] pub struct ValidationError { pub token_index: usize, pub kind: ValidationErrorKind, pub span: Range<usize>, } #[derive(Debug, Clone)] pub enum ValidationErrorKind { UnknownToken, InvalidLiteral(LiteralError), } pub fn is_whitespace(kind: TokenKind) -> bool { matches!(kind, TokenKind::Whitespace) } pub fn is_comment(kind: TokenKind) -> bool { matches!( kind, TokenKind::LineComment | TokenKind::BlockComment { .. } ) } pub fn is_literal(kind: TokenKind) -> bool { matches!(kind, TokenKind::Literal { .. }) } pub fn describe_token(kind: TokenKind) -> &'static str { match kind { TokenKind::Ident => "identifier", TokenKind::RawIdent => "raw identifier", TokenKind::Literal { kind, .. } => match kind { LiteralKind::Int { .. } => "integer literal", LiteralKind::Float { .. } => "float literal", LiteralKind::Char { .. } => "character literal", LiteralKind::Byte { .. } => "byte literal", LiteralKind::Str { .. } => "string literal", LiteralKind::ByteStr { .. } => "byte string literal", LiteralKind::RawStr { .. } => "raw string literal", LiteralKind::RawByteStr { .. } => "raw byte string literal", }, TokenKind::Lifetime { .. } => "lifetime", TokenKind::Semi => "semicolon", TokenKind::Comma => "comma", TokenKind::Dot => "dot", TokenKind::OpenParen => "open parenthesis", TokenKind::CloseParen => "close parenthesis", TokenKind::OpenBrace => "open brace", TokenKind::CloseBrace => "close brace", TokenKind::OpenBracket => "open bracket", TokenKind::CloseBracket => "close bracket", TokenKind::At => "at sign", TokenKind::Pound => "pound sign", TokenKind::Tilde => "tilde", TokenKind::Question => "question mark", TokenKind::Colon => "colon", TokenKind::Dollar => "dollar sign", TokenKind::Eq => "equals", TokenKind::Lt => "less than", TokenKind::Gt => "greater than", TokenKind::Minus => "minus", TokenKind::And => "ampersand", TokenKind::Or => "pipe", TokenKind::Plus => "plus", TokenKind::Star => "star", TokenKind::Slash => "slash", TokenKind::Caret => "caret", TokenKind::Percent => "percent", TokenKind::Unknown => "unknown token", TokenKind::Not => "exclamation mark", TokenKind::Whitespace => "whitespace", TokenKind::LineComment => "line comment", TokenKind::BlockComment { .. } => "block comment", } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_tokenization() { let input = "fn main() { let x = 42; }"; let mut lexer = Lexer::new(input); let tokens = lexer.tokenize(); assert_eq!(tokens[0].kind, TokenKind::Ident); assert_eq!(tokens[0].text, "fn"); assert_eq!(tokens[1].kind, TokenKind::Ident); assert_eq!(tokens[1].text, "main"); assert_eq!(tokens[2].kind, TokenKind::OpenParen); assert_eq!(tokens[3].kind, TokenKind::CloseParen); assert_eq!(tokens[4].kind, TokenKind::OpenBrace); } #[test] fn test_literals() { let input = r##"42 3.14 'a' b'x' "hello" b"bytes" r#"raw"#"##; let mut lexer = Lexer::new(input); let tokens = lexer.tokenize(); // Check that all are literals for token in &tokens { assert!(is_literal(token.kind)); } } #[test] fn test_trivia_handling() { let input = "// comment\nfn /* block */ main()"; let mut lexer = Lexer::new(input); // Without trivia let tokens = lexer.tokenize(); assert_eq!(tokens.len(), 4); // fn main ( ) // With trivia let mut lexer = Lexer::new(input); let tokens = lexer.tokenize_with_trivia(); assert!(tokens.len() > 4); // includes comments and whitespace } #[test] fn test_shebang() { let input = "#!/usr/bin/env rust\nfn main() {}"; let stripped = strip_shebang(input); // The newline is included after stripping the shebang assert!(stripped.starts_with("\nfn main()")); } #[test] fn test_literal_parsing() { let cases = vec![ ( LiteralKind::Int { base: Base::Decimal, empty_int: false, }, "42", ParsedLiteral::Int(42), ), ( LiteralKind::Int { base: Base::Hexadecimal, empty_int: false, }, "0xFF", ParsedLiteral::Int(255), ), ( LiteralKind::Float { base: Base::Decimal, empty_exponent: false, }, "3.14", ParsedLiteral::Float(3.14), ), ( LiteralKind::Char { terminated: true }, "'a'", ParsedLiteral::Char('a'), ), ]; for (kind, text, expected) in cases { let result = cook_lexer_literal(kind, text, 0).unwrap(); assert_eq!(result, expected); } } } pub fn cook_lexer_literal( kind: LiteralKind, text: &str, _start: usize, ) -> Result<ParsedLiteral, LiteralError> { match kind { LiteralKind::Int { base, empty_int } => { if empty_int { return Err(LiteralError::EmptyInt); } let text = text.replace('_', ""); let value = match base { Base::Binary => u128::from_str_radix(&text[2..], 2), Base::Octal => u128::from_str_radix(&text[2..], 8), Base::Decimal => text.parse(), Base::Hexadecimal => u128::from_str_radix(&text[2..], 16), }; match value { Ok(n) => Ok(ParsedLiteral::Int(n)), Err(_) => Err(LiteralError::IntegerOverflow), } } LiteralKind::Float { base, empty_exponent, } => { if empty_exponent { return Err(LiteralError::EmptyExponent); } if base != Base::Decimal { return Err(LiteralError::NonDecimalFloat); } let text = text.replace('_', ""); match text.parse() { Ok(f) => Ok(ParsedLiteral::Float(f)), Err(_) => Err(LiteralError::InvalidFloat), } } LiteralKind::Char { terminated } => { if !terminated { return Err(LiteralError::UnterminatedChar); } let content = &text[1..text.len() - 1]; let unescaped = unescape_char(content)?; Ok(ParsedLiteral::Char(unescaped)) } LiteralKind::Byte { terminated } => { if !terminated { return Err(LiteralError::UnterminatedByte); } let content = &text[2..text.len() - 1]; let unescaped = unescape_byte(content)?; Ok(ParsedLiteral::Byte(unescaped)) } LiteralKind::Str { terminated } => { if !terminated { return Err(LiteralError::UnterminatedString); } let content = &text[1..text.len() - 1]; let unescaped = unescape_string(content)?; Ok(ParsedLiteral::Str(unescaped)) } LiteralKind::ByteStr { terminated } => { if !terminated { return Err(LiteralError::UnterminatedByteString); } let content = &text[2..text.len() - 1]; let unescaped = unescape_byte_string(content)?; Ok(ParsedLiteral::ByteStr(unescaped)) } LiteralKind::RawStr { n_hashes, started, terminated, } => { if !started || !terminated { return Err(LiteralError::UnterminatedRawString); } let _hashes = "#".repeat(n_hashes); let start = 2 + n_hashes; let end = text.len() - n_hashes - 1; let content = text[start..end].to_string(); Ok(ParsedLiteral::RawStr(content)) } LiteralKind::RawByteStr { n_hashes, started, terminated, } => { if !started || !terminated { return Err(LiteralError::UnterminatedRawByteString); } let _hashes = "#".repeat(n_hashes); let start = 3 + n_hashes; let end = text.len() - n_hashes - 1; let content = text.as_bytes()[start..end].to_vec(); Ok(ParsedLiteral::RawByteStr(content)) } } }
This function handles all of Rust’s literal forms, including integer literals with different bases, floating-point numbers with scientific notation, character escapes, and various string literal types.
Trivia Handling
Comments and whitespace (collectively called “trivia”) can be preserved or discarded depending on the use case. Language servers need trivia for formatting, while parsers typically skip it.
This variant preserves all tokens including whitespace and comments, essential for tools that need to maintain source fidelity. The tokenize_with_trivia
method on the Lexer struct returns all tokens without filtering.
This variant preserves all tokens including whitespace and comments, essential for tools that need to maintain source fidelity.
Error Recovery
The lexer is designed for excellent error recovery, continuing to tokenize even when encountering invalid input. Unknown characters produce Unknown tokens rather than failing completely.
use std::ops::Range; use rustc_lexer::{self, Base, LiteralKind, TokenKind}; #[derive(Debug, Clone, PartialEq)] pub struct Token { pub kind: TokenKind, pub text: String, pub span: Range<usize>, } pub struct Lexer<'input> { input: &'input str, position: usize, } impl<'input> Lexer<'input> { pub fn new(input: &'input str) -> Self { Self { input, position: 0 } } pub fn tokenize(&mut self) -> Vec<Token> { let mut tokens = Vec::new(); while self.position < self.input.len() { let remaining = &self.input[self.position..]; let token = rustc_lexer::first_token(remaining); let start = self.position; let end = self.position + token.len as usize; let text = self.input[start..end].to_string(); // Skip whitespace and comments unless we're preserving them match token.kind { TokenKind::Whitespace | TokenKind::LineComment | TokenKind::BlockComment { .. } => { self.position = end; continue; } _ => {} } tokens.push(Token { kind: token.kind, text, span: start..end, }); self.position = end; } tokens } pub fn tokenize_with_trivia(&mut self) -> Vec<Token> { let mut tokens = Vec::new(); while self.position < self.input.len() { let remaining = &self.input[self.position..]; let token = rustc_lexer::first_token(remaining); let start = self.position; let end = self.position + token.len as usize; let text = self.input[start..end].to_string(); tokens.push(Token { kind: token.kind, text, span: start..end, }); self.position = end; } tokens } } pub fn strip_shebang(input: &str) -> &str { rustc_lexer::strip_shebang(input) .map(|shebang_len| &input[shebang_len..]) .unwrap_or(input) } pub fn cook_lexer_literal( kind: LiteralKind, text: &str, _start: usize, ) -> Result<ParsedLiteral, LiteralError> { match kind { LiteralKind::Int { base, empty_int } => { if empty_int { return Err(LiteralError::EmptyInt); } let text = text.replace('_', ""); let value = match base { Base::Binary => u128::from_str_radix(&text[2..], 2), Base::Octal => u128::from_str_radix(&text[2..], 8), Base::Decimal => text.parse(), Base::Hexadecimal => u128::from_str_radix(&text[2..], 16), }; match value { Ok(n) => Ok(ParsedLiteral::Int(n)), Err(_) => Err(LiteralError::IntegerOverflow), } } LiteralKind::Float { base, empty_exponent, } => { if empty_exponent { return Err(LiteralError::EmptyExponent); } if base != Base::Decimal { return Err(LiteralError::NonDecimalFloat); } let text = text.replace('_', ""); match text.parse() { Ok(f) => Ok(ParsedLiteral::Float(f)), Err(_) => Err(LiteralError::InvalidFloat), } } LiteralKind::Char { terminated } => { if !terminated { return Err(LiteralError::UnterminatedChar); } let content = &text[1..text.len() - 1]; let unescaped = unescape_char(content)?; Ok(ParsedLiteral::Char(unescaped)) } LiteralKind::Byte { terminated } => { if !terminated { return Err(LiteralError::UnterminatedByte); } let content = &text[2..text.len() - 1]; let unescaped = unescape_byte(content)?; Ok(ParsedLiteral::Byte(unescaped)) } LiteralKind::Str { terminated } => { if !terminated { return Err(LiteralError::UnterminatedString); } let content = &text[1..text.len() - 1]; let unescaped = unescape_string(content)?; Ok(ParsedLiteral::Str(unescaped)) } LiteralKind::ByteStr { terminated } => { if !terminated { return Err(LiteralError::UnterminatedByteString); } let content = &text[2..text.len() - 1]; let unescaped = unescape_byte_string(content)?; Ok(ParsedLiteral::ByteStr(unescaped)) } LiteralKind::RawStr { n_hashes, started, terminated, } => { if !started || !terminated { return Err(LiteralError::UnterminatedRawString); } let _hashes = "#".repeat(n_hashes); let start = 2 + n_hashes; let end = text.len() - n_hashes - 1; let content = text[start..end].to_string(); Ok(ParsedLiteral::RawStr(content)) } LiteralKind::RawByteStr { n_hashes, started, terminated, } => { if !started || !terminated { return Err(LiteralError::UnterminatedRawByteString); } let _hashes = "#".repeat(n_hashes); let start = 3 + n_hashes; let end = text.len() - n_hashes - 1; let content = text.as_bytes()[start..end].to_vec(); Ok(ParsedLiteral::RawByteStr(content)) } } } #[derive(Debug, Clone, PartialEq)] pub enum ParsedLiteral { Int(u128), Float(f64), Char(char), Byte(u8), Str(String), ByteStr(Vec<u8>), RawStr(String), RawByteStr(Vec<u8>), } #[derive(Debug, Clone, PartialEq)] pub enum LiteralError { EmptyInt, IntegerOverflow, EmptyExponent, NonDecimalFloat, InvalidFloat, UnterminatedChar, UnterminatedByte, UnterminatedString, UnterminatedByteString, UnterminatedRawString, UnterminatedRawByteString, InvalidEscape(String), } fn unescape_char(s: &str) -> Result<char, LiteralError> { if let Some(stripped) = s.strip_prefix('\\') { match stripped { "n" => Ok('\n'), "r" => Ok('\r'), "t" => Ok('\t'), "\\" => Ok('\\'), "'" => Ok('\''), "\"" => Ok('"'), "0" => Ok('\0'), _ => Err(LiteralError::InvalidEscape(s.to_string())), } } else if s.len() == 1 { Ok(s.chars().next().unwrap()) } else { Err(LiteralError::InvalidEscape(s.to_string())) } } fn unescape_byte(s: &str) -> Result<u8, LiteralError> { unescape_char(s).and_then(|c| { if c as u32 <= 255 { Ok(c as u8) } else { Err(LiteralError::InvalidEscape(s.to_string())) } }) } fn unescape_string(s: &str) -> Result<String, LiteralError> { let mut result = String::new(); let mut chars = s.chars(); while let Some(ch) = chars.next() { if ch == '\\' { if let Some(next) = chars.next() { match next { 'n' => result.push('\n'), 'r' => result.push('\r'), 't' => result.push('\t'), '\\' => result.push('\\'), '\'' => result.push('\''), '"' => result.push('"'), '0' => result.push('\0'), _ => return Err(LiteralError::InvalidEscape(format!("\\{}", next))), } } } else { result.push(ch); } } Ok(result) } fn unescape_byte_string(s: &str) -> Result<Vec<u8>, LiteralError> { unescape_string(s).map(|s| s.into_bytes()) } #[derive(Debug, Clone)] pub struct ValidationError { pub token_index: usize, pub kind: ValidationErrorKind, pub span: Range<usize>, } #[derive(Debug, Clone)] pub enum ValidationErrorKind { UnknownToken, InvalidLiteral(LiteralError), } pub fn is_whitespace(kind: TokenKind) -> bool { matches!(kind, TokenKind::Whitespace) } pub fn is_comment(kind: TokenKind) -> bool { matches!( kind, TokenKind::LineComment | TokenKind::BlockComment { .. } ) } pub fn is_literal(kind: TokenKind) -> bool { matches!(kind, TokenKind::Literal { .. }) } pub fn describe_token(kind: TokenKind) -> &'static str { match kind { TokenKind::Ident => "identifier", TokenKind::RawIdent => "raw identifier", TokenKind::Literal { kind, .. } => match kind { LiteralKind::Int { .. } => "integer literal", LiteralKind::Float { .. } => "float literal", LiteralKind::Char { .. } => "character literal", LiteralKind::Byte { .. } => "byte literal", LiteralKind::Str { .. } => "string literal", LiteralKind::ByteStr { .. } => "byte string literal", LiteralKind::RawStr { .. } => "raw string literal", LiteralKind::RawByteStr { .. } => "raw byte string literal", }, TokenKind::Lifetime { .. } => "lifetime", TokenKind::Semi => "semicolon", TokenKind::Comma => "comma", TokenKind::Dot => "dot", TokenKind::OpenParen => "open parenthesis", TokenKind::CloseParen => "close parenthesis", TokenKind::OpenBrace => "open brace", TokenKind::CloseBrace => "close brace", TokenKind::OpenBracket => "open bracket", TokenKind::CloseBracket => "close bracket", TokenKind::At => "at sign", TokenKind::Pound => "pound sign", TokenKind::Tilde => "tilde", TokenKind::Question => "question mark", TokenKind::Colon => "colon", TokenKind::Dollar => "dollar sign", TokenKind::Eq => "equals", TokenKind::Lt => "less than", TokenKind::Gt => "greater than", TokenKind::Minus => "minus", TokenKind::And => "ampersand", TokenKind::Or => "pipe", TokenKind::Plus => "plus", TokenKind::Star => "star", TokenKind::Slash => "slash", TokenKind::Caret => "caret", TokenKind::Percent => "percent", TokenKind::Unknown => "unknown token", TokenKind::Not => "exclamation mark", TokenKind::Whitespace => "whitespace", TokenKind::LineComment => "line comment", TokenKind::BlockComment { .. } => "block comment", } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_tokenization() { let input = "fn main() { let x = 42; }"; let mut lexer = Lexer::new(input); let tokens = lexer.tokenize(); assert_eq!(tokens[0].kind, TokenKind::Ident); assert_eq!(tokens[0].text, "fn"); assert_eq!(tokens[1].kind, TokenKind::Ident); assert_eq!(tokens[1].text, "main"); assert_eq!(tokens[2].kind, TokenKind::OpenParen); assert_eq!(tokens[3].kind, TokenKind::CloseParen); assert_eq!(tokens[4].kind, TokenKind::OpenBrace); } #[test] fn test_literals() { let input = r##"42 3.14 'a' b'x' "hello" b"bytes" r#"raw"#"##; let mut lexer = Lexer::new(input); let tokens = lexer.tokenize(); // Check that all are literals for token in &tokens { assert!(is_literal(token.kind)); } } #[test] fn test_trivia_handling() { let input = "// comment\nfn /* block */ main()"; let mut lexer = Lexer::new(input); // Without trivia let tokens = lexer.tokenize(); assert_eq!(tokens.len(), 4); // fn main ( ) // With trivia let mut lexer = Lexer::new(input); let tokens = lexer.tokenize_with_trivia(); assert!(tokens.len() > 4); // includes comments and whitespace } #[test] fn test_shebang() { let input = "#!/usr/bin/env rust\nfn main() {}"; let stripped = strip_shebang(input); // The newline is included after stripping the shebang assert!(stripped.starts_with("\nfn main()")); } #[test] fn test_literal_parsing() { let cases = vec![ ( LiteralKind::Int { base: Base::Decimal, empty_int: false, }, "42", ParsedLiteral::Int(42), ), ( LiteralKind::Int { base: Base::Hexadecimal, empty_int: false, }, "0xFF", ParsedLiteral::Int(255), ), ( LiteralKind::Float { base: Base::Decimal, empty_exponent: false, }, "3.14", ParsedLiteral::Float(3.14), ), ( LiteralKind::Char { terminated: true }, "'a'", ParsedLiteral::Char('a'), ), ]; for (kind, text, expected) in cases { let result = cook_lexer_literal(kind, text, 0).unwrap(); assert_eq!(result, expected); } } } pub fn tokenize_and_validate(input: &str) -> Result<Vec<Token>, Vec<ValidationError>> { let mut lexer = Lexer::new(input); let mut errors = Vec::new(); let tokens = lexer.tokenize_with_trivia(); for (i, token) in tokens.iter().enumerate() { match &token.kind { TokenKind::Unknown => { errors.push(ValidationError { token_index: i, kind: ValidationErrorKind::UnknownToken, span: token.span.clone(), }); } TokenKind::Literal { kind, .. } => { if let Err(e) = cook_lexer_literal(*kind, &token.text, token.span.start) { errors.push(ValidationError { token_index: i, kind: ValidationErrorKind::InvalidLiteral(e), span: token.span.clone(), }); } } _ => {} } } if errors.is_empty() { Ok(tokens) } else { Err(errors) } }
This function combines tokenization with validation, collecting all errors while still producing a complete token stream. This approach enables IDEs to provide multiple error markers simultaneously.
Raw Strings
Rust’s raw string literals require special handling due to their configurable delimiters. The lexer tracks the number of pound signs and validates proper termination.
The lexer correctly handles arbitrarily nested pound signs in raw strings, making it possible to include any content without escaping. This is particularly useful for embedding other languages or test data in Rust code.
Performance Characteristics
The rustc_lexer is highly optimized for the common case of valid Rust code. It uses table lookups for character classification and minimizes branching in hot paths. The lexer operates in linear time with respect to input size and performs no allocations during tokenization itself.
The cursor-based API allows for incremental lexing, where you can tokenize just a portion of the input or stop early based on some condition. This is crucial for responsive IDE experiences where files may be partially invalid during editing.
Integration Patterns
For building a parser, wrap the lexer in a token stream that provides lookahead:
The lexer integrates naturally with parser combinators or hand-written recursive descent parsers. Its error recovery properties ensure the parser always has tokens to work with, even for invalid input.
For syntax highlighting, process tokens with trivia and map token kinds to color categories. The lexer’s precise token classification enables accurate highlighting that matches rustc’s interpretation.
Best Practices
Cache the token stream when possible rather than re-lexing. While the lexer is fast, avoiding redundant work improves overall performance. For incremental scenarios, track which portions of the input have changed and re-lex only affected regions.
Validate literals in a separate pass rather than during lexing. This separation of concerns keeps the lexer simple and fast while allowing for better error messages during validation.
Handle both terminated and unterminated comments gracefully. IDEs need to provide reasonable behavior even when comments are unclosed, and the lexer’s design supports this requirement.
The rustc_lexer provides a solid foundation for Rust language tooling. Its battle-tested implementation handles all the edge cases that make Rust lexing challenging, from raw identifiers to complex numeric literals. By using the same lexer as rustc, tools can ensure compatibility with the official Rust implementation.
winnow
winnow is a parser combinator library that emphasizes performance, ergonomics, and flexibility. Building on lessons learned from nom and other parser libraries, winnow provides a streamlined API for constructing parsers from simple, composable pieces. The library uses a mutable reference approach that enables better error messages and more intuitive parser composition while maintaining excellent performance characteristics.
The library’s design philosophy centers on making the common case easy while keeping the complex possible. winnow parsers operate on mutable string slices, automatically advancing the input position as parsing proceeds. This approach eliminates the manual input management required by other libraries while providing clear semantics for parser composition and error handling.
Arithmetic Expression Parser
#![allow(unused)] fn main() { use winnow::ascii::{digit1, space0}; use winnow::combinator::{alt, delimited, preceded, repeat}; use winnow::{PResult, Parser}; #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(f64), Add(Box<Expr>, Box<Expr>), Sub(Box<Expr>, Box<Expr>), Mul(Box<Expr>, Box<Expr>), Div(Box<Expr>, Box<Expr>), Paren(Box<Expr>), } pub fn parse_expression(input: &str) -> Result<Expr, String> { expr.parse(input).map_err(|e| e.to_string()) } fn expr(input: &mut &str) -> PResult<Expr> { add_sub(input) } fn add_sub(input: &mut &str) -> PResult<Expr> { let init = mul_div(input)?; repeat(0.., ( delimited(space0, alt(('+', '-')), space0), mul_div )) .fold(move || init.clone(), |acc, (op, val)| { match op { '+' => Expr::Add(Box::new(acc), Box::new(val)), '-' => Expr::Sub(Box::new(acc), Box::new(val)), _ => unreachable!(), } }) .parse_next(input) } }
The expression parser demonstrates winnow’s approach to building precedence-aware parsers. The mutable reference parameter automatically tracks position in the input, eliminating the need to explicitly thread state through the parser. The fold combinator builds left-associative operations by accumulating results, transforming a sequence of operations into a properly structured AST.
The PResult type alias simplifies error handling while maintaining full error information. The parse_next method advances the input reference, consuming matched characters. This mutation-based approach provides cleaner composition semantics than returning remaining input, as each parser clearly consumes its portion of the input.
Factor and Number Parsing
#![allow(unused)] fn main() { fn mul_div(input: &mut &str) -> PResult<Expr> { let init = factor(input)?; repeat(0.., ( delimited(space0, alt(('*', '/')), space0), factor )) .fold(move || init.clone(), |acc, (op, val)| { match op { '*' => Expr::Mul(Box::new(acc), Box::new(val)), '/' => Expr::Div(Box::new(acc), Box::new(val)), _ => unreachable!(), } }) .parse_next(input) } fn factor(input: &mut &str) -> PResult<Expr> { alt(( number.map(Expr::Number), delimited('(', preceded(space0, expr), preceded(space0, ')')) .map(|e| Expr::Paren(Box::new(e))), )) .parse_next(input) } fn number(input: &mut &str) -> PResult<f64> { take_while(1.., |c: char| c.is_ascii_digit() || c == '.') .try_map(|s: &str| s.parse::<f64>()) .parse_next(input) } }
The factor parser handles both atomic numbers and parenthesized expressions, demonstrating recursive parser composition. The alt combinator tries alternatives in order, selecting the first successful match. The delimited combinator parses bracketed content while the preceded combinator handles leading whitespace, showing how winnow’s combinators compose naturally.
Number parsing uses take_while to consume digit characters and decimal points, then try_map to parse the string into a floating-point value. The try_map combinator propagates parsing errors properly, converting string parse failures into parser errors with appropriate context. This error handling ensures that invalid numbers produce meaningful error messages rather than panics.
JSON Parser
#![allow(unused)] fn main() { use winnow::token::take_till; #[derive(Debug, Clone, PartialEq)] pub enum Json { Null, Bool(bool), Number(f64), String(String), Array(Vec<Json>), Object(Vec<(String, Json)>), } fn json_value(input: &mut &str) -> PResult<Json> { delimited( space0, alt(( "null".value(Json::Null), "true".value(Json::Bool(true)), "false".value(Json::Bool(false)), json_number, json_string.map(Json::String), json_array, json_object, )), space0, ) .parse_next(input) } fn json_string(input: &mut &str) -> PResult<String> { delimited( '"', take_till(0.., '"').map(|s: &str| s.to_string()), '"', ) .parse_next(input) } }
The JSON parser showcases winnow’s elegant handling of heterogeneous data structures. The value method on string literals creates parsers that return constant values upon matching, eliminating boilerplate mapping functions. This approach makes literal parsing concise while maintaining type safety.
String parsing uses take_till to consume characters until encountering a delimiter. This combinator efficiently handles variable-length content without backtracking. The simplified string parser shown here would need escape sequence handling for production use, but demonstrates the core parsing approach.
Array and Object Parsing
#![allow(unused)] fn main() { use winnow::combinator::{separated, terminated}; fn json_array(input: &mut &str) -> PResult<Json> { delimited( '[', delimited( space0, separated(0.., json_value, delimited(space0, ',', space0)), space0, ), ']', ) .map(Json::Array) .parse_next(input) } fn json_object(input: &mut &str) -> PResult<Json> { delimited( '{', delimited( space0, separated(0.., json_member, delimited(space0, ',', space0)), space0, ), '}', ) .map(Json::Object) .parse_next(input) } fn json_member(input: &mut &str) -> PResult<(String, Json)> { ( terminated(json_string, delimited(space0, ':', space0)), json_value, ) .parse_next(input) } }
Array parsing demonstrates winnow’s separated combinator, which handles delimiter-separated lists with proper edge case handling. The nested delimited calls manage both the array brackets and internal whitespace, showing how combinators compose to handle complex formatting requirements. The 0.. range allows empty arrays while separated automatically handles trailing delimiter issues.
Object parsing combines multiple concepts including tuple parsing for key-value pairs. The terminated combinator discards the colon separator after parsing the key, while maintaining the key value. This design keeps parsing logic clear while building the exact data structure needed. The member parser returns tuples that map directly to the Object variant’s expected type.
S-Expression Parser
#![allow(unused)] fn main() { #[derive(Debug, Clone, PartialEq)] pub enum SExpr { Symbol(String), Number(i64), String(String), List(Vec<SExpr>), } fn sexpr_value(input: &mut &str) -> PResult<SExpr> { delimited( sexpr_ws, alt(( sexpr_number, sexpr_string, sexpr_symbol, sexpr_list, )), sexpr_ws, ) .parse_next(input) } fn sexpr_symbol(input: &mut &str) -> PResult<SExpr> { take_while(1.., |c: char| { c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '+' || c == '*' || c == '/' || c == '?' }) .map(|s: &str| SExpr::Symbol(s.to_string())) .parse_next(input) } fn sexpr_list(input: &mut &str) -> PResult<SExpr> { delimited( '(', repeat(0.., sexpr_value), ')', ) .map(SExpr::List) .parse_next(input) } }
The S-expression parser illustrates recursive data structure parsing with minimal complexity. Symbol parsing accepts standard LISP identifier characters, including operators that are treated as regular symbols in LISP syntax. The take_while combinator with a minimum count ensures symbols contain at least one character while accepting any valid symbol character.
List parsing recursively calls sexpr_value for each element, enabling arbitrary nesting depth. The repeat combinator with 0.. accepts empty lists, matching LISP’s treatment of () as a valid empty list. The clean separation between value parsing and whitespace handling keeps the grammar clear and maintainable.
Configuration Parser
#![allow(unused)] fn main() { use winnow::ascii::{alpha1, alphanumeric1}; #[derive(Debug, Clone, PartialEq)] pub struct Config { pub entries: Vec<ConfigEntry>, } #[derive(Debug, Clone, PartialEq)] pub struct ConfigEntry { pub key: String, pub value: ConfigValue, } fn config_entry(input: &mut &str) -> PResult<ConfigEntry> { let _ = config_ws(input)?; let key = config_key(input)?; let _ = config_ws(input)?; let _ = '='.parse_next(input)?; let _ = config_ws(input)?; let value = config_value(input)?; let _ = alt(('\n', '\r')).parse_next(input).ok(); Ok(ConfigEntry { key, value }) } fn config_key(input: &mut &str) -> PResult<String> { ( alpha1, take_while(0.., |c: char| c.is_ascii_alphanumeric() || c == '_' || c == '.' ), ) .recognize() .map(|s: &str| s.to_string()) .parse_next(input) } }
Configuration parsing demonstrates line-oriented parsing with winnow’s explicit control flow. The config_entry function manually sequences parsing steps, providing clear control over whitespace handling and error recovery. This explicit approach makes the parser’s behavior transparent, especially useful for formats where line boundaries matter.
The key parser uses recognize to capture the entire matched region as a string, avoiding the need to reconstruct the key from components. The tuple parser ensures keys start with a letter while allowing numbers and underscores in subsequent positions, enforcing common configuration file conventions.
URL Parser
#![allow(unused)] fn main() { #[derive(Debug, Clone, PartialEq)] pub struct Url { pub scheme: String, pub host: String, pub port: Option<u16>, pub path: String, pub query: Option<String>, pub fragment: Option<String>, } fn url(input: &mut &str) -> PResult<Url> { let scheme = terminated(alpha1, "://") .map(|s: &str| s.to_string()) .parse_next(input)?; let host = take_while(1.., |c: char| c.is_ascii_alphanumeric() || c == '.' || c == '-' ) .map(|s: &str| s.to_string()) .parse_next(input)?; let port = winnow::combinator::opt( preceded(':', digit1.try_map(|s: &str| s.parse::<u16>())) ) .parse_next(input)?; let path = alt(( take_while(1.., |c: char| c != '?' && c != '#') .map(|s: &str| s.to_string()), winnow::combinator::empty.value(String::from("/")), )) .parse_next(input)?; let query = winnow::combinator::opt( preceded('?', take_while(1.., |c: char| c != '#') .map(|s: &str| s.to_string())) ) .parse_next(input)?; let fragment = winnow::combinator::opt( preceded('#', winnow::combinator::rest .map(|s: &str| s.to_string())) ) .parse_next(input)?; Ok(Url { scheme, host, port, path, query, fragment, }) } }
URL parsing showcases sequential parsing with optional components. The scheme parser uses terminated to consume the :// separator while keeping only the scheme name. The opt combinator handles optional components like ports and query strings, returning None when the component is absent rather than failing the parse.
The path parser demonstrates fallback behavior using alt with empty, providing a default value when no path is specified. The rest combinator consumes all remaining input for the fragment, appropriate since fragments appear last in URLs. This structured approach handles the various optional components of URLs while maintaining parse correctness.
Error Context
#![allow(unused)] fn main() { use winnow::error::ContextError; fn number_with_context(input: &mut &str) -> PResult<f64> { take_while(1.., |c: char| c.is_ascii_digit() || c == '.') .try_map(|s: &str| s.parse::<f64>()) .context("number") .parse_next(input) } fn json_value_with_context(input: &mut &str) -> PResult<Json> { delimited( space0, alt(( "null".value(Json::Null).context("null"), "true".value(Json::Bool(true)).context("boolean"), "false".value(Json::Bool(false)).context("boolean"), number_with_context.map(Json::Number), json_string.map(Json::String).context("string"), json_array.context("array"), json_object.context("object"), )), space0, ) .parse_next(input) } }
winnow’s context system provides meaningful error messages by labeling parser components. The context method attaches a description to a parser, which appears in error messages when parsing fails. This labeling helps users understand what the parser expected at each position, crucial for debugging complex grammars.
Context annotations work throughout the parser hierarchy, with inner contexts providing specific details while outer contexts give broader structure. This layered approach produces error messages that guide users from high-level structure down to specific token requirements.
Stream Positioning
#![allow(unused)] fn main() { use winnow::stream::Checkpoint; fn parse_with_recovery(input: &mut &str) -> Vec<Json> { let mut results = Vec::new(); while !input.is_empty() { let checkpoint = input.checkpoint(); match json_value(input) { Ok(value) => results.push(value), Err(_) => { input.reset(&checkpoint); // Skip one character and try again if !input.is_empty() { *input = &input[1..]; } } } } results } }
winnow’s checkpoint system enables error recovery and backtracking when needed. Creating a checkpoint saves the current input position, allowing the parser to reset on failure. This mechanism supports fault-tolerant parsing where partial results remain valuable, such as in IDE syntax highlighting or data extraction from corrupted inputs.
The checkpoint approach provides explicit control over backtracking, making performance implications clear. Unlike implicit backtracking, checkpoints document where recovery might occur, helping maintain predictable parser performance.
Custom Input Types
#![allow(unused)] fn main() { use winnow::stream::Stream; use winnow::token::any; fn parse_tokens<'a>(tokens: &mut &'a [Token]) -> PResult<Ast> { let token = any.verify(|t: &Token| t.kind == TokenKind::Identifier); // Continue building parser with token-based input } }
winnow supports custom input types through the Stream trait, enabling parsing of pre-tokenized input or binary formats. Token-based parsing separates lexical analysis from syntax analysis, often improving performance and error messages. The any combinator consumes single elements from any stream type, while verify adds conditions without consuming extra input.
Custom streams enable specialized parsing scenarios like network protocols with length-prefixed fields or ast transformations where the input is already structured. The consistent combinator interface works across all stream types, allowing parser logic to remain unchanged when switching input representations.
Performance Considerations
#![allow(unused)] fn main() { use winnow::combinator::cut_err; fn json_array_cut(input: &mut &str) -> PResult<Json> { delimited( '[', cut_err(delimited( space0, separated(0.., json_value, delimited(space0, ',', space0)), space0, )), ']', ) .map(Json::Array) .parse_next(input) } }
The cut_err combinator improves performance by preventing backtracking after a certain point. Once the opening bracket is matched, cut_err commits to parsing an array, eliminating unnecessary backtracking attempts. This optimization significantly improves performance for deeply nested structures while providing clearer error messages.
winnow’s mutable reference approach eliminates allocation overhead associated with returning remaining input. Parsers operate directly on string slices without copying, maintaining zero-copy parsing throughout. The fold combinator builds results incrementally without intermediate collections, reducing memory allocation in repetitive parsers.
Testing Strategies
#![allow(unused)] fn main() { #[cfg(test)] mod tests { use super::*; #[test] fn test_expression_evaluation() { let cases = vec![ ("42", 42.0), ("1 + 2", 3.0), ("1 + 2 * 3", 7.0), ("(1 + 2) * 3", 9.0), ]; for (input, expected) in cases { let expr = parse_expression(input).unwrap(); assert_eq!(expr.eval(), expected); } } #[test] fn test_partial_parse() { let mut input = "123 extra"; let num = number(&mut input).unwrap(); assert_eq!(num, 123.0); assert_eq!(input, " extra"); } #[test] fn test_error_recovery() { let mut input = "[1, invalid, 3]"; let result = parse_with_recovery(&mut input); assert!(!result.is_empty()); } } }
Testing winnow parsers involves validating both complete and partial parsing scenarios. The mutable reference approach makes it easy to test partial parsing, verifying that parsers consume exactly the expected input. Tests can examine the remaining input after parsing, ensuring parsers stop at appropriate boundaries.
Error recovery testing validates that parsers handle malformed input gracefully. Recovery tests ensure parsers extract valid portions from partially correct input, important for tooling that must handle incomplete or incorrect code. The checkpoint system makes it straightforward to test recovery strategies.
Best Practices
Design parsers with clear separation between lexical and syntactic concerns. Use whitespace handling combinators consistently throughout the grammar rather than embedding whitespace in every parser. This separation simplifies both the grammar and error messages while making the parser’s intent clearer.
Leverage winnow’s mutable reference model for cleaner parser composition. The automatic position tracking eliminates manual state threading while making parser behavior more predictable. Use the checkpoint system sparingly, only where error recovery provides clear value.
Apply context annotations throughout the parser to improve error messages. Good context labels describe what the parser expects in user terms rather than implementation details. Layer contexts from general to specific, providing users with navigable error information.
Choose appropriate combinators for each parsing task. Use cut_err to commit to parse paths once initial markers are recognized. Apply fold for building accumulative results without intermediate collections. Select separated for delimiter-separated lists rather than manually handling separators.
winnow provides an elegant and performant approach to parsing that balances simplicity with power. Its mutable reference model and thoughtful combinator design create parsers that are both efficient and maintainable. The library’s focus on ergonomics and error reporting makes it an excellent choice for building robust parsers for compilers, data formats, and domain-specific languages.
nom_locate
The nom_locate
crate extends nom’s parser combinators with precise source location tracking. While nom excels at building fast, composable parsers, it doesn’t inherently track where in the source text each parse occurred. For compiler construction, this location information is crucial - every error message, warning, and diagnostic needs to point users to the exact position in their source code. nom_locate solves this by wrapping input slices with location metadata that flows through the parsing process.
Location tracking enables compilers to provide helpful error messages that show not just what went wrong, but exactly where. It also supports advanced IDE features like go-to-definition, hover information, and refactoring tools that need to map between source text and AST nodes. The crate integrates seamlessly with nom’s existing combinators while adding minimal overhead to parsing performance.
Core Types
The foundation of nom_locate is the LocatedSpan
type, typically aliased for convenience:
#![allow(unused)] fn main() { type Span<'a> = LocatedSpan<&'a str>; }
#![allow(unused)] fn main() { use std::ops::Range; use nom::branch::alt; use nom::bytes::complete::{tag, take_while, take_while1}; use nom::character::complete::{char, digit1, multispace0, multispace1}; use nom::combinator::{map, recognize}; use nom::error::{Error, ErrorKind}; use nom::multi::separated_list0; use nom::sequence::{delimited, pair, preceded}; use nom::{Err, IResult, Parser as NomParser}; use nom_locate::{position, LocatedSpan}; /// A span type that tracks position information pub type Span<'a> = LocatedSpan<&'a str>; impl Location { pub fn from_span(span: Span<'_>) -> Self { Self { line: span.location_line(), column: span.get_utf8_column(), offset: span.location_offset(), } } } /// A range of source locations #[derive(Debug, Clone, PartialEq)] pub struct SourceRange { pub start: Location, pub end: Location, } impl SourceRange { pub fn from_spans(start: Span<'_>, end: Span<'_>) -> Self { Self { start: Location::from_span(start), end: Location::from_span(end), } } pub fn to_range(&self) -> Range<usize> { self.start.offset..self.end.offset } } /// AST node with location information #[derive(Debug, Clone, PartialEq)] pub struct Spanned<T> { pub node: T, pub span: SourceRange, } impl<T> Spanned<T> { pub fn new(node: T, start: Span<'_>, end: Span<'_>) -> Self { Self { node, span: SourceRange::from_spans(start, end), } } } /// Expression AST for a simple language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Identifier(String), Binary { left: Box<Spanned<Expr>>, op: BinaryOp, right: Box<Spanned<Expr>>, }, Call { func: Box<Spanned<Expr>>, args: Vec<Spanned<Expr>>, }, Let { name: String, value: Box<Spanned<Expr>>, body: Box<Spanned<Expr>>, }, } #[derive(Debug, Clone, PartialEq)] pub enum BinaryOp { Add, Sub, Mul, Div, Eq, Lt, Gt, } /// Parser error with precise location information #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub location: Location, pub expected: Vec<String>, } impl ParseError { pub fn from_nom_error(_input: Span<'_>, error: Error<Span<'_>>) -> Self { let location = Location::from_span(error.input); let message = match error.code { ErrorKind::Tag => "unexpected token".to_string(), ErrorKind::Digit => "expected number".to_string(), ErrorKind::Alpha => "expected identifier".to_string(), ErrorKind::Char => "expected character".to_string(), ErrorKind::Many0 => "expected list".to_string(), ErrorKind::SeparatedList => "expected comma-separated list".to_string(), _ => format!("parse error: {:?}", error.code), }; Self { message, location, expected: vec![], } } } /// Main parser implementation pub struct Parser; impl Parser { /// Parse a complete expression from input pub fn parse_expression(input: &str) -> Result<Spanned<Expr>, ParseError> { let span = Span::new(input); match Self::expression(span) { Ok((_, expr)) => Ok(expr), Err(Err::Error(e)) | Err(Err::Failure(e)) => Err(ParseError::from_nom_error(span, e)), Err(Err::Incomplete(_)) => Err(ParseError { message: "incomplete input".to_string(), location: Location::from_span(span), expected: vec!["more input".to_string()], }), } } /// Parse an expression with precedence fn expression(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { Self::binary_expr(input, 0) } /// Parse binary expressions with operator precedence fn binary_expr(input: Span<'_>, min_prec: u8) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, mut left) = Self::primary_expr(input)?; let mut current_input = input; loop { let (input, _) = multispace0(current_input)?; // Try to parse an operator let op_result: IResult<Span<'_>, (BinaryOp, u8)> = alt(( map(char('+'), |_| (BinaryOp::Add, 1)), map(char('-'), |_| (BinaryOp::Sub, 1)), map(char('*'), |_| (BinaryOp::Mul, 2)), map(char('/'), |_| (BinaryOp::Div, 2)), map(tag("=="), |_| (BinaryOp::Eq, 0)), map(char('<'), |_| (BinaryOp::Lt, 0)), map(char('>'), |_| (BinaryOp::Gt, 0)), )) .parse(input); match op_result { Ok((input, (op, prec))) if prec >= min_prec => { let (input, _) = multispace0(input)?; let (input, right) = Self::binary_expr(input, prec + 1)?; let end_span = position(input)?; left = Spanned::new( Expr::Binary { left: Box::new(left), op, right: Box::new(right), }, start_pos.0, end_span.0, ); current_input = input; } _ => break, } } Ok((current_input, left)) } /// Parse primary expressions (atoms and parenthesized expressions) fn primary_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let (input, _) = multispace0(input)?; alt(( Self::parenthesized_expr, Self::function_call, Self::let_expr, Self::number, Self::identifier, )) .parse(input) } /// Parse parenthesized expressions fn parenthesized_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, expr) = delimited( char('('), preceded(multispace0, Self::expression), preceded(multispace0, char(')')), ) .parse(input)?; let end_pos = position(input)?; Ok((input, Spanned::new(expr.node, start_pos.0, end_pos.0))) } /// Parse function calls fn function_call(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, func) = Self::identifier(input)?; let (input, _) = multispace0(input)?; let (input, _) = char('(')(input)?; let (input, _) = multispace0(input)?; let (input, args) = separated_list0( delimited(multispace0, char(','), multispace0), Self::expression, ) .parse(input)?; let (input, _) = multispace0(input)?; let (input, _) = char(')')(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new( Expr::Call { func: Box::new(func), args, }, start_pos.0, end_pos.0, ), )) } /// Parse let expressions fn let_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, _) = tag("let")(input)?; let (input, _) = multispace1(input)?; let (input, name) = Self::identifier_string(input)?; let (input, _) = multispace0(input)?; let (input, _) = char('=')(input)?; let (input, _) = multispace0(input)?; let (input, value) = Self::expression(input)?; let (input, _) = multispace0(input)?; let (input, _) = tag("in")(input)?; let (input, _) = multispace0(input)?; let (input, body) = Self::expression(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new( Expr::Let { name, value: Box::new(value), body: Box::new(body), }, start_pos.0, end_pos.0, ), )) } /// Parse numbers fn number(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, digits) = digit1(input)?; let end_pos = position(input)?; let num = digits .fragment() .parse() .map_err(|_| Err::Error(Error::new(input, ErrorKind::Digit)))?; Ok(( input, Spanned::new(Expr::Number(num), start_pos.0, end_pos.0), )) } /// Parse identifiers fn identifier(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, ident) = Self::identifier_string(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new(Expr::Identifier(ident), start_pos.0, end_pos.0), )) } /// Parse identifier strings fn identifier_string(input: Span<'_>) -> IResult<Span<'_>, String> { let (input, ident) = recognize(pair( alt((tag("_"), take_while1(|c: char| c.is_ascii_alphabetic()))), take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), )) .parse(input)?; Ok((input, ident.fragment().to_string())) } /// Get position information for error reporting pub fn get_position_info(input: &str, offset: usize) -> Option<(u32, usize)> { if offset <= input.len() { // Count lines and column up to the offset let mut line = 1; let mut col = 1; for (i, ch) in input.char_indices() { if i >= offset { break; } if ch == '\n' { line += 1; col = 1; } else { col += 1; } } Some((line, col)) } else { None } } /// Extract line content for error reporting pub fn get_line_content(input: &str, line_number: u32) -> Option<&str> { input.lines().nth(line_number.saturating_sub(1) as usize) } } /// A lexer that preserves location information for each token pub struct LocatedLexer<'a> { input: Span<'a>, } #[derive(Debug, Clone, PartialEq)] pub struct LocatedToken { pub kind: TokenKind, pub location: Location, pub text: String, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Number, Identifier, Keyword(String), Operator(String), LeftParen, RightParen, Comma, Equals, Whitespace, Comment, Eof, } impl<'a> LocatedLexer<'a> { pub fn new(input: &'a str) -> Self { Self { input: Span::new(input), } } /// Tokenize input while preserving location information pub fn tokenize(&mut self) -> Result<Vec<LocatedToken>, ParseError> { let mut tokens = Vec::new(); let mut current = self.input; while !current.fragment().is_empty() { let _start_pos = position(current).map_err(|e| { ParseError::from_nom_error( current, match e { Err::Error(err) | Err::Failure(err) => err, Err::Incomplete(_) => Error::new(current, ErrorKind::Complete), }, ) })?; let (remaining, token) = self.next_token(current).map_err(|e| match e { Err::Error(err) | Err::Failure(err) => ParseError::from_nom_error(current, err), Err::Incomplete(_) => ParseError { message: "incomplete token".to_string(), location: Location::from_span(current), expected: vec!["complete token".to_string()], }, })?; if let Some(token) = token { tokens.push(token); } current = remaining; } tokens.push(LocatedToken { kind: TokenKind::Eof, location: Location::from_span(current), text: String::new(), }); Ok(tokens) } fn next_token(&self, input: Span<'a>) -> IResult<Span<'a>, Option<LocatedToken>> { alt(( map(|i| self.whitespace_or_comment(i), |_| None), map(|i| self.keyword_or_identifier(i), Some), map(|i| self.number_token(i), Some), map(|i| self.operator_token(i), Some), map(|i| self.punctuation_token(i), Some), )) .parse(input) } fn whitespace_or_comment(&self, input: Span<'a>) -> IResult<Span<'a>, ()> { let (input, _) = alt(( multispace1, recognize((tag("//"), take_while(|c| c != '\n'))), )) .parse(input)?; Ok((input, ())) } fn keyword_or_identifier(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, ident) = recognize(pair( alt((tag("_"), take_while1(|c: char| c.is_ascii_alphabetic()))), take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), )) .parse(input)?; let text = ident.fragment().to_string(); let kind = match text.as_str() { "let" | "in" | "if" | "then" | "else" | "fn" => TokenKind::Keyword(text.clone()), _ => TokenKind::Identifier, }; Ok(( input, LocatedToken { kind, location: Location::from_span(start_pos.0), text, }, )) } fn number_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, number) = digit1(input)?; Ok(( input, LocatedToken { kind: TokenKind::Number, location: Location::from_span(start_pos.0), text: number.fragment().to_string(), }, )) } fn operator_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, op) = alt(( tag("=="), tag("<="), tag(">="), tag("!="), tag("+"), tag("-"), tag("*"), tag("/"), tag("<"), tag(">"), )) .parse(input)?; Ok(( input, LocatedToken { kind: TokenKind::Operator(op.fragment().to_string()), location: Location::from_span(start_pos.0), text: op.fragment().to_string(), }, )) } fn punctuation_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, punct) = alt(( map(char('('), |_| TokenKind::LeftParen), map(char(')'), |_| TokenKind::RightParen), map(char(','), |_| TokenKind::Comma), map(char('='), |_| TokenKind::Equals), )) .parse(input)?; let text = match punct { TokenKind::LeftParen => "(".to_string(), TokenKind::RightParen => ")".to_string(), TokenKind::Comma => ",".to_string(), TokenKind::Equals => "=".to_string(), _ => String::new(), }; Ok(( input, LocatedToken { kind: punct, location: Location::from_span(start_pos.0), text, }, )) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_location_tracking() { let input = "let x = 42\nin y + 1"; let span = Span::new(input); // Test line and column calculation assert_eq!(span.location_line(), 1); assert_eq!(span.get_utf8_column(), 1); // Test position after advancing let (remaining, _): (Span, Span) = tag::<&str, Span, Error<Span>>("let")(span).unwrap(); assert_eq!(remaining.location_offset(), 3); } #[test] fn test_number_parsing() { let result = Parser::parse_expression("42").unwrap(); assert_eq!(result.node, Expr::Number(42)); assert_eq!(result.span.start.line, 1); assert_eq!(result.span.start.column, 1); } #[test] fn test_binary_expression() { let result = Parser::parse_expression("2 + 3 * 4").unwrap(); if let Expr::Binary { left, op, right } = result.node { assert_eq!(op, BinaryOp::Add); assert_eq!(left.node, Expr::Number(2)); if let Expr::Binary { left, op, right } = right.node { assert_eq!(op, BinaryOp::Mul); assert_eq!(left.node, Expr::Number(3)); assert_eq!(right.node, Expr::Number(4)); } else { panic!("Expected binary expression for multiplication"); } } else { panic!("Expected binary expression"); } } #[test] fn test_function_call_parsing() { let result = Parser::parse_expression("add(1, 2)").unwrap(); if let Expr::Call { func, args } = result.node { assert_eq!(func.node, Expr::Identifier("add".to_string())); assert_eq!(args.len(), 2); assert_eq!(args[0].node, Expr::Number(1)); assert_eq!(args[1].node, Expr::Number(2)); } else { panic!("Expected function call"); } } #[test] fn test_let_expression() { let result = Parser::parse_expression("let x = 5 in x + 1").unwrap(); if let Expr::Let { name, value, body } = result.node { assert_eq!(name, "x"); assert_eq!(value.node, Expr::Number(5)); if let Expr::Binary { left, op, right } = body.node { assert_eq!(op, BinaryOp::Add); assert_eq!(left.node, Expr::Identifier("x".to_string())); assert_eq!(right.node, Expr::Number(1)); } else { panic!("Expected binary expression in let body"); } } else { panic!("Expected let expression"); } } #[test] fn test_error_location() { let result = Parser::parse_expression("2 + "); assert!(result.is_err()); let error = result.unwrap_err(); assert_eq!(error.location.line, 1); assert_eq!(error.location.column, 5); } #[test] fn test_lexer_with_locations() { let mut lexer = LocatedLexer::new("let x = 42"); let tokens = lexer.tokenize().unwrap(); assert_eq!(tokens.len(), 5); // let, x, =, 42, EOF assert_eq!(tokens[0].kind, TokenKind::Keyword("let".to_string())); assert_eq!(tokens[0].location.column, 1); assert_eq!(tokens[1].kind, TokenKind::Identifier); assert_eq!(tokens[1].text, "x"); assert_eq!(tokens[1].location.column, 5); assert_eq!(tokens[2].kind, TokenKind::Equals); assert_eq!(tokens[2].location.column, 7); } #[test] fn test_multiline_locations() { let input = "let x = 1\nlet y = 2"; let mut lexer = LocatedLexer::new(input); let tokens = lexer.tokenize().unwrap(); // Find the second 'let' token let second_let = tokens .iter() .find(|t| { matches!(t.kind, TokenKind::Keyword(ref k) if k == "let") && t.location.line == 2 }) .expect("Should find second let token"); assert_eq!(second_let.location.line, 2); assert_eq!(second_let.location.column, 1); } #[test] fn test_position_helpers() { let input = "line1\nline2\nline3"; let pos_info = Parser::get_position_info(input, 7); // Position of 'l' in "line2" assert_eq!(pos_info, Some((2, 2))); let line_content = Parser::get_line_content(input, 2); assert_eq!(line_content, Some("line2")); } #[test] fn test_source_range() { let result = Parser::parse_expression("42 + 3").unwrap(); let range = result.span.to_range(); assert_eq!(range, 0..6); } } /// Location information extracted from a span #[derive(Debug, Clone, PartialEq)] pub struct Location { pub line: u32, pub column: usize, pub offset: usize, } }
#![allow(unused)] fn main() { use std::ops::Range; use nom::branch::alt; use nom::bytes::complete::{tag, take_while, take_while1}; use nom::character::complete::{char, digit1, multispace0, multispace1}; use nom::combinator::{map, recognize}; use nom::error::{Error, ErrorKind}; use nom::multi::separated_list0; use nom::sequence::{delimited, pair, preceded}; use nom::{Err, IResult, Parser as NomParser}; use nom_locate::{position, LocatedSpan}; /// A span type that tracks position information pub type Span<'a> = LocatedSpan<&'a str>; /// Location information extracted from a span #[derive(Debug, Clone, PartialEq)] pub struct Location { pub line: u32, pub column: usize, pub offset: usize, } impl Location { pub fn from_span(span: Span<'_>) -> Self { Self { line: span.location_line(), column: span.get_utf8_column(), offset: span.location_offset(), } } } impl SourceRange { pub fn from_spans(start: Span<'_>, end: Span<'_>) -> Self { Self { start: Location::from_span(start), end: Location::from_span(end), } } pub fn to_range(&self) -> Range<usize> { self.start.offset..self.end.offset } } /// AST node with location information #[derive(Debug, Clone, PartialEq)] pub struct Spanned<T> { pub node: T, pub span: SourceRange, } impl<T> Spanned<T> { pub fn new(node: T, start: Span<'_>, end: Span<'_>) -> Self { Self { node, span: SourceRange::from_spans(start, end), } } } /// Expression AST for a simple language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Identifier(String), Binary { left: Box<Spanned<Expr>>, op: BinaryOp, right: Box<Spanned<Expr>>, }, Call { func: Box<Spanned<Expr>>, args: Vec<Spanned<Expr>>, }, Let { name: String, value: Box<Spanned<Expr>>, body: Box<Spanned<Expr>>, }, } #[derive(Debug, Clone, PartialEq)] pub enum BinaryOp { Add, Sub, Mul, Div, Eq, Lt, Gt, } /// Parser error with precise location information #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub location: Location, pub expected: Vec<String>, } impl ParseError { pub fn from_nom_error(_input: Span<'_>, error: Error<Span<'_>>) -> Self { let location = Location::from_span(error.input); let message = match error.code { ErrorKind::Tag => "unexpected token".to_string(), ErrorKind::Digit => "expected number".to_string(), ErrorKind::Alpha => "expected identifier".to_string(), ErrorKind::Char => "expected character".to_string(), ErrorKind::Many0 => "expected list".to_string(), ErrorKind::SeparatedList => "expected comma-separated list".to_string(), _ => format!("parse error: {:?}", error.code), }; Self { message, location, expected: vec![], } } } /// Main parser implementation pub struct Parser; impl Parser { /// Parse a complete expression from input pub fn parse_expression(input: &str) -> Result<Spanned<Expr>, ParseError> { let span = Span::new(input); match Self::expression(span) { Ok((_, expr)) => Ok(expr), Err(Err::Error(e)) | Err(Err::Failure(e)) => Err(ParseError::from_nom_error(span, e)), Err(Err::Incomplete(_)) => Err(ParseError { message: "incomplete input".to_string(), location: Location::from_span(span), expected: vec!["more input".to_string()], }), } } /// Parse an expression with precedence fn expression(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { Self::binary_expr(input, 0) } /// Parse binary expressions with operator precedence fn binary_expr(input: Span<'_>, min_prec: u8) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, mut left) = Self::primary_expr(input)?; let mut current_input = input; loop { let (input, _) = multispace0(current_input)?; // Try to parse an operator let op_result: IResult<Span<'_>, (BinaryOp, u8)> = alt(( map(char('+'), |_| (BinaryOp::Add, 1)), map(char('-'), |_| (BinaryOp::Sub, 1)), map(char('*'), |_| (BinaryOp::Mul, 2)), map(char('/'), |_| (BinaryOp::Div, 2)), map(tag("=="), |_| (BinaryOp::Eq, 0)), map(char('<'), |_| (BinaryOp::Lt, 0)), map(char('>'), |_| (BinaryOp::Gt, 0)), )) .parse(input); match op_result { Ok((input, (op, prec))) if prec >= min_prec => { let (input, _) = multispace0(input)?; let (input, right) = Self::binary_expr(input, prec + 1)?; let end_span = position(input)?; left = Spanned::new( Expr::Binary { left: Box::new(left), op, right: Box::new(right), }, start_pos.0, end_span.0, ); current_input = input; } _ => break, } } Ok((current_input, left)) } /// Parse primary expressions (atoms and parenthesized expressions) fn primary_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let (input, _) = multispace0(input)?; alt(( Self::parenthesized_expr, Self::function_call, Self::let_expr, Self::number, Self::identifier, )) .parse(input) } /// Parse parenthesized expressions fn parenthesized_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, expr) = delimited( char('('), preceded(multispace0, Self::expression), preceded(multispace0, char(')')), ) .parse(input)?; let end_pos = position(input)?; Ok((input, Spanned::new(expr.node, start_pos.0, end_pos.0))) } /// Parse function calls fn function_call(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, func) = Self::identifier(input)?; let (input, _) = multispace0(input)?; let (input, _) = char('(')(input)?; let (input, _) = multispace0(input)?; let (input, args) = separated_list0( delimited(multispace0, char(','), multispace0), Self::expression, ) .parse(input)?; let (input, _) = multispace0(input)?; let (input, _) = char(')')(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new( Expr::Call { func: Box::new(func), args, }, start_pos.0, end_pos.0, ), )) } /// Parse let expressions fn let_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, _) = tag("let")(input)?; let (input, _) = multispace1(input)?; let (input, name) = Self::identifier_string(input)?; let (input, _) = multispace0(input)?; let (input, _) = char('=')(input)?; let (input, _) = multispace0(input)?; let (input, value) = Self::expression(input)?; let (input, _) = multispace0(input)?; let (input, _) = tag("in")(input)?; let (input, _) = multispace0(input)?; let (input, body) = Self::expression(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new( Expr::Let { name, value: Box::new(value), body: Box::new(body), }, start_pos.0, end_pos.0, ), )) } /// Parse numbers fn number(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, digits) = digit1(input)?; let end_pos = position(input)?; let num = digits .fragment() .parse() .map_err(|_| Err::Error(Error::new(input, ErrorKind::Digit)))?; Ok(( input, Spanned::new(Expr::Number(num), start_pos.0, end_pos.0), )) } /// Parse identifiers fn identifier(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, ident) = Self::identifier_string(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new(Expr::Identifier(ident), start_pos.0, end_pos.0), )) } /// Parse identifier strings fn identifier_string(input: Span<'_>) -> IResult<Span<'_>, String> { let (input, ident) = recognize(pair( alt((tag("_"), take_while1(|c: char| c.is_ascii_alphabetic()))), take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), )) .parse(input)?; Ok((input, ident.fragment().to_string())) } /// Get position information for error reporting pub fn get_position_info(input: &str, offset: usize) -> Option<(u32, usize)> { if offset <= input.len() { // Count lines and column up to the offset let mut line = 1; let mut col = 1; for (i, ch) in input.char_indices() { if i >= offset { break; } if ch == '\n' { line += 1; col = 1; } else { col += 1; } } Some((line, col)) } else { None } } /// Extract line content for error reporting pub fn get_line_content(input: &str, line_number: u32) -> Option<&str> { input.lines().nth(line_number.saturating_sub(1) as usize) } } /// A lexer that preserves location information for each token pub struct LocatedLexer<'a> { input: Span<'a>, } #[derive(Debug, Clone, PartialEq)] pub struct LocatedToken { pub kind: TokenKind, pub location: Location, pub text: String, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Number, Identifier, Keyword(String), Operator(String), LeftParen, RightParen, Comma, Equals, Whitespace, Comment, Eof, } impl<'a> LocatedLexer<'a> { pub fn new(input: &'a str) -> Self { Self { input: Span::new(input), } } /// Tokenize input while preserving location information pub fn tokenize(&mut self) -> Result<Vec<LocatedToken>, ParseError> { let mut tokens = Vec::new(); let mut current = self.input; while !current.fragment().is_empty() { let _start_pos = position(current).map_err(|e| { ParseError::from_nom_error( current, match e { Err::Error(err) | Err::Failure(err) => err, Err::Incomplete(_) => Error::new(current, ErrorKind::Complete), }, ) })?; let (remaining, token) = self.next_token(current).map_err(|e| match e { Err::Error(err) | Err::Failure(err) => ParseError::from_nom_error(current, err), Err::Incomplete(_) => ParseError { message: "incomplete token".to_string(), location: Location::from_span(current), expected: vec!["complete token".to_string()], }, })?; if let Some(token) = token { tokens.push(token); } current = remaining; } tokens.push(LocatedToken { kind: TokenKind::Eof, location: Location::from_span(current), text: String::new(), }); Ok(tokens) } fn next_token(&self, input: Span<'a>) -> IResult<Span<'a>, Option<LocatedToken>> { alt(( map(|i| self.whitespace_or_comment(i), |_| None), map(|i| self.keyword_or_identifier(i), Some), map(|i| self.number_token(i), Some), map(|i| self.operator_token(i), Some), map(|i| self.punctuation_token(i), Some), )) .parse(input) } fn whitespace_or_comment(&self, input: Span<'a>) -> IResult<Span<'a>, ()> { let (input, _) = alt(( multispace1, recognize((tag("//"), take_while(|c| c != '\n'))), )) .parse(input)?; Ok((input, ())) } fn keyword_or_identifier(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, ident) = recognize(pair( alt((tag("_"), take_while1(|c: char| c.is_ascii_alphabetic()))), take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), )) .parse(input)?; let text = ident.fragment().to_string(); let kind = match text.as_str() { "let" | "in" | "if" | "then" | "else" | "fn" => TokenKind::Keyword(text.clone()), _ => TokenKind::Identifier, }; Ok(( input, LocatedToken { kind, location: Location::from_span(start_pos.0), text, }, )) } fn number_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, number) = digit1(input)?; Ok(( input, LocatedToken { kind: TokenKind::Number, location: Location::from_span(start_pos.0), text: number.fragment().to_string(), }, )) } fn operator_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, op) = alt(( tag("=="), tag("<="), tag(">="), tag("!="), tag("+"), tag("-"), tag("*"), tag("/"), tag("<"), tag(">"), )) .parse(input)?; Ok(( input, LocatedToken { kind: TokenKind::Operator(op.fragment().to_string()), location: Location::from_span(start_pos.0), text: op.fragment().to_string(), }, )) } fn punctuation_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, punct) = alt(( map(char('('), |_| TokenKind::LeftParen), map(char(')'), |_| TokenKind::RightParen), map(char(','), |_| TokenKind::Comma), map(char('='), |_| TokenKind::Equals), )) .parse(input)?; let text = match punct { TokenKind::LeftParen => "(".to_string(), TokenKind::RightParen => ")".to_string(), TokenKind::Comma => ",".to_string(), TokenKind::Equals => "=".to_string(), _ => String::new(), }; Ok(( input, LocatedToken { kind: punct, location: Location::from_span(start_pos.0), text, }, )) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_location_tracking() { let input = "let x = 42\nin y + 1"; let span = Span::new(input); // Test line and column calculation assert_eq!(span.location_line(), 1); assert_eq!(span.get_utf8_column(), 1); // Test position after advancing let (remaining, _): (Span, Span) = tag::<&str, Span, Error<Span>>("let")(span).unwrap(); assert_eq!(remaining.location_offset(), 3); } #[test] fn test_number_parsing() { let result = Parser::parse_expression("42").unwrap(); assert_eq!(result.node, Expr::Number(42)); assert_eq!(result.span.start.line, 1); assert_eq!(result.span.start.column, 1); } #[test] fn test_binary_expression() { let result = Parser::parse_expression("2 + 3 * 4").unwrap(); if let Expr::Binary { left, op, right } = result.node { assert_eq!(op, BinaryOp::Add); assert_eq!(left.node, Expr::Number(2)); if let Expr::Binary { left, op, right } = right.node { assert_eq!(op, BinaryOp::Mul); assert_eq!(left.node, Expr::Number(3)); assert_eq!(right.node, Expr::Number(4)); } else { panic!("Expected binary expression for multiplication"); } } else { panic!("Expected binary expression"); } } #[test] fn test_function_call_parsing() { let result = Parser::parse_expression("add(1, 2)").unwrap(); if let Expr::Call { func, args } = result.node { assert_eq!(func.node, Expr::Identifier("add".to_string())); assert_eq!(args.len(), 2); assert_eq!(args[0].node, Expr::Number(1)); assert_eq!(args[1].node, Expr::Number(2)); } else { panic!("Expected function call"); } } #[test] fn test_let_expression() { let result = Parser::parse_expression("let x = 5 in x + 1").unwrap(); if let Expr::Let { name, value, body } = result.node { assert_eq!(name, "x"); assert_eq!(value.node, Expr::Number(5)); if let Expr::Binary { left, op, right } = body.node { assert_eq!(op, BinaryOp::Add); assert_eq!(left.node, Expr::Identifier("x".to_string())); assert_eq!(right.node, Expr::Number(1)); } else { panic!("Expected binary expression in let body"); } } else { panic!("Expected let expression"); } } #[test] fn test_error_location() { let result = Parser::parse_expression("2 + "); assert!(result.is_err()); let error = result.unwrap_err(); assert_eq!(error.location.line, 1); assert_eq!(error.location.column, 5); } #[test] fn test_lexer_with_locations() { let mut lexer = LocatedLexer::new("let x = 42"); let tokens = lexer.tokenize().unwrap(); assert_eq!(tokens.len(), 5); // let, x, =, 42, EOF assert_eq!(tokens[0].kind, TokenKind::Keyword("let".to_string())); assert_eq!(tokens[0].location.column, 1); assert_eq!(tokens[1].kind, TokenKind::Identifier); assert_eq!(tokens[1].text, "x"); assert_eq!(tokens[1].location.column, 5); assert_eq!(tokens[2].kind, TokenKind::Equals); assert_eq!(tokens[2].location.column, 7); } #[test] fn test_multiline_locations() { let input = "let x = 1\nlet y = 2"; let mut lexer = LocatedLexer::new(input); let tokens = lexer.tokenize().unwrap(); // Find the second 'let' token let second_let = tokens .iter() .find(|t| { matches!(t.kind, TokenKind::Keyword(ref k) if k == "let") && t.location.line == 2 }) .expect("Should find second let token"); assert_eq!(second_let.location.line, 2); assert_eq!(second_let.location.column, 1); } #[test] fn test_position_helpers() { let input = "line1\nline2\nline3"; let pos_info = Parser::get_position_info(input, 7); // Position of 'l' in "line2" assert_eq!(pos_info, Some((2, 2))); let line_content = Parser::get_line_content(input, 2); assert_eq!(line_content, Some("line2")); } #[test] fn test_source_range() { let result = Parser::parse_expression("42 + 3").unwrap(); let range = result.span.to_range(); assert_eq!(range, 0..6); } } /// A range of source locations #[derive(Debug, Clone, PartialEq)] pub struct SourceRange { pub start: Location, pub end: Location, } }
These types provide line, column, and byte offset information for any parsed element.
Creating a Located Parser
Transform a nom parser to track locations by wrapping the input:
#![allow(unused)] fn main() { use std::ops::Range; use nom::branch::alt; use nom::bytes::complete::{tag, take_while, take_while1}; use nom::character::complete::{char, digit1, multispace0, multispace1}; use nom::combinator::{map, recognize}; use nom::error::{Error, ErrorKind}; use nom::multi::separated_list0; use nom::sequence::{delimited, pair, preceded}; use nom::{Err, IResult, Parser as NomParser}; use nom_locate::{position, LocatedSpan}; /// A span type that tracks position information pub type Span<'a> = LocatedSpan<&'a str>; /// Location information extracted from a span #[derive(Debug, Clone, PartialEq)] pub struct Location { pub line: u32, pub column: usize, pub offset: usize, } impl Location { pub fn from_span(span: Span<'_>) -> Self { Self { line: span.location_line(), column: span.get_utf8_column(), offset: span.location_offset(), } } } /// A range of source locations #[derive(Debug, Clone, PartialEq)] pub struct SourceRange { pub start: Location, pub end: Location, } impl SourceRange { pub fn from_spans(start: Span<'_>, end: Span<'_>) -> Self { Self { start: Location::from_span(start), end: Location::from_span(end), } } pub fn to_range(&self) -> Range<usize> { self.start.offset..self.end.offset } } /// AST node with location information #[derive(Debug, Clone, PartialEq)] pub struct Spanned<T> { pub node: T, pub span: SourceRange, } impl<T> Spanned<T> { pub fn new(node: T, start: Span<'_>, end: Span<'_>) -> Self { Self { node, span: SourceRange::from_spans(start, end), } } } /// Expression AST for a simple language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Identifier(String), Binary { left: Box<Spanned<Expr>>, op: BinaryOp, right: Box<Spanned<Expr>>, }, Call { func: Box<Spanned<Expr>>, args: Vec<Spanned<Expr>>, }, Let { name: String, value: Box<Spanned<Expr>>, body: Box<Spanned<Expr>>, }, } #[derive(Debug, Clone, PartialEq)] pub enum BinaryOp { Add, Sub, Mul, Div, Eq, Lt, Gt, } /// Parser error with precise location information #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub location: Location, pub expected: Vec<String>, } impl ParseError { pub fn from_nom_error(_input: Span<'_>, error: Error<Span<'_>>) -> Self { let location = Location::from_span(error.input); let message = match error.code { ErrorKind::Tag => "unexpected token".to_string(), ErrorKind::Digit => "expected number".to_string(), ErrorKind::Alpha => "expected identifier".to_string(), ErrorKind::Char => "expected character".to_string(), ErrorKind::Many0 => "expected list".to_string(), ErrorKind::SeparatedList => "expected comma-separated list".to_string(), _ => format!("parse error: {:?}", error.code), }; Self { message, location, expected: vec![], } } } impl Parser { /// Parse a complete expression from input pub fn parse_expression(input: &str) -> Result<Spanned<Expr>, ParseError> { let span = Span::new(input); match Self::expression(span) { Ok((_, expr)) => Ok(expr), Err(Err::Error(e)) | Err(Err::Failure(e)) => Err(ParseError::from_nom_error(span, e)), Err(Err::Incomplete(_)) => Err(ParseError { message: "incomplete input".to_string(), location: Location::from_span(span), expected: vec!["more input".to_string()], }), } } /// Parse an expression with precedence fn expression(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { Self::binary_expr(input, 0) } /// Parse binary expressions with operator precedence fn binary_expr(input: Span<'_>, min_prec: u8) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, mut left) = Self::primary_expr(input)?; let mut current_input = input; loop { let (input, _) = multispace0(current_input)?; // Try to parse an operator let op_result: IResult<Span<'_>, (BinaryOp, u8)> = alt(( map(char('+'), |_| (BinaryOp::Add, 1)), map(char('-'), |_| (BinaryOp::Sub, 1)), map(char('*'), |_| (BinaryOp::Mul, 2)), map(char('/'), |_| (BinaryOp::Div, 2)), map(tag("=="), |_| (BinaryOp::Eq, 0)), map(char('<'), |_| (BinaryOp::Lt, 0)), map(char('>'), |_| (BinaryOp::Gt, 0)), )) .parse(input); match op_result { Ok((input, (op, prec))) if prec >= min_prec => { let (input, _) = multispace0(input)?; let (input, right) = Self::binary_expr(input, prec + 1)?; let end_span = position(input)?; left = Spanned::new( Expr::Binary { left: Box::new(left), op, right: Box::new(right), }, start_pos.0, end_span.0, ); current_input = input; } _ => break, } } Ok((current_input, left)) } /// Parse primary expressions (atoms and parenthesized expressions) fn primary_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let (input, _) = multispace0(input)?; alt(( Self::parenthesized_expr, Self::function_call, Self::let_expr, Self::number, Self::identifier, )) .parse(input) } /// Parse parenthesized expressions fn parenthesized_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, expr) = delimited( char('('), preceded(multispace0, Self::expression), preceded(multispace0, char(')')), ) .parse(input)?; let end_pos = position(input)?; Ok((input, Spanned::new(expr.node, start_pos.0, end_pos.0))) } /// Parse function calls fn function_call(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, func) = Self::identifier(input)?; let (input, _) = multispace0(input)?; let (input, _) = char('(')(input)?; let (input, _) = multispace0(input)?; let (input, args) = separated_list0( delimited(multispace0, char(','), multispace0), Self::expression, ) .parse(input)?; let (input, _) = multispace0(input)?; let (input, _) = char(')')(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new( Expr::Call { func: Box::new(func), args, }, start_pos.0, end_pos.0, ), )) } /// Parse let expressions fn let_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, _) = tag("let")(input)?; let (input, _) = multispace1(input)?; let (input, name) = Self::identifier_string(input)?; let (input, _) = multispace0(input)?; let (input, _) = char('=')(input)?; let (input, _) = multispace0(input)?; let (input, value) = Self::expression(input)?; let (input, _) = multispace0(input)?; let (input, _) = tag("in")(input)?; let (input, _) = multispace0(input)?; let (input, body) = Self::expression(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new( Expr::Let { name, value: Box::new(value), body: Box::new(body), }, start_pos.0, end_pos.0, ), )) } /// Parse numbers fn number(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, digits) = digit1(input)?; let end_pos = position(input)?; let num = digits .fragment() .parse() .map_err(|_| Err::Error(Error::new(input, ErrorKind::Digit)))?; Ok(( input, Spanned::new(Expr::Number(num), start_pos.0, end_pos.0), )) } /// Parse identifiers fn identifier(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, ident) = Self::identifier_string(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new(Expr::Identifier(ident), start_pos.0, end_pos.0), )) } /// Parse identifier strings fn identifier_string(input: Span<'_>) -> IResult<Span<'_>, String> { let (input, ident) = recognize(pair( alt((tag("_"), take_while1(|c: char| c.is_ascii_alphabetic()))), take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), )) .parse(input)?; Ok((input, ident.fragment().to_string())) } /// Get position information for error reporting pub fn get_position_info(input: &str, offset: usize) -> Option<(u32, usize)> { if offset <= input.len() { // Count lines and column up to the offset let mut line = 1; let mut col = 1; for (i, ch) in input.char_indices() { if i >= offset { break; } if ch == '\n' { line += 1; col = 1; } else { col += 1; } } Some((line, col)) } else { None } } /// Extract line content for error reporting pub fn get_line_content(input: &str, line_number: u32) -> Option<&str> { input.lines().nth(line_number.saturating_sub(1) as usize) } } /// A lexer that preserves location information for each token pub struct LocatedLexer<'a> { input: Span<'a>, } #[derive(Debug, Clone, PartialEq)] pub struct LocatedToken { pub kind: TokenKind, pub location: Location, pub text: String, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Number, Identifier, Keyword(String), Operator(String), LeftParen, RightParen, Comma, Equals, Whitespace, Comment, Eof, } impl<'a> LocatedLexer<'a> { pub fn new(input: &'a str) -> Self { Self { input: Span::new(input), } } /// Tokenize input while preserving location information pub fn tokenize(&mut self) -> Result<Vec<LocatedToken>, ParseError> { let mut tokens = Vec::new(); let mut current = self.input; while !current.fragment().is_empty() { let _start_pos = position(current).map_err(|e| { ParseError::from_nom_error( current, match e { Err::Error(err) | Err::Failure(err) => err, Err::Incomplete(_) => Error::new(current, ErrorKind::Complete), }, ) })?; let (remaining, token) = self.next_token(current).map_err(|e| match e { Err::Error(err) | Err::Failure(err) => ParseError::from_nom_error(current, err), Err::Incomplete(_) => ParseError { message: "incomplete token".to_string(), location: Location::from_span(current), expected: vec!["complete token".to_string()], }, })?; if let Some(token) = token { tokens.push(token); } current = remaining; } tokens.push(LocatedToken { kind: TokenKind::Eof, location: Location::from_span(current), text: String::new(), }); Ok(tokens) } fn next_token(&self, input: Span<'a>) -> IResult<Span<'a>, Option<LocatedToken>> { alt(( map(|i| self.whitespace_or_comment(i), |_| None), map(|i| self.keyword_or_identifier(i), Some), map(|i| self.number_token(i), Some), map(|i| self.operator_token(i), Some), map(|i| self.punctuation_token(i), Some), )) .parse(input) } fn whitespace_or_comment(&self, input: Span<'a>) -> IResult<Span<'a>, ()> { let (input, _) = alt(( multispace1, recognize((tag("//"), take_while(|c| c != '\n'))), )) .parse(input)?; Ok((input, ())) } fn keyword_or_identifier(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, ident) = recognize(pair( alt((tag("_"), take_while1(|c: char| c.is_ascii_alphabetic()))), take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), )) .parse(input)?; let text = ident.fragment().to_string(); let kind = match text.as_str() { "let" | "in" | "if" | "then" | "else" | "fn" => TokenKind::Keyword(text.clone()), _ => TokenKind::Identifier, }; Ok(( input, LocatedToken { kind, location: Location::from_span(start_pos.0), text, }, )) } fn number_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, number) = digit1(input)?; Ok(( input, LocatedToken { kind: TokenKind::Number, location: Location::from_span(start_pos.0), text: number.fragment().to_string(), }, )) } fn operator_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, op) = alt(( tag("=="), tag("<="), tag(">="), tag("!="), tag("+"), tag("-"), tag("*"), tag("/"), tag("<"), tag(">"), )) .parse(input)?; Ok(( input, LocatedToken { kind: TokenKind::Operator(op.fragment().to_string()), location: Location::from_span(start_pos.0), text: op.fragment().to_string(), }, )) } fn punctuation_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, punct) = alt(( map(char('('), |_| TokenKind::LeftParen), map(char(')'), |_| TokenKind::RightParen), map(char(','), |_| TokenKind::Comma), map(char('='), |_| TokenKind::Equals), )) .parse(input)?; let text = match punct { TokenKind::LeftParen => "(".to_string(), TokenKind::RightParen => ")".to_string(), TokenKind::Comma => ",".to_string(), TokenKind::Equals => "=".to_string(), _ => String::new(), }; Ok(( input, LocatedToken { kind: punct, location: Location::from_span(start_pos.0), text, }, )) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_location_tracking() { let input = "let x = 42\nin y + 1"; let span = Span::new(input); // Test line and column calculation assert_eq!(span.location_line(), 1); assert_eq!(span.get_utf8_column(), 1); // Test position after advancing let (remaining, _): (Span, Span) = tag::<&str, Span, Error<Span>>("let")(span).unwrap(); assert_eq!(remaining.location_offset(), 3); } #[test] fn test_number_parsing() { let result = Parser::parse_expression("42").unwrap(); assert_eq!(result.node, Expr::Number(42)); assert_eq!(result.span.start.line, 1); assert_eq!(result.span.start.column, 1); } #[test] fn test_binary_expression() { let result = Parser::parse_expression("2 + 3 * 4").unwrap(); if let Expr::Binary { left, op, right } = result.node { assert_eq!(op, BinaryOp::Add); assert_eq!(left.node, Expr::Number(2)); if let Expr::Binary { left, op, right } = right.node { assert_eq!(op, BinaryOp::Mul); assert_eq!(left.node, Expr::Number(3)); assert_eq!(right.node, Expr::Number(4)); } else { panic!("Expected binary expression for multiplication"); } } else { panic!("Expected binary expression"); } } #[test] fn test_function_call_parsing() { let result = Parser::parse_expression("add(1, 2)").unwrap(); if let Expr::Call { func, args } = result.node { assert_eq!(func.node, Expr::Identifier("add".to_string())); assert_eq!(args.len(), 2); assert_eq!(args[0].node, Expr::Number(1)); assert_eq!(args[1].node, Expr::Number(2)); } else { panic!("Expected function call"); } } #[test] fn test_let_expression() { let result = Parser::parse_expression("let x = 5 in x + 1").unwrap(); if let Expr::Let { name, value, body } = result.node { assert_eq!(name, "x"); assert_eq!(value.node, Expr::Number(5)); if let Expr::Binary { left, op, right } = body.node { assert_eq!(op, BinaryOp::Add); assert_eq!(left.node, Expr::Identifier("x".to_string())); assert_eq!(right.node, Expr::Number(1)); } else { panic!("Expected binary expression in let body"); } } else { panic!("Expected let expression"); } } #[test] fn test_error_location() { let result = Parser::parse_expression("2 + "); assert!(result.is_err()); let error = result.unwrap_err(); assert_eq!(error.location.line, 1); assert_eq!(error.location.column, 5); } #[test] fn test_lexer_with_locations() { let mut lexer = LocatedLexer::new("let x = 42"); let tokens = lexer.tokenize().unwrap(); assert_eq!(tokens.len(), 5); // let, x, =, 42, EOF assert_eq!(tokens[0].kind, TokenKind::Keyword("let".to_string())); assert_eq!(tokens[0].location.column, 1); assert_eq!(tokens[1].kind, TokenKind::Identifier); assert_eq!(tokens[1].text, "x"); assert_eq!(tokens[1].location.column, 5); assert_eq!(tokens[2].kind, TokenKind::Equals); assert_eq!(tokens[2].location.column, 7); } #[test] fn test_multiline_locations() { let input = "let x = 1\nlet y = 2"; let mut lexer = LocatedLexer::new(input); let tokens = lexer.tokenize().unwrap(); // Find the second 'let' token let second_let = tokens .iter() .find(|t| { matches!(t.kind, TokenKind::Keyword(ref k) if k == "let") && t.location.line == 2 }) .expect("Should find second let token"); assert_eq!(second_let.location.line, 2); assert_eq!(second_let.location.column, 1); } #[test] fn test_position_helpers() { let input = "line1\nline2\nline3"; let pos_info = Parser::get_position_info(input, 7); // Position of 'l' in "line2" assert_eq!(pos_info, Some((2, 2))); let line_content = Parser::get_line_content(input, 2); assert_eq!(line_content, Some("line2")); } #[test] fn test_source_range() { let result = Parser::parse_expression("42 + 3").unwrap(); let range = result.span.to_range(); assert_eq!(range, 0..6); } } /// Main parser implementation pub struct Parser; }
The parser methods work with Span
instead of &str
, automatically tracking positions as parsing proceeds.
Expression Parsing with Locations
Building an expression parser that preserves source locations:
#![allow(unused)] fn main() { use std::ops::Range; use nom::branch::alt; use nom::bytes::complete::{tag, take_while, take_while1}; use nom::character::complete::{char, digit1, multispace0, multispace1}; use nom::combinator::{map, recognize}; use nom::error::{Error, ErrorKind}; use nom::multi::separated_list0; use nom::sequence::{delimited, pair, preceded}; use nom::{Err, IResult, Parser as NomParser}; use nom_locate::{position, LocatedSpan}; /// A span type that tracks position information pub type Span<'a> = LocatedSpan<&'a str>; /// Location information extracted from a span #[derive(Debug, Clone, PartialEq)] pub struct Location { pub line: u32, pub column: usize, pub offset: usize, } impl Location { pub fn from_span(span: Span<'_>) -> Self { Self { line: span.location_line(), column: span.get_utf8_column(), offset: span.location_offset(), } } } /// A range of source locations #[derive(Debug, Clone, PartialEq)] pub struct SourceRange { pub start: Location, pub end: Location, } impl SourceRange { pub fn from_spans(start: Span<'_>, end: Span<'_>) -> Self { Self { start: Location::from_span(start), end: Location::from_span(end), } } pub fn to_range(&self) -> Range<usize> { self.start.offset..self.end.offset } } /// AST node with location information #[derive(Debug, Clone, PartialEq)] pub struct Spanned<T> { pub node: T, pub span: SourceRange, } impl<T> Spanned<T> { pub fn new(node: T, start: Span<'_>, end: Span<'_>) -> Self { Self { node, span: SourceRange::from_spans(start, end), } } } #[derive(Debug, Clone, PartialEq)] pub enum BinaryOp { Add, Sub, Mul, Div, Eq, Lt, Gt, } /// Parser error with precise location information #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub location: Location, pub expected: Vec<String>, } impl ParseError { pub fn from_nom_error(_input: Span<'_>, error: Error<Span<'_>>) -> Self { let location = Location::from_span(error.input); let message = match error.code { ErrorKind::Tag => "unexpected token".to_string(), ErrorKind::Digit => "expected number".to_string(), ErrorKind::Alpha => "expected identifier".to_string(), ErrorKind::Char => "expected character".to_string(), ErrorKind::Many0 => "expected list".to_string(), ErrorKind::SeparatedList => "expected comma-separated list".to_string(), _ => format!("parse error: {:?}", error.code), }; Self { message, location, expected: vec![], } } } /// Main parser implementation pub struct Parser; impl Parser { /// Parse a complete expression from input pub fn parse_expression(input: &str) -> Result<Spanned<Expr>, ParseError> { let span = Span::new(input); match Self::expression(span) { Ok((_, expr)) => Ok(expr), Err(Err::Error(e)) | Err(Err::Failure(e)) => Err(ParseError::from_nom_error(span, e)), Err(Err::Incomplete(_)) => Err(ParseError { message: "incomplete input".to_string(), location: Location::from_span(span), expected: vec!["more input".to_string()], }), } } /// Parse an expression with precedence fn expression(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { Self::binary_expr(input, 0) } /// Parse binary expressions with operator precedence fn binary_expr(input: Span<'_>, min_prec: u8) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, mut left) = Self::primary_expr(input)?; let mut current_input = input; loop { let (input, _) = multispace0(current_input)?; // Try to parse an operator let op_result: IResult<Span<'_>, (BinaryOp, u8)> = alt(( map(char('+'), |_| (BinaryOp::Add, 1)), map(char('-'), |_| (BinaryOp::Sub, 1)), map(char('*'), |_| (BinaryOp::Mul, 2)), map(char('/'), |_| (BinaryOp::Div, 2)), map(tag("=="), |_| (BinaryOp::Eq, 0)), map(char('<'), |_| (BinaryOp::Lt, 0)), map(char('>'), |_| (BinaryOp::Gt, 0)), )) .parse(input); match op_result { Ok((input, (op, prec))) if prec >= min_prec => { let (input, _) = multispace0(input)?; let (input, right) = Self::binary_expr(input, prec + 1)?; let end_span = position(input)?; left = Spanned::new( Expr::Binary { left: Box::new(left), op, right: Box::new(right), }, start_pos.0, end_span.0, ); current_input = input; } _ => break, } } Ok((current_input, left)) } /// Parse primary expressions (atoms and parenthesized expressions) fn primary_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let (input, _) = multispace0(input)?; alt(( Self::parenthesized_expr, Self::function_call, Self::let_expr, Self::number, Self::identifier, )) .parse(input) } /// Parse parenthesized expressions fn parenthesized_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, expr) = delimited( char('('), preceded(multispace0, Self::expression), preceded(multispace0, char(')')), ) .parse(input)?; let end_pos = position(input)?; Ok((input, Spanned::new(expr.node, start_pos.0, end_pos.0))) } /// Parse function calls fn function_call(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, func) = Self::identifier(input)?; let (input, _) = multispace0(input)?; let (input, _) = char('(')(input)?; let (input, _) = multispace0(input)?; let (input, args) = separated_list0( delimited(multispace0, char(','), multispace0), Self::expression, ) .parse(input)?; let (input, _) = multispace0(input)?; let (input, _) = char(')')(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new( Expr::Call { func: Box::new(func), args, }, start_pos.0, end_pos.0, ), )) } /// Parse let expressions fn let_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, _) = tag("let")(input)?; let (input, _) = multispace1(input)?; let (input, name) = Self::identifier_string(input)?; let (input, _) = multispace0(input)?; let (input, _) = char('=')(input)?; let (input, _) = multispace0(input)?; let (input, value) = Self::expression(input)?; let (input, _) = multispace0(input)?; let (input, _) = tag("in")(input)?; let (input, _) = multispace0(input)?; let (input, body) = Self::expression(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new( Expr::Let { name, value: Box::new(value), body: Box::new(body), }, start_pos.0, end_pos.0, ), )) } /// Parse numbers fn number(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, digits) = digit1(input)?; let end_pos = position(input)?; let num = digits .fragment() .parse() .map_err(|_| Err::Error(Error::new(input, ErrorKind::Digit)))?; Ok(( input, Spanned::new(Expr::Number(num), start_pos.0, end_pos.0), )) } /// Parse identifiers fn identifier(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, ident) = Self::identifier_string(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new(Expr::Identifier(ident), start_pos.0, end_pos.0), )) } /// Parse identifier strings fn identifier_string(input: Span<'_>) -> IResult<Span<'_>, String> { let (input, ident) = recognize(pair( alt((tag("_"), take_while1(|c: char| c.is_ascii_alphabetic()))), take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), )) .parse(input)?; Ok((input, ident.fragment().to_string())) } /// Get position information for error reporting pub fn get_position_info(input: &str, offset: usize) -> Option<(u32, usize)> { if offset <= input.len() { // Count lines and column up to the offset let mut line = 1; let mut col = 1; for (i, ch) in input.char_indices() { if i >= offset { break; } if ch == '\n' { line += 1; col = 1; } else { col += 1; } } Some((line, col)) } else { None } } /// Extract line content for error reporting pub fn get_line_content(input: &str, line_number: u32) -> Option<&str> { input.lines().nth(line_number.saturating_sub(1) as usize) } } /// A lexer that preserves location information for each token pub struct LocatedLexer<'a> { input: Span<'a>, } #[derive(Debug, Clone, PartialEq)] pub struct LocatedToken { pub kind: TokenKind, pub location: Location, pub text: String, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Number, Identifier, Keyword(String), Operator(String), LeftParen, RightParen, Comma, Equals, Whitespace, Comment, Eof, } impl<'a> LocatedLexer<'a> { pub fn new(input: &'a str) -> Self { Self { input: Span::new(input), } } /// Tokenize input while preserving location information pub fn tokenize(&mut self) -> Result<Vec<LocatedToken>, ParseError> { let mut tokens = Vec::new(); let mut current = self.input; while !current.fragment().is_empty() { let _start_pos = position(current).map_err(|e| { ParseError::from_nom_error( current, match e { Err::Error(err) | Err::Failure(err) => err, Err::Incomplete(_) => Error::new(current, ErrorKind::Complete), }, ) })?; let (remaining, token) = self.next_token(current).map_err(|e| match e { Err::Error(err) | Err::Failure(err) => ParseError::from_nom_error(current, err), Err::Incomplete(_) => ParseError { message: "incomplete token".to_string(), location: Location::from_span(current), expected: vec!["complete token".to_string()], }, })?; if let Some(token) = token { tokens.push(token); } current = remaining; } tokens.push(LocatedToken { kind: TokenKind::Eof, location: Location::from_span(current), text: String::new(), }); Ok(tokens) } fn next_token(&self, input: Span<'a>) -> IResult<Span<'a>, Option<LocatedToken>> { alt(( map(|i| self.whitespace_or_comment(i), |_| None), map(|i| self.keyword_or_identifier(i), Some), map(|i| self.number_token(i), Some), map(|i| self.operator_token(i), Some), map(|i| self.punctuation_token(i), Some), )) .parse(input) } fn whitespace_or_comment(&self, input: Span<'a>) -> IResult<Span<'a>, ()> { let (input, _) = alt(( multispace1, recognize((tag("//"), take_while(|c| c != '\n'))), )) .parse(input)?; Ok((input, ())) } fn keyword_or_identifier(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, ident) = recognize(pair( alt((tag("_"), take_while1(|c: char| c.is_ascii_alphabetic()))), take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), )) .parse(input)?; let text = ident.fragment().to_string(); let kind = match text.as_str() { "let" | "in" | "if" | "then" | "else" | "fn" => TokenKind::Keyword(text.clone()), _ => TokenKind::Identifier, }; Ok(( input, LocatedToken { kind, location: Location::from_span(start_pos.0), text, }, )) } fn number_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, number) = digit1(input)?; Ok(( input, LocatedToken { kind: TokenKind::Number, location: Location::from_span(start_pos.0), text: number.fragment().to_string(), }, )) } fn operator_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, op) = alt(( tag("=="), tag("<="), tag(">="), tag("!="), tag("+"), tag("-"), tag("*"), tag("/"), tag("<"), tag(">"), )) .parse(input)?; Ok(( input, LocatedToken { kind: TokenKind::Operator(op.fragment().to_string()), location: Location::from_span(start_pos.0), text: op.fragment().to_string(), }, )) } fn punctuation_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, punct) = alt(( map(char('('), |_| TokenKind::LeftParen), map(char(')'), |_| TokenKind::RightParen), map(char(','), |_| TokenKind::Comma), map(char('='), |_| TokenKind::Equals), )) .parse(input)?; let text = match punct { TokenKind::LeftParen => "(".to_string(), TokenKind::RightParen => ")".to_string(), TokenKind::Comma => ",".to_string(), TokenKind::Equals => "=".to_string(), _ => String::new(), }; Ok(( input, LocatedToken { kind: punct, location: Location::from_span(start_pos.0), text, }, )) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_location_tracking() { let input = "let x = 42\nin y + 1"; let span = Span::new(input); // Test line and column calculation assert_eq!(span.location_line(), 1); assert_eq!(span.get_utf8_column(), 1); // Test position after advancing let (remaining, _): (Span, Span) = tag::<&str, Span, Error<Span>>("let")(span).unwrap(); assert_eq!(remaining.location_offset(), 3); } #[test] fn test_number_parsing() { let result = Parser::parse_expression("42").unwrap(); assert_eq!(result.node, Expr::Number(42)); assert_eq!(result.span.start.line, 1); assert_eq!(result.span.start.column, 1); } #[test] fn test_binary_expression() { let result = Parser::parse_expression("2 + 3 * 4").unwrap(); if let Expr::Binary { left, op, right } = result.node { assert_eq!(op, BinaryOp::Add); assert_eq!(left.node, Expr::Number(2)); if let Expr::Binary { left, op, right } = right.node { assert_eq!(op, BinaryOp::Mul); assert_eq!(left.node, Expr::Number(3)); assert_eq!(right.node, Expr::Number(4)); } else { panic!("Expected binary expression for multiplication"); } } else { panic!("Expected binary expression"); } } #[test] fn test_function_call_parsing() { let result = Parser::parse_expression("add(1, 2)").unwrap(); if let Expr::Call { func, args } = result.node { assert_eq!(func.node, Expr::Identifier("add".to_string())); assert_eq!(args.len(), 2); assert_eq!(args[0].node, Expr::Number(1)); assert_eq!(args[1].node, Expr::Number(2)); } else { panic!("Expected function call"); } } #[test] fn test_let_expression() { let result = Parser::parse_expression("let x = 5 in x + 1").unwrap(); if let Expr::Let { name, value, body } = result.node { assert_eq!(name, "x"); assert_eq!(value.node, Expr::Number(5)); if let Expr::Binary { left, op, right } = body.node { assert_eq!(op, BinaryOp::Add); assert_eq!(left.node, Expr::Identifier("x".to_string())); assert_eq!(right.node, Expr::Number(1)); } else { panic!("Expected binary expression in let body"); } } else { panic!("Expected let expression"); } } #[test] fn test_error_location() { let result = Parser::parse_expression("2 + "); assert!(result.is_err()); let error = result.unwrap_err(); assert_eq!(error.location.line, 1); assert_eq!(error.location.column, 5); } #[test] fn test_lexer_with_locations() { let mut lexer = LocatedLexer::new("let x = 42"); let tokens = lexer.tokenize().unwrap(); assert_eq!(tokens.len(), 5); // let, x, =, 42, EOF assert_eq!(tokens[0].kind, TokenKind::Keyword("let".to_string())); assert_eq!(tokens[0].location.column, 1); assert_eq!(tokens[1].kind, TokenKind::Identifier); assert_eq!(tokens[1].text, "x"); assert_eq!(tokens[1].location.column, 5); assert_eq!(tokens[2].kind, TokenKind::Equals); assert_eq!(tokens[2].location.column, 7); } #[test] fn test_multiline_locations() { let input = "let x = 1\nlet y = 2"; let mut lexer = LocatedLexer::new(input); let tokens = lexer.tokenize().unwrap(); // Find the second 'let' token let second_let = tokens .iter() .find(|t| { matches!(t.kind, TokenKind::Keyword(ref k) if k == "let") && t.location.line == 2 }) .expect("Should find second let token"); assert_eq!(second_let.location.line, 2); assert_eq!(second_let.location.column, 1); } #[test] fn test_position_helpers() { let input = "line1\nline2\nline3"; let pos_info = Parser::get_position_info(input, 7); // Position of 'l' in "line2" assert_eq!(pos_info, Some((2, 2))); let line_content = Parser::get_line_content(input, 2); assert_eq!(line_content, Some("line2")); } #[test] fn test_source_range() { let result = Parser::parse_expression("42 + 3").unwrap(); let range = result.span.to_range(); assert_eq!(range, 0..6); } } /// Expression AST for a simple language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Identifier(String), Binary { left: Box<Spanned<Expr>>, op: BinaryOp, right: Box<Spanned<Expr>>, }, Call { func: Box<Spanned<Expr>>, args: Vec<Spanned<Expr>>, }, Let { name: String, value: Box<Spanned<Expr>>, body: Box<Spanned<Expr>>, }, } }
#![allow(unused)] fn main() { use std::ops::Range; use nom::branch::alt; use nom::bytes::complete::{tag, take_while, take_while1}; use nom::character::complete::{char, digit1, multispace0, multispace1}; use nom::combinator::{map, recognize}; use nom::error::{Error, ErrorKind}; use nom::multi::separated_list0; use nom::sequence::{delimited, pair, preceded}; use nom::{Err, IResult, Parser as NomParser}; use nom_locate::{position, LocatedSpan}; /// A span type that tracks position information pub type Span<'a> = LocatedSpan<&'a str>; /// Location information extracted from a span #[derive(Debug, Clone, PartialEq)] pub struct Location { pub line: u32, pub column: usize, pub offset: usize, } impl Location { pub fn from_span(span: Span<'_>) -> Self { Self { line: span.location_line(), column: span.get_utf8_column(), offset: span.location_offset(), } } } /// A range of source locations #[derive(Debug, Clone, PartialEq)] pub struct SourceRange { pub start: Location, pub end: Location, } impl SourceRange { pub fn from_spans(start: Span<'_>, end: Span<'_>) -> Self { Self { start: Location::from_span(start), end: Location::from_span(end), } } pub fn to_range(&self) -> Range<usize> { self.start.offset..self.end.offset } } impl<T> Spanned<T> { pub fn new(node: T, start: Span<'_>, end: Span<'_>) -> Self { Self { node, span: SourceRange::from_spans(start, end), } } } /// Expression AST for a simple language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Identifier(String), Binary { left: Box<Spanned<Expr>>, op: BinaryOp, right: Box<Spanned<Expr>>, }, Call { func: Box<Spanned<Expr>>, args: Vec<Spanned<Expr>>, }, Let { name: String, value: Box<Spanned<Expr>>, body: Box<Spanned<Expr>>, }, } #[derive(Debug, Clone, PartialEq)] pub enum BinaryOp { Add, Sub, Mul, Div, Eq, Lt, Gt, } /// Parser error with precise location information #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub location: Location, pub expected: Vec<String>, } impl ParseError { pub fn from_nom_error(_input: Span<'_>, error: Error<Span<'_>>) -> Self { let location = Location::from_span(error.input); let message = match error.code { ErrorKind::Tag => "unexpected token".to_string(), ErrorKind::Digit => "expected number".to_string(), ErrorKind::Alpha => "expected identifier".to_string(), ErrorKind::Char => "expected character".to_string(), ErrorKind::Many0 => "expected list".to_string(), ErrorKind::SeparatedList => "expected comma-separated list".to_string(), _ => format!("parse error: {:?}", error.code), }; Self { message, location, expected: vec![], } } } /// Main parser implementation pub struct Parser; impl Parser { /// Parse a complete expression from input pub fn parse_expression(input: &str) -> Result<Spanned<Expr>, ParseError> { let span = Span::new(input); match Self::expression(span) { Ok((_, expr)) => Ok(expr), Err(Err::Error(e)) | Err(Err::Failure(e)) => Err(ParseError::from_nom_error(span, e)), Err(Err::Incomplete(_)) => Err(ParseError { message: "incomplete input".to_string(), location: Location::from_span(span), expected: vec!["more input".to_string()], }), } } /// Parse an expression with precedence fn expression(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { Self::binary_expr(input, 0) } /// Parse binary expressions with operator precedence fn binary_expr(input: Span<'_>, min_prec: u8) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, mut left) = Self::primary_expr(input)?; let mut current_input = input; loop { let (input, _) = multispace0(current_input)?; // Try to parse an operator let op_result: IResult<Span<'_>, (BinaryOp, u8)> = alt(( map(char('+'), |_| (BinaryOp::Add, 1)), map(char('-'), |_| (BinaryOp::Sub, 1)), map(char('*'), |_| (BinaryOp::Mul, 2)), map(char('/'), |_| (BinaryOp::Div, 2)), map(tag("=="), |_| (BinaryOp::Eq, 0)), map(char('<'), |_| (BinaryOp::Lt, 0)), map(char('>'), |_| (BinaryOp::Gt, 0)), )) .parse(input); match op_result { Ok((input, (op, prec))) if prec >= min_prec => { let (input, _) = multispace0(input)?; let (input, right) = Self::binary_expr(input, prec + 1)?; let end_span = position(input)?; left = Spanned::new( Expr::Binary { left: Box::new(left), op, right: Box::new(right), }, start_pos.0, end_span.0, ); current_input = input; } _ => break, } } Ok((current_input, left)) } /// Parse primary expressions (atoms and parenthesized expressions) fn primary_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let (input, _) = multispace0(input)?; alt(( Self::parenthesized_expr, Self::function_call, Self::let_expr, Self::number, Self::identifier, )) .parse(input) } /// Parse parenthesized expressions fn parenthesized_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, expr) = delimited( char('('), preceded(multispace0, Self::expression), preceded(multispace0, char(')')), ) .parse(input)?; let end_pos = position(input)?; Ok((input, Spanned::new(expr.node, start_pos.0, end_pos.0))) } /// Parse function calls fn function_call(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, func) = Self::identifier(input)?; let (input, _) = multispace0(input)?; let (input, _) = char('(')(input)?; let (input, _) = multispace0(input)?; let (input, args) = separated_list0( delimited(multispace0, char(','), multispace0), Self::expression, ) .parse(input)?; let (input, _) = multispace0(input)?; let (input, _) = char(')')(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new( Expr::Call { func: Box::new(func), args, }, start_pos.0, end_pos.0, ), )) } /// Parse let expressions fn let_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, _) = tag("let")(input)?; let (input, _) = multispace1(input)?; let (input, name) = Self::identifier_string(input)?; let (input, _) = multispace0(input)?; let (input, _) = char('=')(input)?; let (input, _) = multispace0(input)?; let (input, value) = Self::expression(input)?; let (input, _) = multispace0(input)?; let (input, _) = tag("in")(input)?; let (input, _) = multispace0(input)?; let (input, body) = Self::expression(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new( Expr::Let { name, value: Box::new(value), body: Box::new(body), }, start_pos.0, end_pos.0, ), )) } /// Parse numbers fn number(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, digits) = digit1(input)?; let end_pos = position(input)?; let num = digits .fragment() .parse() .map_err(|_| Err::Error(Error::new(input, ErrorKind::Digit)))?; Ok(( input, Spanned::new(Expr::Number(num), start_pos.0, end_pos.0), )) } /// Parse identifiers fn identifier(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, ident) = Self::identifier_string(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new(Expr::Identifier(ident), start_pos.0, end_pos.0), )) } /// Parse identifier strings fn identifier_string(input: Span<'_>) -> IResult<Span<'_>, String> { let (input, ident) = recognize(pair( alt((tag("_"), take_while1(|c: char| c.is_ascii_alphabetic()))), take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), )) .parse(input)?; Ok((input, ident.fragment().to_string())) } /// Get position information for error reporting pub fn get_position_info(input: &str, offset: usize) -> Option<(u32, usize)> { if offset <= input.len() { // Count lines and column up to the offset let mut line = 1; let mut col = 1; for (i, ch) in input.char_indices() { if i >= offset { break; } if ch == '\n' { line += 1; col = 1; } else { col += 1; } } Some((line, col)) } else { None } } /// Extract line content for error reporting pub fn get_line_content(input: &str, line_number: u32) -> Option<&str> { input.lines().nth(line_number.saturating_sub(1) as usize) } } /// A lexer that preserves location information for each token pub struct LocatedLexer<'a> { input: Span<'a>, } #[derive(Debug, Clone, PartialEq)] pub struct LocatedToken { pub kind: TokenKind, pub location: Location, pub text: String, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Number, Identifier, Keyword(String), Operator(String), LeftParen, RightParen, Comma, Equals, Whitespace, Comment, Eof, } impl<'a> LocatedLexer<'a> { pub fn new(input: &'a str) -> Self { Self { input: Span::new(input), } } /// Tokenize input while preserving location information pub fn tokenize(&mut self) -> Result<Vec<LocatedToken>, ParseError> { let mut tokens = Vec::new(); let mut current = self.input; while !current.fragment().is_empty() { let _start_pos = position(current).map_err(|e| { ParseError::from_nom_error( current, match e { Err::Error(err) | Err::Failure(err) => err, Err::Incomplete(_) => Error::new(current, ErrorKind::Complete), }, ) })?; let (remaining, token) = self.next_token(current).map_err(|e| match e { Err::Error(err) | Err::Failure(err) => ParseError::from_nom_error(current, err), Err::Incomplete(_) => ParseError { message: "incomplete token".to_string(), location: Location::from_span(current), expected: vec!["complete token".to_string()], }, })?; if let Some(token) = token { tokens.push(token); } current = remaining; } tokens.push(LocatedToken { kind: TokenKind::Eof, location: Location::from_span(current), text: String::new(), }); Ok(tokens) } fn next_token(&self, input: Span<'a>) -> IResult<Span<'a>, Option<LocatedToken>> { alt(( map(|i| self.whitespace_or_comment(i), |_| None), map(|i| self.keyword_or_identifier(i), Some), map(|i| self.number_token(i), Some), map(|i| self.operator_token(i), Some), map(|i| self.punctuation_token(i), Some), )) .parse(input) } fn whitespace_or_comment(&self, input: Span<'a>) -> IResult<Span<'a>, ()> { let (input, _) = alt(( multispace1, recognize((tag("//"), take_while(|c| c != '\n'))), )) .parse(input)?; Ok((input, ())) } fn keyword_or_identifier(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, ident) = recognize(pair( alt((tag("_"), take_while1(|c: char| c.is_ascii_alphabetic()))), take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), )) .parse(input)?; let text = ident.fragment().to_string(); let kind = match text.as_str() { "let" | "in" | "if" | "then" | "else" | "fn" => TokenKind::Keyword(text.clone()), _ => TokenKind::Identifier, }; Ok(( input, LocatedToken { kind, location: Location::from_span(start_pos.0), text, }, )) } fn number_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, number) = digit1(input)?; Ok(( input, LocatedToken { kind: TokenKind::Number, location: Location::from_span(start_pos.0), text: number.fragment().to_string(), }, )) } fn operator_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, op) = alt(( tag("=="), tag("<="), tag(">="), tag("!="), tag("+"), tag("-"), tag("*"), tag("/"), tag("<"), tag(">"), )) .parse(input)?; Ok(( input, LocatedToken { kind: TokenKind::Operator(op.fragment().to_string()), location: Location::from_span(start_pos.0), text: op.fragment().to_string(), }, )) } fn punctuation_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, punct) = alt(( map(char('('), |_| TokenKind::LeftParen), map(char(')'), |_| TokenKind::RightParen), map(char(','), |_| TokenKind::Comma), map(char('='), |_| TokenKind::Equals), )) .parse(input)?; let text = match punct { TokenKind::LeftParen => "(".to_string(), TokenKind::RightParen => ")".to_string(), TokenKind::Comma => ",".to_string(), TokenKind::Equals => "=".to_string(), _ => String::new(), }; Ok(( input, LocatedToken { kind: punct, location: Location::from_span(start_pos.0), text, }, )) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_location_tracking() { let input = "let x = 42\nin y + 1"; let span = Span::new(input); // Test line and column calculation assert_eq!(span.location_line(), 1); assert_eq!(span.get_utf8_column(), 1); // Test position after advancing let (remaining, _): (Span, Span) = tag::<&str, Span, Error<Span>>("let")(span).unwrap(); assert_eq!(remaining.location_offset(), 3); } #[test] fn test_number_parsing() { let result = Parser::parse_expression("42").unwrap(); assert_eq!(result.node, Expr::Number(42)); assert_eq!(result.span.start.line, 1); assert_eq!(result.span.start.column, 1); } #[test] fn test_binary_expression() { let result = Parser::parse_expression("2 + 3 * 4").unwrap(); if let Expr::Binary { left, op, right } = result.node { assert_eq!(op, BinaryOp::Add); assert_eq!(left.node, Expr::Number(2)); if let Expr::Binary { left, op, right } = right.node { assert_eq!(op, BinaryOp::Mul); assert_eq!(left.node, Expr::Number(3)); assert_eq!(right.node, Expr::Number(4)); } else { panic!("Expected binary expression for multiplication"); } } else { panic!("Expected binary expression"); } } #[test] fn test_function_call_parsing() { let result = Parser::parse_expression("add(1, 2)").unwrap(); if let Expr::Call { func, args } = result.node { assert_eq!(func.node, Expr::Identifier("add".to_string())); assert_eq!(args.len(), 2); assert_eq!(args[0].node, Expr::Number(1)); assert_eq!(args[1].node, Expr::Number(2)); } else { panic!("Expected function call"); } } #[test] fn test_let_expression() { let result = Parser::parse_expression("let x = 5 in x + 1").unwrap(); if let Expr::Let { name, value, body } = result.node { assert_eq!(name, "x"); assert_eq!(value.node, Expr::Number(5)); if let Expr::Binary { left, op, right } = body.node { assert_eq!(op, BinaryOp::Add); assert_eq!(left.node, Expr::Identifier("x".to_string())); assert_eq!(right.node, Expr::Number(1)); } else { panic!("Expected binary expression in let body"); } } else { panic!("Expected let expression"); } } #[test] fn test_error_location() { let result = Parser::parse_expression("2 + "); assert!(result.is_err()); let error = result.unwrap_err(); assert_eq!(error.location.line, 1); assert_eq!(error.location.column, 5); } #[test] fn test_lexer_with_locations() { let mut lexer = LocatedLexer::new("let x = 42"); let tokens = lexer.tokenize().unwrap(); assert_eq!(tokens.len(), 5); // let, x, =, 42, EOF assert_eq!(tokens[0].kind, TokenKind::Keyword("let".to_string())); assert_eq!(tokens[0].location.column, 1); assert_eq!(tokens[1].kind, TokenKind::Identifier); assert_eq!(tokens[1].text, "x"); assert_eq!(tokens[1].location.column, 5); assert_eq!(tokens[2].kind, TokenKind::Equals); assert_eq!(tokens[2].location.column, 7); } #[test] fn test_multiline_locations() { let input = "let x = 1\nlet y = 2"; let mut lexer = LocatedLexer::new(input); let tokens = lexer.tokenize().unwrap(); // Find the second 'let' token let second_let = tokens .iter() .find(|t| { matches!(t.kind, TokenKind::Keyword(ref k) if k == "let") && t.location.line == 2 }) .expect("Should find second let token"); assert_eq!(second_let.location.line, 2); assert_eq!(second_let.location.column, 1); } #[test] fn test_position_helpers() { let input = "line1\nline2\nline3"; let pos_info = Parser::get_position_info(input, 7); // Position of 'l' in "line2" assert_eq!(pos_info, Some((2, 2))); let line_content = Parser::get_line_content(input, 2); assert_eq!(line_content, Some("line2")); } #[test] fn test_source_range() { let result = Parser::parse_expression("42 + 3").unwrap(); let range = result.span.to_range(); assert_eq!(range, 0..6); } } /// AST node with location information #[derive(Debug, Clone, PartialEq)] pub struct Spanned<T> { pub node: T, pub span: SourceRange, } }
Every AST node is wrapped with Spanned
to preserve its source location alongside the parsed data.
Extracting Position Information
Convert nom_locate’s position data into user-friendly line and column numbers:
#![allow(unused)] fn main() { /// Get position information for error reporting pub fn get_position_info(input: &str, offset: usize) -> Option<(u32, usize)> { if offset <= input.len() { // Count lines and column up to the offset let mut line = 1; let mut col = 1; for (i, ch) in input.char_indices() { if i >= offset { break; } if ch == '\n' { line += 1; col = 1; } else { col += 1; } } Some((line, col)) } else { None } } }
This provides the exact line and column for error reporting and IDE features.
Error Reporting with Context
Generate helpful error messages with source context:
#![allow(unused)] fn main() { /// Extract line content for error reporting pub fn get_line_content(input: &str, line_number: u32) -> Option<&str> { input.lines().nth(line_number.saturating_sub(1) as usize) } }
This creates error displays showing the problematic line with a pointer to the exact error location.
Tokenizer with Location Tracking
Building a located lexer that preserves position information for each token:
#![allow(unused)] fn main() { use std::ops::Range; use nom::branch::alt; use nom::bytes::complete::{tag, take_while, take_while1}; use nom::character::complete::{char, digit1, multispace0, multispace1}; use nom::combinator::{map, recognize}; use nom::error::{Error, ErrorKind}; use nom::multi::separated_list0; use nom::sequence::{delimited, pair, preceded}; use nom::{Err, IResult, Parser as NomParser}; use nom_locate::{position, LocatedSpan}; /// A span type that tracks position information pub type Span<'a> = LocatedSpan<&'a str>; /// Location information extracted from a span #[derive(Debug, Clone, PartialEq)] pub struct Location { pub line: u32, pub column: usize, pub offset: usize, } impl Location { pub fn from_span(span: Span<'_>) -> Self { Self { line: span.location_line(), column: span.get_utf8_column(), offset: span.location_offset(), } } } /// A range of source locations #[derive(Debug, Clone, PartialEq)] pub struct SourceRange { pub start: Location, pub end: Location, } impl SourceRange { pub fn from_spans(start: Span<'_>, end: Span<'_>) -> Self { Self { start: Location::from_span(start), end: Location::from_span(end), } } pub fn to_range(&self) -> Range<usize> { self.start.offset..self.end.offset } } /// AST node with location information #[derive(Debug, Clone, PartialEq)] pub struct Spanned<T> { pub node: T, pub span: SourceRange, } impl<T> Spanned<T> { pub fn new(node: T, start: Span<'_>, end: Span<'_>) -> Self { Self { node, span: SourceRange::from_spans(start, end), } } } /// Expression AST for a simple language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Identifier(String), Binary { left: Box<Spanned<Expr>>, op: BinaryOp, right: Box<Spanned<Expr>>, }, Call { func: Box<Spanned<Expr>>, args: Vec<Spanned<Expr>>, }, Let { name: String, value: Box<Spanned<Expr>>, body: Box<Spanned<Expr>>, }, } #[derive(Debug, Clone, PartialEq)] pub enum BinaryOp { Add, Sub, Mul, Div, Eq, Lt, Gt, } /// Parser error with precise location information #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub location: Location, pub expected: Vec<String>, } impl ParseError { pub fn from_nom_error(_input: Span<'_>, error: Error<Span<'_>>) -> Self { let location = Location::from_span(error.input); let message = match error.code { ErrorKind::Tag => "unexpected token".to_string(), ErrorKind::Digit => "expected number".to_string(), ErrorKind::Alpha => "expected identifier".to_string(), ErrorKind::Char => "expected character".to_string(), ErrorKind::Many0 => "expected list".to_string(), ErrorKind::SeparatedList => "expected comma-separated list".to_string(), _ => format!("parse error: {:?}", error.code), }; Self { message, location, expected: vec![], } } } /// Main parser implementation pub struct Parser; impl Parser { /// Parse a complete expression from input pub fn parse_expression(input: &str) -> Result<Spanned<Expr>, ParseError> { let span = Span::new(input); match Self::expression(span) { Ok((_, expr)) => Ok(expr), Err(Err::Error(e)) | Err(Err::Failure(e)) => Err(ParseError::from_nom_error(span, e)), Err(Err::Incomplete(_)) => Err(ParseError { message: "incomplete input".to_string(), location: Location::from_span(span), expected: vec!["more input".to_string()], }), } } /// Parse an expression with precedence fn expression(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { Self::binary_expr(input, 0) } /// Parse binary expressions with operator precedence fn binary_expr(input: Span<'_>, min_prec: u8) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, mut left) = Self::primary_expr(input)?; let mut current_input = input; loop { let (input, _) = multispace0(current_input)?; // Try to parse an operator let op_result: IResult<Span<'_>, (BinaryOp, u8)> = alt(( map(char('+'), |_| (BinaryOp::Add, 1)), map(char('-'), |_| (BinaryOp::Sub, 1)), map(char('*'), |_| (BinaryOp::Mul, 2)), map(char('/'), |_| (BinaryOp::Div, 2)), map(tag("=="), |_| (BinaryOp::Eq, 0)), map(char('<'), |_| (BinaryOp::Lt, 0)), map(char('>'), |_| (BinaryOp::Gt, 0)), )) .parse(input); match op_result { Ok((input, (op, prec))) if prec >= min_prec => { let (input, _) = multispace0(input)?; let (input, right) = Self::binary_expr(input, prec + 1)?; let end_span = position(input)?; left = Spanned::new( Expr::Binary { left: Box::new(left), op, right: Box::new(right), }, start_pos.0, end_span.0, ); current_input = input; } _ => break, } } Ok((current_input, left)) } /// Parse primary expressions (atoms and parenthesized expressions) fn primary_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let (input, _) = multispace0(input)?; alt(( Self::parenthesized_expr, Self::function_call, Self::let_expr, Self::number, Self::identifier, )) .parse(input) } /// Parse parenthesized expressions fn parenthesized_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, expr) = delimited( char('('), preceded(multispace0, Self::expression), preceded(multispace0, char(')')), ) .parse(input)?; let end_pos = position(input)?; Ok((input, Spanned::new(expr.node, start_pos.0, end_pos.0))) } /// Parse function calls fn function_call(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, func) = Self::identifier(input)?; let (input, _) = multispace0(input)?; let (input, _) = char('(')(input)?; let (input, _) = multispace0(input)?; let (input, args) = separated_list0( delimited(multispace0, char(','), multispace0), Self::expression, ) .parse(input)?; let (input, _) = multispace0(input)?; let (input, _) = char(')')(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new( Expr::Call { func: Box::new(func), args, }, start_pos.0, end_pos.0, ), )) } /// Parse let expressions fn let_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, _) = tag("let")(input)?; let (input, _) = multispace1(input)?; let (input, name) = Self::identifier_string(input)?; let (input, _) = multispace0(input)?; let (input, _) = char('=')(input)?; let (input, _) = multispace0(input)?; let (input, value) = Self::expression(input)?; let (input, _) = multispace0(input)?; let (input, _) = tag("in")(input)?; let (input, _) = multispace0(input)?; let (input, body) = Self::expression(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new( Expr::Let { name, value: Box::new(value), body: Box::new(body), }, start_pos.0, end_pos.0, ), )) } /// Parse numbers fn number(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, digits) = digit1(input)?; let end_pos = position(input)?; let num = digits .fragment() .parse() .map_err(|_| Err::Error(Error::new(input, ErrorKind::Digit)))?; Ok(( input, Spanned::new(Expr::Number(num), start_pos.0, end_pos.0), )) } /// Parse identifiers fn identifier(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, ident) = Self::identifier_string(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new(Expr::Identifier(ident), start_pos.0, end_pos.0), )) } /// Parse identifier strings fn identifier_string(input: Span<'_>) -> IResult<Span<'_>, String> { let (input, ident) = recognize(pair( alt((tag("_"), take_while1(|c: char| c.is_ascii_alphabetic()))), take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), )) .parse(input)?; Ok((input, ident.fragment().to_string())) } /// Get position information for error reporting pub fn get_position_info(input: &str, offset: usize) -> Option<(u32, usize)> { if offset <= input.len() { // Count lines and column up to the offset let mut line = 1; let mut col = 1; for (i, ch) in input.char_indices() { if i >= offset { break; } if ch == '\n' { line += 1; col = 1; } else { col += 1; } } Some((line, col)) } else { None } } /// Extract line content for error reporting pub fn get_line_content(input: &str, line_number: u32) -> Option<&str> { input.lines().nth(line_number.saturating_sub(1) as usize) } } #[derive(Debug, Clone, PartialEq)] pub struct LocatedToken { pub kind: TokenKind, pub location: Location, pub text: String, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Number, Identifier, Keyword(String), Operator(String), LeftParen, RightParen, Comma, Equals, Whitespace, Comment, Eof, } impl<'a> LocatedLexer<'a> { pub fn new(input: &'a str) -> Self { Self { input: Span::new(input), } } /// Tokenize input while preserving location information pub fn tokenize(&mut self) -> Result<Vec<LocatedToken>, ParseError> { let mut tokens = Vec::new(); let mut current = self.input; while !current.fragment().is_empty() { let _start_pos = position(current).map_err(|e| { ParseError::from_nom_error( current, match e { Err::Error(err) | Err::Failure(err) => err, Err::Incomplete(_) => Error::new(current, ErrorKind::Complete), }, ) })?; let (remaining, token) = self.next_token(current).map_err(|e| match e { Err::Error(err) | Err::Failure(err) => ParseError::from_nom_error(current, err), Err::Incomplete(_) => ParseError { message: "incomplete token".to_string(), location: Location::from_span(current), expected: vec!["complete token".to_string()], }, })?; if let Some(token) = token { tokens.push(token); } current = remaining; } tokens.push(LocatedToken { kind: TokenKind::Eof, location: Location::from_span(current), text: String::new(), }); Ok(tokens) } fn next_token(&self, input: Span<'a>) -> IResult<Span<'a>, Option<LocatedToken>> { alt(( map(|i| self.whitespace_or_comment(i), |_| None), map(|i| self.keyword_or_identifier(i), Some), map(|i| self.number_token(i), Some), map(|i| self.operator_token(i), Some), map(|i| self.punctuation_token(i), Some), )) .parse(input) } fn whitespace_or_comment(&self, input: Span<'a>) -> IResult<Span<'a>, ()> { let (input, _) = alt(( multispace1, recognize((tag("//"), take_while(|c| c != '\n'))), )) .parse(input)?; Ok((input, ())) } fn keyword_or_identifier(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, ident) = recognize(pair( alt((tag("_"), take_while1(|c: char| c.is_ascii_alphabetic()))), take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), )) .parse(input)?; let text = ident.fragment().to_string(); let kind = match text.as_str() { "let" | "in" | "if" | "then" | "else" | "fn" => TokenKind::Keyword(text.clone()), _ => TokenKind::Identifier, }; Ok(( input, LocatedToken { kind, location: Location::from_span(start_pos.0), text, }, )) } fn number_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, number) = digit1(input)?; Ok(( input, LocatedToken { kind: TokenKind::Number, location: Location::from_span(start_pos.0), text: number.fragment().to_string(), }, )) } fn operator_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, op) = alt(( tag("=="), tag("<="), tag(">="), tag("!="), tag("+"), tag("-"), tag("*"), tag("/"), tag("<"), tag(">"), )) .parse(input)?; Ok(( input, LocatedToken { kind: TokenKind::Operator(op.fragment().to_string()), location: Location::from_span(start_pos.0), text: op.fragment().to_string(), }, )) } fn punctuation_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, punct) = alt(( map(char('('), |_| TokenKind::LeftParen), map(char(')'), |_| TokenKind::RightParen), map(char(','), |_| TokenKind::Comma), map(char('='), |_| TokenKind::Equals), )) .parse(input)?; let text = match punct { TokenKind::LeftParen => "(".to_string(), TokenKind::RightParen => ")".to_string(), TokenKind::Comma => ",".to_string(), TokenKind::Equals => "=".to_string(), _ => String::new(), }; Ok(( input, LocatedToken { kind: punct, location: Location::from_span(start_pos.0), text, }, )) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_location_tracking() { let input = "let x = 42\nin y + 1"; let span = Span::new(input); // Test line and column calculation assert_eq!(span.location_line(), 1); assert_eq!(span.get_utf8_column(), 1); // Test position after advancing let (remaining, _): (Span, Span) = tag::<&str, Span, Error<Span>>("let")(span).unwrap(); assert_eq!(remaining.location_offset(), 3); } #[test] fn test_number_parsing() { let result = Parser::parse_expression("42").unwrap(); assert_eq!(result.node, Expr::Number(42)); assert_eq!(result.span.start.line, 1); assert_eq!(result.span.start.column, 1); } #[test] fn test_binary_expression() { let result = Parser::parse_expression("2 + 3 * 4").unwrap(); if let Expr::Binary { left, op, right } = result.node { assert_eq!(op, BinaryOp::Add); assert_eq!(left.node, Expr::Number(2)); if let Expr::Binary { left, op, right } = right.node { assert_eq!(op, BinaryOp::Mul); assert_eq!(left.node, Expr::Number(3)); assert_eq!(right.node, Expr::Number(4)); } else { panic!("Expected binary expression for multiplication"); } } else { panic!("Expected binary expression"); } } #[test] fn test_function_call_parsing() { let result = Parser::parse_expression("add(1, 2)").unwrap(); if let Expr::Call { func, args } = result.node { assert_eq!(func.node, Expr::Identifier("add".to_string())); assert_eq!(args.len(), 2); assert_eq!(args[0].node, Expr::Number(1)); assert_eq!(args[1].node, Expr::Number(2)); } else { panic!("Expected function call"); } } #[test] fn test_let_expression() { let result = Parser::parse_expression("let x = 5 in x + 1").unwrap(); if let Expr::Let { name, value, body } = result.node { assert_eq!(name, "x"); assert_eq!(value.node, Expr::Number(5)); if let Expr::Binary { left, op, right } = body.node { assert_eq!(op, BinaryOp::Add); assert_eq!(left.node, Expr::Identifier("x".to_string())); assert_eq!(right.node, Expr::Number(1)); } else { panic!("Expected binary expression in let body"); } } else { panic!("Expected let expression"); } } #[test] fn test_error_location() { let result = Parser::parse_expression("2 + "); assert!(result.is_err()); let error = result.unwrap_err(); assert_eq!(error.location.line, 1); assert_eq!(error.location.column, 5); } #[test] fn test_lexer_with_locations() { let mut lexer = LocatedLexer::new("let x = 42"); let tokens = lexer.tokenize().unwrap(); assert_eq!(tokens.len(), 5); // let, x, =, 42, EOF assert_eq!(tokens[0].kind, TokenKind::Keyword("let".to_string())); assert_eq!(tokens[0].location.column, 1); assert_eq!(tokens[1].kind, TokenKind::Identifier); assert_eq!(tokens[1].text, "x"); assert_eq!(tokens[1].location.column, 5); assert_eq!(tokens[2].kind, TokenKind::Equals); assert_eq!(tokens[2].location.column, 7); } #[test] fn test_multiline_locations() { let input = "let x = 1\nlet y = 2"; let mut lexer = LocatedLexer::new(input); let tokens = lexer.tokenize().unwrap(); // Find the second 'let' token let second_let = tokens .iter() .find(|t| { matches!(t.kind, TokenKind::Keyword(ref k) if k == "let") && t.location.line == 2 }) .expect("Should find second let token"); assert_eq!(second_let.location.line, 2); assert_eq!(second_let.location.column, 1); } #[test] fn test_position_helpers() { let input = "line1\nline2\nline3"; let pos_info = Parser::get_position_info(input, 7); // Position of 'l' in "line2" assert_eq!(pos_info, Some((2, 2))); let line_content = Parser::get_line_content(input, 2); assert_eq!(line_content, Some("line2")); } #[test] fn test_source_range() { let result = Parser::parse_expression("42 + 3").unwrap(); let range = result.span.to_range(); assert_eq!(range, 0..6); } } /// A lexer that preserves location information for each token pub struct LocatedLexer<'a> { input: Span<'a>, } }
#![allow(unused)] fn main() { use std::ops::Range; use nom::branch::alt; use nom::bytes::complete::{tag, take_while, take_while1}; use nom::character::complete::{char, digit1, multispace0, multispace1}; use nom::combinator::{map, recognize}; use nom::error::{Error, ErrorKind}; use nom::multi::separated_list0; use nom::sequence::{delimited, pair, preceded}; use nom::{Err, IResult, Parser as NomParser}; use nom_locate::{position, LocatedSpan}; /// A span type that tracks position information pub type Span<'a> = LocatedSpan<&'a str>; /// Location information extracted from a span #[derive(Debug, Clone, PartialEq)] pub struct Location { pub line: u32, pub column: usize, pub offset: usize, } impl Location { pub fn from_span(span: Span<'_>) -> Self { Self { line: span.location_line(), column: span.get_utf8_column(), offset: span.location_offset(), } } } /// A range of source locations #[derive(Debug, Clone, PartialEq)] pub struct SourceRange { pub start: Location, pub end: Location, } impl SourceRange { pub fn from_spans(start: Span<'_>, end: Span<'_>) -> Self { Self { start: Location::from_span(start), end: Location::from_span(end), } } pub fn to_range(&self) -> Range<usize> { self.start.offset..self.end.offset } } /// AST node with location information #[derive(Debug, Clone, PartialEq)] pub struct Spanned<T> { pub node: T, pub span: SourceRange, } impl<T> Spanned<T> { pub fn new(node: T, start: Span<'_>, end: Span<'_>) -> Self { Self { node, span: SourceRange::from_spans(start, end), } } } /// Expression AST for a simple language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Identifier(String), Binary { left: Box<Spanned<Expr>>, op: BinaryOp, right: Box<Spanned<Expr>>, }, Call { func: Box<Spanned<Expr>>, args: Vec<Spanned<Expr>>, }, Let { name: String, value: Box<Spanned<Expr>>, body: Box<Spanned<Expr>>, }, } #[derive(Debug, Clone, PartialEq)] pub enum BinaryOp { Add, Sub, Mul, Div, Eq, Lt, Gt, } /// Parser error with precise location information #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub location: Location, pub expected: Vec<String>, } impl ParseError { pub fn from_nom_error(_input: Span<'_>, error: Error<Span<'_>>) -> Self { let location = Location::from_span(error.input); let message = match error.code { ErrorKind::Tag => "unexpected token".to_string(), ErrorKind::Digit => "expected number".to_string(), ErrorKind::Alpha => "expected identifier".to_string(), ErrorKind::Char => "expected character".to_string(), ErrorKind::Many0 => "expected list".to_string(), ErrorKind::SeparatedList => "expected comma-separated list".to_string(), _ => format!("parse error: {:?}", error.code), }; Self { message, location, expected: vec![], } } } /// Main parser implementation pub struct Parser; impl Parser { /// Parse a complete expression from input pub fn parse_expression(input: &str) -> Result<Spanned<Expr>, ParseError> { let span = Span::new(input); match Self::expression(span) { Ok((_, expr)) => Ok(expr), Err(Err::Error(e)) | Err(Err::Failure(e)) => Err(ParseError::from_nom_error(span, e)), Err(Err::Incomplete(_)) => Err(ParseError { message: "incomplete input".to_string(), location: Location::from_span(span), expected: vec!["more input".to_string()], }), } } /// Parse an expression with precedence fn expression(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { Self::binary_expr(input, 0) } /// Parse binary expressions with operator precedence fn binary_expr(input: Span<'_>, min_prec: u8) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, mut left) = Self::primary_expr(input)?; let mut current_input = input; loop { let (input, _) = multispace0(current_input)?; // Try to parse an operator let op_result: IResult<Span<'_>, (BinaryOp, u8)> = alt(( map(char('+'), |_| (BinaryOp::Add, 1)), map(char('-'), |_| (BinaryOp::Sub, 1)), map(char('*'), |_| (BinaryOp::Mul, 2)), map(char('/'), |_| (BinaryOp::Div, 2)), map(tag("=="), |_| (BinaryOp::Eq, 0)), map(char('<'), |_| (BinaryOp::Lt, 0)), map(char('>'), |_| (BinaryOp::Gt, 0)), )) .parse(input); match op_result { Ok((input, (op, prec))) if prec >= min_prec => { let (input, _) = multispace0(input)?; let (input, right) = Self::binary_expr(input, prec + 1)?; let end_span = position(input)?; left = Spanned::new( Expr::Binary { left: Box::new(left), op, right: Box::new(right), }, start_pos.0, end_span.0, ); current_input = input; } _ => break, } } Ok((current_input, left)) } /// Parse primary expressions (atoms and parenthesized expressions) fn primary_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let (input, _) = multispace0(input)?; alt(( Self::parenthesized_expr, Self::function_call, Self::let_expr, Self::number, Self::identifier, )) .parse(input) } /// Parse parenthesized expressions fn parenthesized_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, expr) = delimited( char('('), preceded(multispace0, Self::expression), preceded(multispace0, char(')')), ) .parse(input)?; let end_pos = position(input)?; Ok((input, Spanned::new(expr.node, start_pos.0, end_pos.0))) } /// Parse function calls fn function_call(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, func) = Self::identifier(input)?; let (input, _) = multispace0(input)?; let (input, _) = char('(')(input)?; let (input, _) = multispace0(input)?; let (input, args) = separated_list0( delimited(multispace0, char(','), multispace0), Self::expression, ) .parse(input)?; let (input, _) = multispace0(input)?; let (input, _) = char(')')(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new( Expr::Call { func: Box::new(func), args, }, start_pos.0, end_pos.0, ), )) } /// Parse let expressions fn let_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, _) = tag("let")(input)?; let (input, _) = multispace1(input)?; let (input, name) = Self::identifier_string(input)?; let (input, _) = multispace0(input)?; let (input, _) = char('=')(input)?; let (input, _) = multispace0(input)?; let (input, value) = Self::expression(input)?; let (input, _) = multispace0(input)?; let (input, _) = tag("in")(input)?; let (input, _) = multispace0(input)?; let (input, body) = Self::expression(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new( Expr::Let { name, value: Box::new(value), body: Box::new(body), }, start_pos.0, end_pos.0, ), )) } /// Parse numbers fn number(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, digits) = digit1(input)?; let end_pos = position(input)?; let num = digits .fragment() .parse() .map_err(|_| Err::Error(Error::new(input, ErrorKind::Digit)))?; Ok(( input, Spanned::new(Expr::Number(num), start_pos.0, end_pos.0), )) } /// Parse identifiers fn identifier(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, ident) = Self::identifier_string(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new(Expr::Identifier(ident), start_pos.0, end_pos.0), )) } /// Parse identifier strings fn identifier_string(input: Span<'_>) -> IResult<Span<'_>, String> { let (input, ident) = recognize(pair( alt((tag("_"), take_while1(|c: char| c.is_ascii_alphabetic()))), take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), )) .parse(input)?; Ok((input, ident.fragment().to_string())) } /// Get position information for error reporting pub fn get_position_info(input: &str, offset: usize) -> Option<(u32, usize)> { if offset <= input.len() { // Count lines and column up to the offset let mut line = 1; let mut col = 1; for (i, ch) in input.char_indices() { if i >= offset { break; } if ch == '\n' { line += 1; col = 1; } else { col += 1; } } Some((line, col)) } else { None } } /// Extract line content for error reporting pub fn get_line_content(input: &str, line_number: u32) -> Option<&str> { input.lines().nth(line_number.saturating_sub(1) as usize) } } /// A lexer that preserves location information for each token pub struct LocatedLexer<'a> { input: Span<'a>, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Number, Identifier, Keyword(String), Operator(String), LeftParen, RightParen, Comma, Equals, Whitespace, Comment, Eof, } impl<'a> LocatedLexer<'a> { pub fn new(input: &'a str) -> Self { Self { input: Span::new(input), } } /// Tokenize input while preserving location information pub fn tokenize(&mut self) -> Result<Vec<LocatedToken>, ParseError> { let mut tokens = Vec::new(); let mut current = self.input; while !current.fragment().is_empty() { let _start_pos = position(current).map_err(|e| { ParseError::from_nom_error( current, match e { Err::Error(err) | Err::Failure(err) => err, Err::Incomplete(_) => Error::new(current, ErrorKind::Complete), }, ) })?; let (remaining, token) = self.next_token(current).map_err(|e| match e { Err::Error(err) | Err::Failure(err) => ParseError::from_nom_error(current, err), Err::Incomplete(_) => ParseError { message: "incomplete token".to_string(), location: Location::from_span(current), expected: vec!["complete token".to_string()], }, })?; if let Some(token) = token { tokens.push(token); } current = remaining; } tokens.push(LocatedToken { kind: TokenKind::Eof, location: Location::from_span(current), text: String::new(), }); Ok(tokens) } fn next_token(&self, input: Span<'a>) -> IResult<Span<'a>, Option<LocatedToken>> { alt(( map(|i| self.whitespace_or_comment(i), |_| None), map(|i| self.keyword_or_identifier(i), Some), map(|i| self.number_token(i), Some), map(|i| self.operator_token(i), Some), map(|i| self.punctuation_token(i), Some), )) .parse(input) } fn whitespace_or_comment(&self, input: Span<'a>) -> IResult<Span<'a>, ()> { let (input, _) = alt(( multispace1, recognize((tag("//"), take_while(|c| c != '\n'))), )) .parse(input)?; Ok((input, ())) } fn keyword_or_identifier(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, ident) = recognize(pair( alt((tag("_"), take_while1(|c: char| c.is_ascii_alphabetic()))), take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), )) .parse(input)?; let text = ident.fragment().to_string(); let kind = match text.as_str() { "let" | "in" | "if" | "then" | "else" | "fn" => TokenKind::Keyword(text.clone()), _ => TokenKind::Identifier, }; Ok(( input, LocatedToken { kind, location: Location::from_span(start_pos.0), text, }, )) } fn number_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, number) = digit1(input)?; Ok(( input, LocatedToken { kind: TokenKind::Number, location: Location::from_span(start_pos.0), text: number.fragment().to_string(), }, )) } fn operator_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, op) = alt(( tag("=="), tag("<="), tag(">="), tag("!="), tag("+"), tag("-"), tag("*"), tag("/"), tag("<"), tag(">"), )) .parse(input)?; Ok(( input, LocatedToken { kind: TokenKind::Operator(op.fragment().to_string()), location: Location::from_span(start_pos.0), text: op.fragment().to_string(), }, )) } fn punctuation_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, punct) = alt(( map(char('('), |_| TokenKind::LeftParen), map(char(')'), |_| TokenKind::RightParen), map(char(','), |_| TokenKind::Comma), map(char('='), |_| TokenKind::Equals), )) .parse(input)?; let text = match punct { TokenKind::LeftParen => "(".to_string(), TokenKind::RightParen => ")".to_string(), TokenKind::Comma => ",".to_string(), TokenKind::Equals => "=".to_string(), _ => String::new(), }; Ok(( input, LocatedToken { kind: punct, location: Location::from_span(start_pos.0), text, }, )) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_location_tracking() { let input = "let x = 42\nin y + 1"; let span = Span::new(input); // Test line and column calculation assert_eq!(span.location_line(), 1); assert_eq!(span.get_utf8_column(), 1); // Test position after advancing let (remaining, _): (Span, Span) = tag::<&str, Span, Error<Span>>("let")(span).unwrap(); assert_eq!(remaining.location_offset(), 3); } #[test] fn test_number_parsing() { let result = Parser::parse_expression("42").unwrap(); assert_eq!(result.node, Expr::Number(42)); assert_eq!(result.span.start.line, 1); assert_eq!(result.span.start.column, 1); } #[test] fn test_binary_expression() { let result = Parser::parse_expression("2 + 3 * 4").unwrap(); if let Expr::Binary { left, op, right } = result.node { assert_eq!(op, BinaryOp::Add); assert_eq!(left.node, Expr::Number(2)); if let Expr::Binary { left, op, right } = right.node { assert_eq!(op, BinaryOp::Mul); assert_eq!(left.node, Expr::Number(3)); assert_eq!(right.node, Expr::Number(4)); } else { panic!("Expected binary expression for multiplication"); } } else { panic!("Expected binary expression"); } } #[test] fn test_function_call_parsing() { let result = Parser::parse_expression("add(1, 2)").unwrap(); if let Expr::Call { func, args } = result.node { assert_eq!(func.node, Expr::Identifier("add".to_string())); assert_eq!(args.len(), 2); assert_eq!(args[0].node, Expr::Number(1)); assert_eq!(args[1].node, Expr::Number(2)); } else { panic!("Expected function call"); } } #[test] fn test_let_expression() { let result = Parser::parse_expression("let x = 5 in x + 1").unwrap(); if let Expr::Let { name, value, body } = result.node { assert_eq!(name, "x"); assert_eq!(value.node, Expr::Number(5)); if let Expr::Binary { left, op, right } = body.node { assert_eq!(op, BinaryOp::Add); assert_eq!(left.node, Expr::Identifier("x".to_string())); assert_eq!(right.node, Expr::Number(1)); } else { panic!("Expected binary expression in let body"); } } else { panic!("Expected let expression"); } } #[test] fn test_error_location() { let result = Parser::parse_expression("2 + "); assert!(result.is_err()); let error = result.unwrap_err(); assert_eq!(error.location.line, 1); assert_eq!(error.location.column, 5); } #[test] fn test_lexer_with_locations() { let mut lexer = LocatedLexer::new("let x = 42"); let tokens = lexer.tokenize().unwrap(); assert_eq!(tokens.len(), 5); // let, x, =, 42, EOF assert_eq!(tokens[0].kind, TokenKind::Keyword("let".to_string())); assert_eq!(tokens[0].location.column, 1); assert_eq!(tokens[1].kind, TokenKind::Identifier); assert_eq!(tokens[1].text, "x"); assert_eq!(tokens[1].location.column, 5); assert_eq!(tokens[2].kind, TokenKind::Equals); assert_eq!(tokens[2].location.column, 7); } #[test] fn test_multiline_locations() { let input = "let x = 1\nlet y = 2"; let mut lexer = LocatedLexer::new(input); let tokens = lexer.tokenize().unwrap(); // Find the second 'let' token let second_let = tokens .iter() .find(|t| { matches!(t.kind, TokenKind::Keyword(ref k) if k == "let") && t.location.line == 2 }) .expect("Should find second let token"); assert_eq!(second_let.location.line, 2); assert_eq!(second_let.location.column, 1); } #[test] fn test_position_helpers() { let input = "line1\nline2\nline3"; let pos_info = Parser::get_position_info(input, 7); // Position of 'l' in "line2" assert_eq!(pos_info, Some((2, 2))); let line_content = Parser::get_line_content(input, 2); assert_eq!(line_content, Some("line2")); } #[test] fn test_source_range() { let result = Parser::parse_expression("42 + 3").unwrap(); let range = result.span.to_range(); assert_eq!(range, 0..6); } } #[derive(Debug, Clone, PartialEq)] pub struct LocatedToken { pub kind: TokenKind, pub location: Location, pub text: String, } }
#![allow(unused)] fn main() { use std::ops::Range; use nom::branch::alt; use nom::bytes::complete::{tag, take_while, take_while1}; use nom::character::complete::{char, digit1, multispace0, multispace1}; use nom::combinator::{map, recognize}; use nom::error::{Error, ErrorKind}; use nom::multi::separated_list0; use nom::sequence::{delimited, pair, preceded}; use nom::{Err, IResult, Parser as NomParser}; use nom_locate::{position, LocatedSpan}; /// A span type that tracks position information pub type Span<'a> = LocatedSpan<&'a str>; /// Location information extracted from a span #[derive(Debug, Clone, PartialEq)] pub struct Location { pub line: u32, pub column: usize, pub offset: usize, } impl Location { pub fn from_span(span: Span<'_>) -> Self { Self { line: span.location_line(), column: span.get_utf8_column(), offset: span.location_offset(), } } } /// A range of source locations #[derive(Debug, Clone, PartialEq)] pub struct SourceRange { pub start: Location, pub end: Location, } impl SourceRange { pub fn from_spans(start: Span<'_>, end: Span<'_>) -> Self { Self { start: Location::from_span(start), end: Location::from_span(end), } } pub fn to_range(&self) -> Range<usize> { self.start.offset..self.end.offset } } /// AST node with location information #[derive(Debug, Clone, PartialEq)] pub struct Spanned<T> { pub node: T, pub span: SourceRange, } impl<T> Spanned<T> { pub fn new(node: T, start: Span<'_>, end: Span<'_>) -> Self { Self { node, span: SourceRange::from_spans(start, end), } } } /// Expression AST for a simple language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Identifier(String), Binary { left: Box<Spanned<Expr>>, op: BinaryOp, right: Box<Spanned<Expr>>, }, Call { func: Box<Spanned<Expr>>, args: Vec<Spanned<Expr>>, }, Let { name: String, value: Box<Spanned<Expr>>, body: Box<Spanned<Expr>>, }, } #[derive(Debug, Clone, PartialEq)] pub enum BinaryOp { Add, Sub, Mul, Div, Eq, Lt, Gt, } /// Parser error with precise location information #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub location: Location, pub expected: Vec<String>, } impl ParseError { pub fn from_nom_error(_input: Span<'_>, error: Error<Span<'_>>) -> Self { let location = Location::from_span(error.input); let message = match error.code { ErrorKind::Tag => "unexpected token".to_string(), ErrorKind::Digit => "expected number".to_string(), ErrorKind::Alpha => "expected identifier".to_string(), ErrorKind::Char => "expected character".to_string(), ErrorKind::Many0 => "expected list".to_string(), ErrorKind::SeparatedList => "expected comma-separated list".to_string(), _ => format!("parse error: {:?}", error.code), }; Self { message, location, expected: vec![], } } } /// Main parser implementation pub struct Parser; impl Parser { /// Parse a complete expression from input pub fn parse_expression(input: &str) -> Result<Spanned<Expr>, ParseError> { let span = Span::new(input); match Self::expression(span) { Ok((_, expr)) => Ok(expr), Err(Err::Error(e)) | Err(Err::Failure(e)) => Err(ParseError::from_nom_error(span, e)), Err(Err::Incomplete(_)) => Err(ParseError { message: "incomplete input".to_string(), location: Location::from_span(span), expected: vec!["more input".to_string()], }), } } /// Parse an expression with precedence fn expression(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { Self::binary_expr(input, 0) } /// Parse binary expressions with operator precedence fn binary_expr(input: Span<'_>, min_prec: u8) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, mut left) = Self::primary_expr(input)?; let mut current_input = input; loop { let (input, _) = multispace0(current_input)?; // Try to parse an operator let op_result: IResult<Span<'_>, (BinaryOp, u8)> = alt(( map(char('+'), |_| (BinaryOp::Add, 1)), map(char('-'), |_| (BinaryOp::Sub, 1)), map(char('*'), |_| (BinaryOp::Mul, 2)), map(char('/'), |_| (BinaryOp::Div, 2)), map(tag("=="), |_| (BinaryOp::Eq, 0)), map(char('<'), |_| (BinaryOp::Lt, 0)), map(char('>'), |_| (BinaryOp::Gt, 0)), )) .parse(input); match op_result { Ok((input, (op, prec))) if prec >= min_prec => { let (input, _) = multispace0(input)?; let (input, right) = Self::binary_expr(input, prec + 1)?; let end_span = position(input)?; left = Spanned::new( Expr::Binary { left: Box::new(left), op, right: Box::new(right), }, start_pos.0, end_span.0, ); current_input = input; } _ => break, } } Ok((current_input, left)) } /// Parse primary expressions (atoms and parenthesized expressions) fn primary_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let (input, _) = multispace0(input)?; alt(( Self::parenthesized_expr, Self::function_call, Self::let_expr, Self::number, Self::identifier, )) .parse(input) } /// Parse parenthesized expressions fn parenthesized_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, expr) = delimited( char('('), preceded(multispace0, Self::expression), preceded(multispace0, char(')')), ) .parse(input)?; let end_pos = position(input)?; Ok((input, Spanned::new(expr.node, start_pos.0, end_pos.0))) } /// Parse function calls fn function_call(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, func) = Self::identifier(input)?; let (input, _) = multispace0(input)?; let (input, _) = char('(')(input)?; let (input, _) = multispace0(input)?; let (input, args) = separated_list0( delimited(multispace0, char(','), multispace0), Self::expression, ) .parse(input)?; let (input, _) = multispace0(input)?; let (input, _) = char(')')(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new( Expr::Call { func: Box::new(func), args, }, start_pos.0, end_pos.0, ), )) } /// Parse let expressions fn let_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, _) = tag("let")(input)?; let (input, _) = multispace1(input)?; let (input, name) = Self::identifier_string(input)?; let (input, _) = multispace0(input)?; let (input, _) = char('=')(input)?; let (input, _) = multispace0(input)?; let (input, value) = Self::expression(input)?; let (input, _) = multispace0(input)?; let (input, _) = tag("in")(input)?; let (input, _) = multispace0(input)?; let (input, body) = Self::expression(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new( Expr::Let { name, value: Box::new(value), body: Box::new(body), }, start_pos.0, end_pos.0, ), )) } /// Parse numbers fn number(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, digits) = digit1(input)?; let end_pos = position(input)?; let num = digits .fragment() .parse() .map_err(|_| Err::Error(Error::new(input, ErrorKind::Digit)))?; Ok(( input, Spanned::new(Expr::Number(num), start_pos.0, end_pos.0), )) } /// Parse identifiers fn identifier(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, ident) = Self::identifier_string(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new(Expr::Identifier(ident), start_pos.0, end_pos.0), )) } /// Parse identifier strings fn identifier_string(input: Span<'_>) -> IResult<Span<'_>, String> { let (input, ident) = recognize(pair( alt((tag("_"), take_while1(|c: char| c.is_ascii_alphabetic()))), take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), )) .parse(input)?; Ok((input, ident.fragment().to_string())) } /// Get position information for error reporting pub fn get_position_info(input: &str, offset: usize) -> Option<(u32, usize)> { if offset <= input.len() { // Count lines and column up to the offset let mut line = 1; let mut col = 1; for (i, ch) in input.char_indices() { if i >= offset { break; } if ch == '\n' { line += 1; col = 1; } else { col += 1; } } Some((line, col)) } else { None } } /// Extract line content for error reporting pub fn get_line_content(input: &str, line_number: u32) -> Option<&str> { input.lines().nth(line_number.saturating_sub(1) as usize) } } /// A lexer that preserves location information for each token pub struct LocatedLexer<'a> { input: Span<'a>, } #[derive(Debug, Clone, PartialEq)] pub struct LocatedToken { pub kind: TokenKind, pub location: Location, pub text: String, } impl<'a> LocatedLexer<'a> { pub fn new(input: &'a str) -> Self { Self { input: Span::new(input), } } /// Tokenize input while preserving location information pub fn tokenize(&mut self) -> Result<Vec<LocatedToken>, ParseError> { let mut tokens = Vec::new(); let mut current = self.input; while !current.fragment().is_empty() { let _start_pos = position(current).map_err(|e| { ParseError::from_nom_error( current, match e { Err::Error(err) | Err::Failure(err) => err, Err::Incomplete(_) => Error::new(current, ErrorKind::Complete), }, ) })?; let (remaining, token) = self.next_token(current).map_err(|e| match e { Err::Error(err) | Err::Failure(err) => ParseError::from_nom_error(current, err), Err::Incomplete(_) => ParseError { message: "incomplete token".to_string(), location: Location::from_span(current), expected: vec!["complete token".to_string()], }, })?; if let Some(token) = token { tokens.push(token); } current = remaining; } tokens.push(LocatedToken { kind: TokenKind::Eof, location: Location::from_span(current), text: String::new(), }); Ok(tokens) } fn next_token(&self, input: Span<'a>) -> IResult<Span<'a>, Option<LocatedToken>> { alt(( map(|i| self.whitespace_or_comment(i), |_| None), map(|i| self.keyword_or_identifier(i), Some), map(|i| self.number_token(i), Some), map(|i| self.operator_token(i), Some), map(|i| self.punctuation_token(i), Some), )) .parse(input) } fn whitespace_or_comment(&self, input: Span<'a>) -> IResult<Span<'a>, ()> { let (input, _) = alt(( multispace1, recognize((tag("//"), take_while(|c| c != '\n'))), )) .parse(input)?; Ok((input, ())) } fn keyword_or_identifier(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, ident) = recognize(pair( alt((tag("_"), take_while1(|c: char| c.is_ascii_alphabetic()))), take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), )) .parse(input)?; let text = ident.fragment().to_string(); let kind = match text.as_str() { "let" | "in" | "if" | "then" | "else" | "fn" => TokenKind::Keyword(text.clone()), _ => TokenKind::Identifier, }; Ok(( input, LocatedToken { kind, location: Location::from_span(start_pos.0), text, }, )) } fn number_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, number) = digit1(input)?; Ok(( input, LocatedToken { kind: TokenKind::Number, location: Location::from_span(start_pos.0), text: number.fragment().to_string(), }, )) } fn operator_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, op) = alt(( tag("=="), tag("<="), tag(">="), tag("!="), tag("+"), tag("-"), tag("*"), tag("/"), tag("<"), tag(">"), )) .parse(input)?; Ok(( input, LocatedToken { kind: TokenKind::Operator(op.fragment().to_string()), location: Location::from_span(start_pos.0), text: op.fragment().to_string(), }, )) } fn punctuation_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, punct) = alt(( map(char('('), |_| TokenKind::LeftParen), map(char(')'), |_| TokenKind::RightParen), map(char(','), |_| TokenKind::Comma), map(char('='), |_| TokenKind::Equals), )) .parse(input)?; let text = match punct { TokenKind::LeftParen => "(".to_string(), TokenKind::RightParen => ")".to_string(), TokenKind::Comma => ",".to_string(), TokenKind::Equals => "=".to_string(), _ => String::new(), }; Ok(( input, LocatedToken { kind: punct, location: Location::from_span(start_pos.0), text, }, )) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_location_tracking() { let input = "let x = 42\nin y + 1"; let span = Span::new(input); // Test line and column calculation assert_eq!(span.location_line(), 1); assert_eq!(span.get_utf8_column(), 1); // Test position after advancing let (remaining, _): (Span, Span) = tag::<&str, Span, Error<Span>>("let")(span).unwrap(); assert_eq!(remaining.location_offset(), 3); } #[test] fn test_number_parsing() { let result = Parser::parse_expression("42").unwrap(); assert_eq!(result.node, Expr::Number(42)); assert_eq!(result.span.start.line, 1); assert_eq!(result.span.start.column, 1); } #[test] fn test_binary_expression() { let result = Parser::parse_expression("2 + 3 * 4").unwrap(); if let Expr::Binary { left, op, right } = result.node { assert_eq!(op, BinaryOp::Add); assert_eq!(left.node, Expr::Number(2)); if let Expr::Binary { left, op, right } = right.node { assert_eq!(op, BinaryOp::Mul); assert_eq!(left.node, Expr::Number(3)); assert_eq!(right.node, Expr::Number(4)); } else { panic!("Expected binary expression for multiplication"); } } else { panic!("Expected binary expression"); } } #[test] fn test_function_call_parsing() { let result = Parser::parse_expression("add(1, 2)").unwrap(); if let Expr::Call { func, args } = result.node { assert_eq!(func.node, Expr::Identifier("add".to_string())); assert_eq!(args.len(), 2); assert_eq!(args[0].node, Expr::Number(1)); assert_eq!(args[1].node, Expr::Number(2)); } else { panic!("Expected function call"); } } #[test] fn test_let_expression() { let result = Parser::parse_expression("let x = 5 in x + 1").unwrap(); if let Expr::Let { name, value, body } = result.node { assert_eq!(name, "x"); assert_eq!(value.node, Expr::Number(5)); if let Expr::Binary { left, op, right } = body.node { assert_eq!(op, BinaryOp::Add); assert_eq!(left.node, Expr::Identifier("x".to_string())); assert_eq!(right.node, Expr::Number(1)); } else { panic!("Expected binary expression in let body"); } } else { panic!("Expected let expression"); } } #[test] fn test_error_location() { let result = Parser::parse_expression("2 + "); assert!(result.is_err()); let error = result.unwrap_err(); assert_eq!(error.location.line, 1); assert_eq!(error.location.column, 5); } #[test] fn test_lexer_with_locations() { let mut lexer = LocatedLexer::new("let x = 42"); let tokens = lexer.tokenize().unwrap(); assert_eq!(tokens.len(), 5); // let, x, =, 42, EOF assert_eq!(tokens[0].kind, TokenKind::Keyword("let".to_string())); assert_eq!(tokens[0].location.column, 1); assert_eq!(tokens[1].kind, TokenKind::Identifier); assert_eq!(tokens[1].text, "x"); assert_eq!(tokens[1].location.column, 5); assert_eq!(tokens[2].kind, TokenKind::Equals); assert_eq!(tokens[2].location.column, 7); } #[test] fn test_multiline_locations() { let input = "let x = 1\nlet y = 2"; let mut lexer = LocatedLexer::new(input); let tokens = lexer.tokenize().unwrap(); // Find the second 'let' token let second_let = tokens .iter() .find(|t| { matches!(t.kind, TokenKind::Keyword(ref k) if k == "let") && t.location.line == 2 }) .expect("Should find second let token"); assert_eq!(second_let.location.line, 2); assert_eq!(second_let.location.column, 1); } #[test] fn test_position_helpers() { let input = "line1\nline2\nline3"; let pos_info = Parser::get_position_info(input, 7); // Position of 'l' in "line2" assert_eq!(pos_info, Some((2, 2))); let line_content = Parser::get_line_content(input, 2); assert_eq!(line_content, Some("line2")); } #[test] fn test_source_range() { let result = Parser::parse_expression("42 + 3").unwrap(); let range = result.span.to_range(); assert_eq!(range, 0..6); } } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Number, Identifier, Keyword(String), Operator(String), LeftParen, RightParen, Comma, Equals, Whitespace, Comment, Eof, } }
Each token knows exactly where it appeared in the source text, enabling precise error messages even from later compilation phases.
Binary Expression Parsing
Handling operator precedence while maintaining location information:
#![allow(unused)] fn main() { use std::ops::Range; use nom::branch::alt; use nom::bytes::complete::{tag, take_while, take_while1}; use nom::character::complete::{char, digit1, multispace0, multispace1}; use nom::combinator::{map, recognize}; use nom::error::{Error, ErrorKind}; use nom::multi::separated_list0; use nom::sequence::{delimited, pair, preceded}; use nom::{Err, IResult, Parser as NomParser}; use nom_locate::{position, LocatedSpan}; /// A span type that tracks position information pub type Span<'a> = LocatedSpan<&'a str>; /// Location information extracted from a span #[derive(Debug, Clone, PartialEq)] pub struct Location { pub line: u32, pub column: usize, pub offset: usize, } impl Location { pub fn from_span(span: Span<'_>) -> Self { Self { line: span.location_line(), column: span.get_utf8_column(), offset: span.location_offset(), } } } /// A range of source locations #[derive(Debug, Clone, PartialEq)] pub struct SourceRange { pub start: Location, pub end: Location, } impl SourceRange { pub fn from_spans(start: Span<'_>, end: Span<'_>) -> Self { Self { start: Location::from_span(start), end: Location::from_span(end), } } pub fn to_range(&self) -> Range<usize> { self.start.offset..self.end.offset } } /// AST node with location information #[derive(Debug, Clone, PartialEq)] pub struct Spanned<T> { pub node: T, pub span: SourceRange, } impl<T> Spanned<T> { pub fn new(node: T, start: Span<'_>, end: Span<'_>) -> Self { Self { node, span: SourceRange::from_spans(start, end), } } } /// Expression AST for a simple language #[derive(Debug, Clone, PartialEq)] pub enum Expr { Number(i64), Identifier(String), Binary { left: Box<Spanned<Expr>>, op: BinaryOp, right: Box<Spanned<Expr>>, }, Call { func: Box<Spanned<Expr>>, args: Vec<Spanned<Expr>>, }, Let { name: String, value: Box<Spanned<Expr>>, body: Box<Spanned<Expr>>, }, } /// Parser error with precise location information #[derive(Debug, Clone)] pub struct ParseError { pub message: String, pub location: Location, pub expected: Vec<String>, } impl ParseError { pub fn from_nom_error(_input: Span<'_>, error: Error<Span<'_>>) -> Self { let location = Location::from_span(error.input); let message = match error.code { ErrorKind::Tag => "unexpected token".to_string(), ErrorKind::Digit => "expected number".to_string(), ErrorKind::Alpha => "expected identifier".to_string(), ErrorKind::Char => "expected character".to_string(), ErrorKind::Many0 => "expected list".to_string(), ErrorKind::SeparatedList => "expected comma-separated list".to_string(), _ => format!("parse error: {:?}", error.code), }; Self { message, location, expected: vec![], } } } /// Main parser implementation pub struct Parser; impl Parser { /// Parse a complete expression from input pub fn parse_expression(input: &str) -> Result<Spanned<Expr>, ParseError> { let span = Span::new(input); match Self::expression(span) { Ok((_, expr)) => Ok(expr), Err(Err::Error(e)) | Err(Err::Failure(e)) => Err(ParseError::from_nom_error(span, e)), Err(Err::Incomplete(_)) => Err(ParseError { message: "incomplete input".to_string(), location: Location::from_span(span), expected: vec!["more input".to_string()], }), } } /// Parse an expression with precedence fn expression(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { Self::binary_expr(input, 0) } /// Parse binary expressions with operator precedence fn binary_expr(input: Span<'_>, min_prec: u8) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, mut left) = Self::primary_expr(input)?; let mut current_input = input; loop { let (input, _) = multispace0(current_input)?; // Try to parse an operator let op_result: IResult<Span<'_>, (BinaryOp, u8)> = alt(( map(char('+'), |_| (BinaryOp::Add, 1)), map(char('-'), |_| (BinaryOp::Sub, 1)), map(char('*'), |_| (BinaryOp::Mul, 2)), map(char('/'), |_| (BinaryOp::Div, 2)), map(tag("=="), |_| (BinaryOp::Eq, 0)), map(char('<'), |_| (BinaryOp::Lt, 0)), map(char('>'), |_| (BinaryOp::Gt, 0)), )) .parse(input); match op_result { Ok((input, (op, prec))) if prec >= min_prec => { let (input, _) = multispace0(input)?; let (input, right) = Self::binary_expr(input, prec + 1)?; let end_span = position(input)?; left = Spanned::new( Expr::Binary { left: Box::new(left), op, right: Box::new(right), }, start_pos.0, end_span.0, ); current_input = input; } _ => break, } } Ok((current_input, left)) } /// Parse primary expressions (atoms and parenthesized expressions) fn primary_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let (input, _) = multispace0(input)?; alt(( Self::parenthesized_expr, Self::function_call, Self::let_expr, Self::number, Self::identifier, )) .parse(input) } /// Parse parenthesized expressions fn parenthesized_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, expr) = delimited( char('('), preceded(multispace0, Self::expression), preceded(multispace0, char(')')), ) .parse(input)?; let end_pos = position(input)?; Ok((input, Spanned::new(expr.node, start_pos.0, end_pos.0))) } /// Parse function calls fn function_call(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, func) = Self::identifier(input)?; let (input, _) = multispace0(input)?; let (input, _) = char('(')(input)?; let (input, _) = multispace0(input)?; let (input, args) = separated_list0( delimited(multispace0, char(','), multispace0), Self::expression, ) .parse(input)?; let (input, _) = multispace0(input)?; let (input, _) = char(')')(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new( Expr::Call { func: Box::new(func), args, }, start_pos.0, end_pos.0, ), )) } /// Parse let expressions fn let_expr(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, _) = tag("let")(input)?; let (input, _) = multispace1(input)?; let (input, name) = Self::identifier_string(input)?; let (input, _) = multispace0(input)?; let (input, _) = char('=')(input)?; let (input, _) = multispace0(input)?; let (input, value) = Self::expression(input)?; let (input, _) = multispace0(input)?; let (input, _) = tag("in")(input)?; let (input, _) = multispace0(input)?; let (input, body) = Self::expression(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new( Expr::Let { name, value: Box::new(value), body: Box::new(body), }, start_pos.0, end_pos.0, ), )) } /// Parse numbers fn number(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, digits) = digit1(input)?; let end_pos = position(input)?; let num = digits .fragment() .parse() .map_err(|_| Err::Error(Error::new(input, ErrorKind::Digit)))?; Ok(( input, Spanned::new(Expr::Number(num), start_pos.0, end_pos.0), )) } /// Parse identifiers fn identifier(input: Span<'_>) -> IResult<Span<'_>, Spanned<Expr>> { let start_pos = position(input)?; let (input, ident) = Self::identifier_string(input)?; let end_pos = position(input)?; Ok(( input, Spanned::new(Expr::Identifier(ident), start_pos.0, end_pos.0), )) } /// Parse identifier strings fn identifier_string(input: Span<'_>) -> IResult<Span<'_>, String> { let (input, ident) = recognize(pair( alt((tag("_"), take_while1(|c: char| c.is_ascii_alphabetic()))), take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), )) .parse(input)?; Ok((input, ident.fragment().to_string())) } /// Get position information for error reporting pub fn get_position_info(input: &str, offset: usize) -> Option<(u32, usize)> { if offset <= input.len() { // Count lines and column up to the offset let mut line = 1; let mut col = 1; for (i, ch) in input.char_indices() { if i >= offset { break; } if ch == '\n' { line += 1; col = 1; } else { col += 1; } } Some((line, col)) } else { None } } /// Extract line content for error reporting pub fn get_line_content(input: &str, line_number: u32) -> Option<&str> { input.lines().nth(line_number.saturating_sub(1) as usize) } } /// A lexer that preserves location information for each token pub struct LocatedLexer<'a> { input: Span<'a>, } #[derive(Debug, Clone, PartialEq)] pub struct LocatedToken { pub kind: TokenKind, pub location: Location, pub text: String, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Number, Identifier, Keyword(String), Operator(String), LeftParen, RightParen, Comma, Equals, Whitespace, Comment, Eof, } impl<'a> LocatedLexer<'a> { pub fn new(input: &'a str) -> Self { Self { input: Span::new(input), } } /// Tokenize input while preserving location information pub fn tokenize(&mut self) -> Result<Vec<LocatedToken>, ParseError> { let mut tokens = Vec::new(); let mut current = self.input; while !current.fragment().is_empty() { let _start_pos = position(current).map_err(|e| { ParseError::from_nom_error( current, match e { Err::Error(err) | Err::Failure(err) => err, Err::Incomplete(_) => Error::new(current, ErrorKind::Complete), }, ) })?; let (remaining, token) = self.next_token(current).map_err(|e| match e { Err::Error(err) | Err::Failure(err) => ParseError::from_nom_error(current, err), Err::Incomplete(_) => ParseError { message: "incomplete token".to_string(), location: Location::from_span(current), expected: vec!["complete token".to_string()], }, })?; if let Some(token) = token { tokens.push(token); } current = remaining; } tokens.push(LocatedToken { kind: TokenKind::Eof, location: Location::from_span(current), text: String::new(), }); Ok(tokens) } fn next_token(&self, input: Span<'a>) -> IResult<Span<'a>, Option<LocatedToken>> { alt(( map(|i| self.whitespace_or_comment(i), |_| None), map(|i| self.keyword_or_identifier(i), Some), map(|i| self.number_token(i), Some), map(|i| self.operator_token(i), Some), map(|i| self.punctuation_token(i), Some), )) .parse(input) } fn whitespace_or_comment(&self, input: Span<'a>) -> IResult<Span<'a>, ()> { let (input, _) = alt(( multispace1, recognize((tag("//"), take_while(|c| c != '\n'))), )) .parse(input)?; Ok((input, ())) } fn keyword_or_identifier(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, ident) = recognize(pair( alt((tag("_"), take_while1(|c: char| c.is_ascii_alphabetic()))), take_while(|c: char| c.is_ascii_alphanumeric() || c == '_'), )) .parse(input)?; let text = ident.fragment().to_string(); let kind = match text.as_str() { "let" | "in" | "if" | "then" | "else" | "fn" => TokenKind::Keyword(text.clone()), _ => TokenKind::Identifier, }; Ok(( input, LocatedToken { kind, location: Location::from_span(start_pos.0), text, }, )) } fn number_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, number) = digit1(input)?; Ok(( input, LocatedToken { kind: TokenKind::Number, location: Location::from_span(start_pos.0), text: number.fragment().to_string(), }, )) } fn operator_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, op) = alt(( tag("=="), tag("<="), tag(">="), tag("!="), tag("+"), tag("-"), tag("*"), tag("/"), tag("<"), tag(">"), )) .parse(input)?; Ok(( input, LocatedToken { kind: TokenKind::Operator(op.fragment().to_string()), location: Location::from_span(start_pos.0), text: op.fragment().to_string(), }, )) } fn punctuation_token(&self, input: Span<'a>) -> IResult<Span<'a>, LocatedToken> { let start_pos = position(input)?; let (input, punct) = alt(( map(char('('), |_| TokenKind::LeftParen), map(char(')'), |_| TokenKind::RightParen), map(char(','), |_| TokenKind::Comma), map(char('='), |_| TokenKind::Equals), )) .parse(input)?; let text = match punct { TokenKind::LeftParen => "(".to_string(), TokenKind::RightParen => ")".to_string(), TokenKind::Comma => ",".to_string(), TokenKind::Equals => "=".to_string(), _ => String::new(), }; Ok(( input, LocatedToken { kind: punct, location: Location::from_span(start_pos.0), text, }, )) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_location_tracking() { let input = "let x = 42\nin y + 1"; let span = Span::new(input); // Test line and column calculation assert_eq!(span.location_line(), 1); assert_eq!(span.get_utf8_column(), 1); // Test position after advancing let (remaining, _): (Span, Span) = tag::<&str, Span, Error<Span>>("let")(span).unwrap(); assert_eq!(remaining.location_offset(), 3); } #[test] fn test_number_parsing() { let result = Parser::parse_expression("42").unwrap(); assert_eq!(result.node, Expr::Number(42)); assert_eq!(result.span.start.line, 1); assert_eq!(result.span.start.column, 1); } #[test] fn test_binary_expression() { let result = Parser::parse_expression("2 + 3 * 4").unwrap(); if let Expr::Binary { left, op, right } = result.node { assert_eq!(op, BinaryOp::Add); assert_eq!(left.node, Expr::Number(2)); if let Expr::Binary { left, op, right } = right.node { assert_eq!(op, BinaryOp::Mul); assert_eq!(left.node, Expr::Number(3)); assert_eq!(right.node, Expr::Number(4)); } else { panic!("Expected binary expression for multiplication"); } } else { panic!("Expected binary expression"); } } #[test] fn test_function_call_parsing() { let result = Parser::parse_expression("add(1, 2)").unwrap(); if let Expr::Call { func, args } = result.node { assert_eq!(func.node, Expr::Identifier("add".to_string())); assert_eq!(args.len(), 2); assert_eq!(args[0].node, Expr::Number(1)); assert_eq!(args[1].node, Expr::Number(2)); } else { panic!("Expected function call"); } } #[test] fn test_let_expression() { let result = Parser::parse_expression("let x = 5 in x + 1").unwrap(); if let Expr::Let { name, value, body } = result.node { assert_eq!(name, "x"); assert_eq!(value.node, Expr::Number(5)); if let Expr::Binary { left, op, right } = body.node { assert_eq!(op, BinaryOp::Add); assert_eq!(left.node, Expr::Identifier("x".to_string())); assert_eq!(right.node, Expr::Number(1)); } else { panic!("Expected binary expression in let body"); } } else { panic!("Expected let expression"); } } #[test] fn test_error_location() { let result = Parser::parse_expression("2 + "); assert!(result.is_err()); let error = result.unwrap_err(); assert_eq!(error.location.line, 1); assert_eq!(error.location.column, 5); } #[test] fn test_lexer_with_locations() { let mut lexer = LocatedLexer::new("let x = 42"); let tokens = lexer.tokenize().unwrap(); assert_eq!(tokens.len(), 5); // let, x, =, 42, EOF assert_eq!(tokens[0].kind, TokenKind::Keyword("let".to_string())); assert_eq!(tokens[0].location.column, 1); assert_eq!(tokens[1].kind, TokenKind::Identifier); assert_eq!(tokens[1].text, "x"); assert_eq!(tokens[1].location.column, 5); assert_eq!(tokens[2].kind, TokenKind::Equals); assert_eq!(tokens[2].location.column, 7); } #[test] fn test_multiline_locations() { let input = "let x = 1\nlet y = 2"; let mut lexer = LocatedLexer::new(input); let tokens = lexer.tokenize().unwrap(); // Find the second 'let' token let second_let = tokens .iter() .find(|t| { matches!(t.kind, TokenKind::Keyword(ref k) if k == "let") && t.location.line == 2 }) .expect("Should find second let token"); assert_eq!(second_let.location.line, 2); assert_eq!(second_let.location.column, 1); } #[test] fn test_position_helpers() { let input = "line1\nline2\nline3"; let pos_info = Parser::get_position_info(input, 7); // Position of 'l' in "line2" assert_eq!(pos_info, Some((2, 2))); let line_content = Parser::get_line_content(input, 2); assert_eq!(line_content, Some("line2")); } #[test] fn test_source_range() { let result = Parser::parse_expression("42 + 3").unwrap(); let range = result.span.to_range(); assert_eq!(range, 0..6); } } #[derive(Debug, Clone, PartialEq)] pub enum BinaryOp { Add, Sub, Mul, Div, Eq, Lt, Gt, } }
The parser correctly handles precedence while tracking the span of entire expressions and their sub-components.
Best Practices
Preserve spans throughout AST construction. Don’t extract the inner value and discard location information until absolutely necessary.
Use type aliases to make span types more readable. type Span<'a> = LocatedSpan<&'a str>
is clearer than using the full type everywhere.
Create wrapper functions for common patterns. Helper functions that handle span extraction and position calculation reduce boilerplate.
Test location accuracy. Include tests that verify not just parse results but also that locations are correctly preserved.
Design AST nodes to include location information from the start. Retrofitting location tracking is much harder than including it in the initial design.
quote
quote provides quasi-quoting for generating Rust code programmatically while preserving the structure and formatting of the generated code. The library works in tandem with syn for parsing and proc-macro2 for token manipulation, forming the foundation of most Rust procedural macros. Unlike string-based code generation, quote maintains type safety and proper hygiene while generating syntactically correct Rust code.
The library’s interpolation syntax using # allows embedding runtime values into generated code, while repetition patterns with #(…)* enable generating loops and repeated structures. quote excels at preserving the visual structure of code templates, making generated code readable and maintainable. The ability to splice together token streams from different sources enables modular code generation patterns.
Basic Code Generation
#![allow(unused)] fn main() { use quote::{quote, format_ident}; use proc_macro2::TokenStream; pub fn generate_function(name: &str, body: TokenStream) -> TokenStream { let fn_name = format_ident!("{}", name); quote! { pub fn #fn_name() -> i32 { #body } } } pub fn generate_struct(name: &str, fields: &[(String, String)]) -> TokenStream { let struct_name = format_ident!("{}", name); let field_defs = fields.iter().map(|(name, ty)| { let field_name = format_ident!("{}", name); let field_type = format_ident!("{}", ty); quote! { pub #field_name: #field_type } }); quote! { #[derive(Debug, Clone)] pub struct #struct_name { #(#field_defs),* } } } }
The basic code generation functions demonstrate quote’s fundamental interpolation mechanism. The format_ident! macro creates identifiers from strings, ensuring they are valid Rust identifiers. The # symbol acts as an interpolation marker, embedding the identifier into the generated code. The quote! macro preserves the visual structure of the code template, making it easy to understand what code will be generated.
The repetition pattern #(#field_defs),* generates a comma-separated list of field definitions. This pattern iterates over the field_defs iterator, inserting each element separated by commas. The outer #(…) marks the repetition boundary, the inner # interpolates each item, and the ,* specifies comma separation with zero or more repetitions.
AST-Based Generation
#![allow(unused)] fn main() { #[derive(Debug, Clone)] pub enum Expr { Literal(i32), Variable(String), Binary { op: BinaryOp, left: Box<Expr>, right: Box<Expr> }, } #[derive(Debug, Clone)] pub enum BinaryOp { Add, Sub, Mul, Div, } impl quote::ToTokens for Expr { fn to_tokens(&self, tokens: &mut TokenStream) { match self { Expr::Literal(n) => tokens.extend(quote! { #n }), Expr::Variable(name) => { let ident = format_ident!("{}", name); tokens.extend(quote! { #ident }); } Expr::Binary { op, left, right } => { let op_tokens = match op { BinaryOp::Add => quote! { + }, BinaryOp::Sub => quote! { - }, BinaryOp::Mul => quote! { * }, BinaryOp::Div => quote! { / }, }; tokens.extend(quote! { (#left #op_tokens #right) }); } } } } }
Implementing ToTokens allows custom types to participate in quote’s interpolation system. The to_tokens method converts the AST representation into token streams that represent valid Rust code. This approach enables type-safe code generation where the AST structure ensures only valid combinations are possible.
The recursive nature of the Binary variant demonstrates how complex expressions naturally map to nested token generation. The parentheses in the output ensure proper precedence, while the interpolation of left and right recursively invokes their ToTokens implementations. This pattern scales to arbitrarily complex ASTs while maintaining clean separation between representation and generation.
Builder Pattern Generation
#![allow(unused)] fn main() { pub fn generate_builder(struct_name: &str, fields: &[(String, String)]) -> TokenStream { let struct_ident = format_ident!("{}", struct_name); let builder_ident = format_ident!("{}Builder", struct_name); let builder_fields = fields.iter().map(|(name, ty)| { let name = format_ident!("{}", name); let ty = format_ident!("{}", ty); quote! { #name: Option<#ty> } }); let builder_methods = fields.iter().map(|(name, ty)| { let name = format_ident!("{}", name); let ty = format_ident!("{}", ty); quote! { pub fn #name(mut self, value: #ty) -> Self { self.#name = Some(value); self } } }); let build_assignments = fields.iter().map(|(name, _)| { let name = format_ident!("{}", name); let error_msg = format!("Field {} is required", name); quote! { #name: self.#name.ok_or(#error_msg)? } }); quote! { pub struct #builder_ident { #(#builder_fields),* } impl #builder_ident { pub fn new() -> Self { Self { #(#name: None),* } } #(#builder_methods)* pub fn build(self) -> Result<#struct_ident, &'static str> { Ok(#struct_ident { #(#build_assignments),* }) } } } } }
The builder pattern generator showcases quote’s ability to generate complex patterns with multiple related components. Each field generates three pieces: an optional field in the builder, a setter method, and a build-time assignment. The repetition patterns handle collections of generated code, maintaining consistency across all fields.
The error handling in the build method demonstrates embedding runtime values into generated code. The format! macro creates error messages at generation time, which become string literals in the generated code. This technique allows customizing generated code based on input parameters while maintaining compile-time type checking.
Trait Implementation Generation
#![allow(unused)] fn main() { pub fn generate_display_impl( struct_name: &str, format_str: &str, fields: &[String] ) -> TokenStream { let struct_ident = format_ident!("{}", struct_name); let field_refs = fields.iter().map(|name| { let field = format_ident!("{}", name); quote! { self.#field } }); quote! { impl std::fmt::Display for #struct_ident { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, #format_str, #(#field_refs),*) } } } } pub fn generate_from_impl( target: &str, source: &str, conversion: TokenStream ) -> TokenStream { let target_ident = format_ident!("{}", target); let source_ident = format_ident!("{}", source); quote! { impl From<#source_ident> for #target_ident { fn from(value: #source_ident) -> Self { #conversion } } } } }
Trait implementation generation demonstrates quote’s ability to generate standard Rust patterns. The Display implementation shows how format strings and field references combine to create formatted output. The repetition pattern in write! generates the exact number of arguments needed, matching the format string placeholders.
The From implementation generator accepts a TokenStream for the conversion logic, showing how quote enables composition of generated code. This pattern allows callers to provide complex conversion logic while the generator handles the boilerplate trait implementation structure.
Generic Code Generation
#![allow(unused)] fn main() { pub fn generate_generic_wrapper<T: quote::ToTokens>( name: &str, inner_type: T, bounds: &[String] ) -> TokenStream { let wrapper_ident = format_ident!("{}", name); let bound_tokens = bounds.iter().map(|b| { let bound = format_ident!("{}", b); quote! { #bound } }); quote! { pub struct #wrapper_ident<T> where T: #(#bound_tokens)+* { inner: #inner_type, phantom: std::marker::PhantomData<T>, } impl<T> #wrapper_ident<T> where T: #(#bound_tokens)+* { pub fn new(inner: #inner_type) -> Self { Self { inner, phantom: std::marker::PhantomData, } } pub fn into_inner(self) -> #inner_type { self.inner } } } } }
Generic code generation requires careful handling of type parameters and bounds. The bound_tokens iterator generates trait bounds, while the +* repetition pattern creates the proper + separator for multiple bounds. The where clause repetition ensures bounds appear consistently in both the struct definition and implementation.
The PhantomData field demonstrates generating standard patterns for generic types that don’t directly use their type parameters. This pattern is essential for maintaining proper variance and drop checking in generic types.
Method Chain Generation
#![allow(unused)] fn main() { pub fn generate_method_chain( base: TokenStream, methods: &[(String, Vec<TokenStream>)] ) -> TokenStream { let mut result = base; for (method, args) in methods { let method_ident = format_ident!("{}", method); result = quote! { #result.#method_ident(#(#args),*) }; } result } pub fn generate_builder_chain(fields: &[(String, TokenStream)]) -> TokenStream { let setters = fields.iter().map(|(name, value)| { let method = format_ident!("{}", name); quote! { .#method(#value) } }); quote! { Builder::new() #(#setters)* .build() } } }
Method chain generation shows how quote enables building complex expressions programmatically. The iterative approach accumulates method calls, with each iteration wrapping the previous result. This pattern generates fluent interfaces and builder chains commonly used in Rust APIs.
The builder chain generator demonstrates generating idiomatic Rust patterns with proper indentation and formatting. The quote! macro preserves the visual structure, making the generated code readable. The repetition pattern handles any number of setter calls while maintaining consistent formatting.
Conditional Generation
#![allow(unused)] fn main() { pub fn generate_conditional_impl( condition: bool, true_branch: TokenStream, false_branch: TokenStream ) -> TokenStream { if condition { quote! { #[cfg(feature = "enabled")] #true_branch } } else { quote! { #[cfg(not(feature = "enabled"))] #false_branch } } } pub fn generate_optional_field( name: &str, ty: &str, include: bool ) -> TokenStream { let field_name = format_ident!("{}", name); let field_type = format_ident!("{}", ty); if include { quote! { pub #field_name: #field_type, } } else { quote! {} } } }
Conditional generation enables creating different code based on compile-time conditions. The cfg attributes in generated code allow feature-gated implementations, while the generation-time conditions customize what code gets generated. This dual-layer approach provides maximum flexibility in code generation.
Empty token streams from quote! {} allow optional elements in generated code. This pattern is useful for conditionally including fields, methods, or entire implementations based on configuration or feature flags.
Span Preservation
#![allow(unused)] fn main() { use quote::quote_spanned; use proc_macro2::Span; pub fn generate_spanned_error(span: Span, message: &str) -> TokenStream { quote_spanned! {span=> compile_error!(#message); } } pub fn generate_with_location(span: Span, code: TokenStream) -> TokenStream { quote_spanned! {span=> #code } } }
The quote_spanned! macro preserves source location information, crucial for error reporting in procedural macros. When the generated code contains errors, the compiler reports them at the original source location rather than pointing to the macro invocation. This feature significantly improves the debugging experience for macro users.
Span preservation enables generating helpful error messages that point to the exact location of problems in the input code. This capability is essential for creating user-friendly procedural macros that provide clear diagnostics.
Repetition Patterns
#![allow(unused)] fn main() { pub fn generate_match_arms(variants: &[(String, TokenStream)]) -> TokenStream { let arms = variants.iter().map(|(pattern, body)| { let pattern_ident = format_ident!("{}", pattern); quote! { Self::#pattern_ident => { #body } } }); quote! { match self { #(#arms),* } } } pub fn generate_tuple_destructure(count: usize) -> TokenStream { let vars = (0..count).map(|i| format_ident!("_{}", i)); let indices = (0..count).map(|i| syn::Index::from(i)); quote! { let (#(#vars),*) = tuple; #( println!("Element {}: {:?}", #indices, #vars); )* } } }
Advanced repetition patterns demonstrate quote’s flexibility in generating complex structures. The match arm generation shows how patterns and bodies can be generated from data, while maintaining proper syntax. The comma separator in the repetition ensures valid match syntax.
The tuple destructuring example showcases numeric repetition, generating unique variable names and accessing tuple elements by index. The nested repetition pattern generates both the destructuring and the println statements, demonstrating how multiple repetitions can work together.
Integration with syn
#![allow(unused)] fn main() { use syn::{parse_quote, Expr, Stmt}; pub fn generate_assertion(left: &str, op: &str, right: &str) -> TokenStream { let assertion: Expr = parse_quote! { assert!(#left #op #right, "Assertion failed: {} {} {}", #left, #op, #right) }; quote! { #assertion } } pub fn generate_test_function(name: &str, body: Vec<Stmt>) -> TokenStream { let test_name = format_ident!("test_{}", name); quote! { #[test] fn #test_name() { #(#body)* } } } }
The parse_quote! macro from syn parses string literals into syn types at compile time, which can then be interpolated with quote!. This combination enables parsing complex expressions and statements while maintaining type safety. The assertion generator shows how string representations convert to properly typed AST nodes.
Integration with syn enables sophisticated code generation patterns where parsing and generation work together. This approach is particularly useful in procedural macros that need to analyze input code before generating output.
Best Practices
Structure generated code to match handwritten Rust conventions. Use proper indentation and formatting in quote! templates to make generated code readable. The visual structure of the template should reflect the structure of the generated code, making it easy to understand what will be generated.
Implement ToTokens for custom types that frequently appear in generated code. This approach provides better abstraction than repeatedly using quote! for the same patterns. Custom ToTokens implementations can encapsulate complex generation logic while providing a clean interface.
Use format_ident! for creating identifiers from strings, ensuring valid Rust identifiers. Never concatenate strings to build code; use quote!’s interpolation system instead. This approach maintains hygiene and prevents syntax errors in generated code.
Preserve spans when generating error messages or warnings. Use quote_spanned! to attach generated code to specific source locations, improving error messages. Good span preservation makes procedural macros feel like built-in language features.
Test generated code by comparing TokenStream representations or by compiling and running the generated code. Use snapshot testing for complex generated structures, capturing the generated code as strings for comparison. Regular testing ensures code generation remains correct as requirements evolve.
quote provides an elegant and powerful system for generating Rust code that maintains the language’s safety guarantees while enabling sophisticated metaprogramming patterns. Its integration with the procedural macro ecosystem makes it indispensable for creating derive macros, attribute macros, and code generators that feel native to Rust.
syn
Syn is a parser library for Rust code that provides a complete syntax tree representation of Rust source code. While primarily designed for procedural macros, syn’s powerful parsing capabilities make it invaluable for compiler construction tasks, especially when building languages that integrate with Rust or when analyzing Rust code itself.
The library excels at parsing complex token streams into strongly-typed abstract syntax trees. Unlike traditional parser generators that work with external grammar files, syn embeds the entire Rust grammar as Rust types, providing compile-time safety and excellent IDE support. This approach makes it particularly suitable for building domain-specific languages that extend Rust’s syntax or for creating compiler tools that analyze and transform Rust code.
Core Concepts
Syn operates on TokenStreams, which represent sequences of Rust tokens. These tokens flow from the Rust compiler through proc-macro2 into syn for parsing. The library provides three primary ways to work with syntax: parsing tokens into predefined AST types, implementing custom parsers using the Parse trait, and transforming existing AST nodes.
#![allow(unused)] fn main() { use std::collections::HashMap; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{ parse_quote, Error, Expr, ExprLit, FnArg, ItemFn, Lit, Pat, Result, Stmt, Token, Type, Visibility, }; #[derive(Debug, Clone)] pub struct FunctionAnalysis { pub name: String, pub param_count: usize, pub params: Vec<String>, pub is_async: bool, pub is_unsafe: bool, pub has_generics: bool, pub visibility: String, } /// Example: Custom DSL parsing - Simple state machine language pub struct StateMachine { pub name: Ident, pub states: Vec<State>, pub initial: Ident, } pub struct State { pub name: Ident, pub transitions: Vec<Transition>, } pub struct Transition { pub event: Ident, pub target: Ident, pub action: Option<Expr>, } impl Parse for StateMachine { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; input.parse::<kw::machine>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); // Parse initial state content.parse::<kw::initial>()?; content.parse::<Token![:]>()?; let initial: Ident = content.parse()?; content.parse::<Token![;]>()?; // Parse states let mut states = Vec::new(); while !content.is_empty() { states.push(content.parse()?); } Ok(StateMachine { name, states, initial, }) } } impl Parse for State { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); let mut transitions = Vec::new(); while !content.is_empty() { transitions.push(content.parse()?); } Ok(State { name, transitions }) } } impl Parse for Transition { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::on>()?; let event: Ident = input.parse()?; input.parse::<Token![=>]>()?; let target: Ident = input.parse()?; let action = if input.peek(Token![,]) { input.parse::<Token![,]>()?; Some(input.parse()?) } else { None }; input.parse::<Token![;]>()?; Ok(Transition { event, target, action, }) } } mod kw { use syn::custom_keyword; custom_keyword!(state); custom_keyword!(machine); custom_keyword!(initial); custom_keyword!(on); } /// Example: AST transformation - Add logging to functions pub fn inject_logging(mut func: ItemFn) -> ItemFn { let fn_name = &func.sig.ident; let log_entry: Stmt = parse_quote! { println!("Entering function: {}", stringify!(#fn_name)); }; // Insert at the beginning of the function body func.block.stmts.insert(0, log_entry); // Add exit logging before each return let log_exit: Stmt = parse_quote! { println!("Exiting function: {}", stringify!(#fn_name)); }; let mut new_stmts = Vec::new(); for stmt in func.block.stmts.drain(..) { match &stmt { Stmt::Expr(Expr::Return(_), _) => { new_stmts.push(log_exit.clone()); new_stmts.push(stmt); } _ => new_stmts.push(stmt), } } // Add exit log at the end if there's no explicit return if !matches!(new_stmts.last(), Some(Stmt::Expr(Expr::Return(_), _))) { new_stmts.push(log_exit); } func.block.stmts = new_stmts; func } /// Example: Custom attribute parsing #[derive(Debug)] pub struct CompilerDirective { pub optimization_level: u8, pub inline: bool, pub target_features: Vec<String>, } impl Parse for CompilerDirective { fn parse(input: ParseStream) -> Result<Self> { let mut optimization_level = 0; let mut inline = false; let mut target_features = Vec::new(); let vars = Punctuated::<MetaItem, Token![,]>::parse_terminated(input)?; for var in vars { match var.name.to_string().as_str() { "opt_level" => optimization_level = var.value, "inline" => inline = true, "features" => { target_features = var .list .into_iter() .map(|s| s.trim_matches('"').to_string()) .collect(); } _ => { return Err(Error::new( var.name.span(), format!("Unknown directive: {}", var.name), )) } } } Ok(CompilerDirective { optimization_level, inline, target_features, }) } } struct MetaItem { name: Ident, value: u8, list: Vec<String>, } impl Parse for MetaItem { fn parse(input: ParseStream) -> Result<Self> { let name: Ident = input.parse()?; if input.peek(Token![=]) { input.parse::<Token![=]>()?; if let Ok(lit) = input.parse::<ExprLit>() { if let Lit::Int(int) = lit.lit { let value = int.base10_parse::<u8>()?; return Ok(MetaItem { name, value, list: vec![], }); } } } if input.peek(syn::token::Paren) { let content; syn::parenthesized!(content in input); let list = Punctuated::<ExprLit, Token![,]>::parse_terminated(&content)? .into_iter() .filter_map(|lit| { if let Lit::Str(s) = lit.lit { Some(s.value()) } else { None } }) .collect(); return Ok(MetaItem { name, value: 0, list, }); } Ok(MetaItem { name, value: 1, list: vec![], }) } } /// Example: Type analysis for compiler optimizations pub fn analyze_types_in_function(func: &ItemFn) -> HashMap<String, TypeInfo> { let mut type_info = HashMap::new(); // Analyze parameter types for input in &func.sig.inputs { if let FnArg::Typed(pat_type) = input { if let Pat::Ident(ident) = pat_type.pat.as_ref() { let info = analyze_type(&pat_type.ty); type_info.insert(ident.ident.to_string(), info); } } } type_info } #[derive(Debug, Clone)] pub struct TypeInfo { pub is_primitive: bool, pub is_reference: bool, pub is_mutable: bool, pub type_string: String, } fn analyze_type(ty: &Type) -> TypeInfo { match ty { Type::Path(type_path) => { let type_string = quote!(#type_path).to_string(); let is_primitive = matches!( type_string.as_str(), "i8" | "i16" | "i32" | "i64" | "i128" | "u8" | "u16" | "u32" | "u64" | "u128" | "f32" | "f64" | "bool" | "char" ); TypeInfo { is_primitive, is_reference: false, is_mutable: false, type_string, } } Type::Reference(type_ref) => { let inner = analyze_type(&type_ref.elem); TypeInfo { is_reference: true, is_mutable: type_ref.mutability.is_some(), ..inner } } _ => TypeInfo { is_primitive: false, is_reference: false, is_mutable: false, type_string: quote!(#ty).to_string(), }, } } /// Example: Generate optimized code based on const evaluation pub fn const_fold_binary_ops(expr: Expr) -> Expr { match expr { Expr::Binary(mut binary) => { // Recursively fold sub-expressions binary.left = Box::new(const_fold_binary_ops(*binary.left)); binary.right = Box::new(const_fold_binary_ops(*binary.right)); // Try to fold if both operands are literals if let (Expr::Lit(left_lit), Expr::Lit(right_lit)) = (binary.left.as_ref(), binary.right.as_ref()) { if let (Lit::Int(l), Lit::Int(r)) = (&left_lit.lit, &right_lit.lit) { if let (Ok(l_val), Ok(r_val)) = (l.base10_parse::<i64>(), r.base10_parse::<i64>()) { use syn::BinOp; let result = match binary.op { BinOp::Add(_) => Some(l_val + r_val), BinOp::Sub(_) => Some(l_val - r_val), BinOp::Mul(_) => Some(l_val * r_val), BinOp::Div(_) if r_val != 0 => Some(l_val / r_val), _ => None, }; if let Some(val) = result { return parse_quote!(#val); } } } } Expr::Binary(binary) } // Recursively process other expression types Expr::Paren(mut paren) => { paren.expr = Box::new(const_fold_binary_ops(*paren.expr)); Expr::Paren(paren) } Expr::Block(mut block) => { if let Some(Stmt::Expr(expr, _semi)) = block.block.stmts.last_mut() { *expr = const_fold_binary_ops(expr.clone()); } Expr::Block(block) } other => other, } } /// Error handling with span information pub fn validate_function(func: &ItemFn) -> std::result::Result<(), Vec<Error>> { let mut errors = Vec::new(); // Check function name conventions let name = func.sig.ident.to_string(); if name.starts_with('_') && func.vis != Visibility::Inherited { errors.push(Error::new( func.sig.ident.span(), "Public functions should not start with underscore", )); } // Check for missing documentation if !func.attrs.iter().any(|attr| attr.path().is_ident("doc")) { errors.push(Error::new( func.sig.ident.span(), "Missing documentation comment", )); } // Check parameter conventions for input in &func.sig.inputs { let FnArg::Typed(pat_type) = input else { continue; }; let Type::Reference(type_ref) = pat_type.ty.as_ref() else { continue; }; if type_ref.mutability.is_some() { continue; } let Type::Path(path) = type_ref.elem.as_ref() else { continue; }; let Some(ident) = path.path.get_ident() else { continue; }; let type_name = ident.to_string(); if matches!(type_name.as_str(), "String" | "Vec" | "HashMap") { errors.push(Error::new( pat_type.ty.span(), format!( "Consider using &{} instead of {} for better performance", type_name, type_name ), )); } } if errors.is_empty() { Ok(()) } else { Err(errors) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_function_analysis() { let input = quote! { pub async unsafe fn process_data<T>(input: &str, count: usize) -> Result<T> { todo!() } }; let analysis = analyze_function(input).unwrap(); assert_eq!(analysis.name, "process_data"); assert_eq!(analysis.param_count, 2); assert!(analysis.is_async); assert!(analysis.is_unsafe); assert!(analysis.has_generics); assert_eq!(analysis.params, vec!["input", "count"]); } #[test] fn test_inject_logging() { let input: ItemFn = parse_quote! { fn calculate(x: i32, y: i32) -> i32 { if x > y { return x - y; } x + y } }; let modified = inject_logging(input); let output = quote!(#modified).to_string(); assert!(output.contains("Entering function")); assert!(output.contains("Exiting function")); } #[test] fn test_const_folding() { // Test simple constant folding let expr: Expr = parse_quote! { 2 + 3 }; let folded = const_fold_binary_ops(expr); match &folded { Expr::Lit(lit) => { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } else { panic!("Expected integer literal"); } } _ => panic!( "Expected literal after folding, got: {:?}", quote!(#folded).to_string() ), } // Test division let expr: Expr = parse_quote! { 10 / 2 }; let folded = const_fold_binary_ops(expr); if let Expr::Lit(lit) = &folded { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } } // Test non-foldable expression (variable) let expr: Expr = parse_quote! { x + 3 }; let folded = const_fold_binary_ops(expr); assert!(matches!(folded, Expr::Binary(_))); } #[test] fn test_type_analysis() { let func: ItemFn = parse_quote! { fn example(x: i32, s: &str, data: &mut Vec<u8>) {} }; let types = analyze_types_in_function(&func); assert!(types["x"].is_primitive); assert!(types["s"].is_reference); assert!(!types["s"].is_mutable); assert!(types["data"].is_reference); assert!(types["data"].is_mutable); } } /// Example: Parsing and analyzing a Rust function pub fn analyze_function(input: TokenStream) -> Result<FunctionAnalysis> { let func: ItemFn = syn::parse2(input)?; let param_count = func.sig.inputs.len(); let is_async = func.sig.asyncness.is_some(); let is_unsafe = func.sig.unsafety.is_some(); let has_generics = !func.sig.generics.params.is_empty(); let params = func .sig .inputs .iter() .filter_map(|arg| match arg { FnArg::Typed(pat_type) => { if let Pat::Ident(ident) = pat_type.pat.as_ref() { Some(ident.ident.to_string()) } else { None } } _ => None, }) .collect(); Ok(FunctionAnalysis { name: func.sig.ident.to_string(), param_count, params, is_async, is_unsafe, has_generics, visibility: format!("{:?}", func.vis), }) } }
The Parse trait forms the foundation of syn’s extensibility. By implementing this trait, you can create parsers for custom syntax that integrates seamlessly with Rust’s token system. This capability proves essential when building domain-specific languages or extending Rust with new syntactic constructs.
Custom Language Parsing
One of syn’s most powerful features is its ability to parse custom languages that feel native to Rust. By defining custom keywords and implementing Parse traits, you can create domain-specific languages that leverage Rust’s tokenization while introducing novel syntax.
#![allow(unused)] fn main() { use std::collections::HashMap; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{ parse_quote, Error, Expr, ExprLit, FnArg, ItemFn, Lit, Pat, Result, Stmt, Token, Type, Visibility, }; /// Example: Parsing and analyzing a Rust function pub fn analyze_function(input: TokenStream) -> Result<FunctionAnalysis> { let func: ItemFn = syn::parse2(input)?; let param_count = func.sig.inputs.len(); let is_async = func.sig.asyncness.is_some(); let is_unsafe = func.sig.unsafety.is_some(); let has_generics = !func.sig.generics.params.is_empty(); let params = func .sig .inputs .iter() .filter_map(|arg| match arg { FnArg::Typed(pat_type) => { if let Pat::Ident(ident) = pat_type.pat.as_ref() { Some(ident.ident.to_string()) } else { None } } _ => None, }) .collect(); Ok(FunctionAnalysis { name: func.sig.ident.to_string(), param_count, params, is_async, is_unsafe, has_generics, visibility: format!("{:?}", func.vis), }) } #[derive(Debug, Clone)] pub struct FunctionAnalysis { pub name: String, pub param_count: usize, pub params: Vec<String>, pub is_async: bool, pub is_unsafe: bool, pub has_generics: bool, pub visibility: String, } pub struct State { pub name: Ident, pub transitions: Vec<Transition>, } pub struct Transition { pub event: Ident, pub target: Ident, pub action: Option<Expr>, } impl Parse for StateMachine { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; input.parse::<kw::machine>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); // Parse initial state content.parse::<kw::initial>()?; content.parse::<Token![:]>()?; let initial: Ident = content.parse()?; content.parse::<Token![;]>()?; // Parse states let mut states = Vec::new(); while !content.is_empty() { states.push(content.parse()?); } Ok(StateMachine { name, states, initial, }) } } impl Parse for State { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); let mut transitions = Vec::new(); while !content.is_empty() { transitions.push(content.parse()?); } Ok(State { name, transitions }) } } impl Parse for Transition { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::on>()?; let event: Ident = input.parse()?; input.parse::<Token![=>]>()?; let target: Ident = input.parse()?; let action = if input.peek(Token![,]) { input.parse::<Token![,]>()?; Some(input.parse()?) } else { None }; input.parse::<Token![;]>()?; Ok(Transition { event, target, action, }) } } mod kw { use syn::custom_keyword; custom_keyword!(state); custom_keyword!(machine); custom_keyword!(initial); custom_keyword!(on); } /// Example: AST transformation - Add logging to functions pub fn inject_logging(mut func: ItemFn) -> ItemFn { let fn_name = &func.sig.ident; let log_entry: Stmt = parse_quote! { println!("Entering function: {}", stringify!(#fn_name)); }; // Insert at the beginning of the function body func.block.stmts.insert(0, log_entry); // Add exit logging before each return let log_exit: Stmt = parse_quote! { println!("Exiting function: {}", stringify!(#fn_name)); }; let mut new_stmts = Vec::new(); for stmt in func.block.stmts.drain(..) { match &stmt { Stmt::Expr(Expr::Return(_), _) => { new_stmts.push(log_exit.clone()); new_stmts.push(stmt); } _ => new_stmts.push(stmt), } } // Add exit log at the end if there's no explicit return if !matches!(new_stmts.last(), Some(Stmt::Expr(Expr::Return(_), _))) { new_stmts.push(log_exit); } func.block.stmts = new_stmts; func } /// Example: Custom attribute parsing #[derive(Debug)] pub struct CompilerDirective { pub optimization_level: u8, pub inline: bool, pub target_features: Vec<String>, } impl Parse for CompilerDirective { fn parse(input: ParseStream) -> Result<Self> { let mut optimization_level = 0; let mut inline = false; let mut target_features = Vec::new(); let vars = Punctuated::<MetaItem, Token![,]>::parse_terminated(input)?; for var in vars { match var.name.to_string().as_str() { "opt_level" => optimization_level = var.value, "inline" => inline = true, "features" => { target_features = var .list .into_iter() .map(|s| s.trim_matches('"').to_string()) .collect(); } _ => { return Err(Error::new( var.name.span(), format!("Unknown directive: {}", var.name), )) } } } Ok(CompilerDirective { optimization_level, inline, target_features, }) } } struct MetaItem { name: Ident, value: u8, list: Vec<String>, } impl Parse for MetaItem { fn parse(input: ParseStream) -> Result<Self> { let name: Ident = input.parse()?; if input.peek(Token![=]) { input.parse::<Token![=]>()?; if let Ok(lit) = input.parse::<ExprLit>() { if let Lit::Int(int) = lit.lit { let value = int.base10_parse::<u8>()?; return Ok(MetaItem { name, value, list: vec![], }); } } } if input.peek(syn::token::Paren) { let content; syn::parenthesized!(content in input); let list = Punctuated::<ExprLit, Token![,]>::parse_terminated(&content)? .into_iter() .filter_map(|lit| { if let Lit::Str(s) = lit.lit { Some(s.value()) } else { None } }) .collect(); return Ok(MetaItem { name, value: 0, list, }); } Ok(MetaItem { name, value: 1, list: vec![], }) } } /// Example: Type analysis for compiler optimizations pub fn analyze_types_in_function(func: &ItemFn) -> HashMap<String, TypeInfo> { let mut type_info = HashMap::new(); // Analyze parameter types for input in &func.sig.inputs { if let FnArg::Typed(pat_type) = input { if let Pat::Ident(ident) = pat_type.pat.as_ref() { let info = analyze_type(&pat_type.ty); type_info.insert(ident.ident.to_string(), info); } } } type_info } #[derive(Debug, Clone)] pub struct TypeInfo { pub is_primitive: bool, pub is_reference: bool, pub is_mutable: bool, pub type_string: String, } fn analyze_type(ty: &Type) -> TypeInfo { match ty { Type::Path(type_path) => { let type_string = quote!(#type_path).to_string(); let is_primitive = matches!( type_string.as_str(), "i8" | "i16" | "i32" | "i64" | "i128" | "u8" | "u16" | "u32" | "u64" | "u128" | "f32" | "f64" | "bool" | "char" ); TypeInfo { is_primitive, is_reference: false, is_mutable: false, type_string, } } Type::Reference(type_ref) => { let inner = analyze_type(&type_ref.elem); TypeInfo { is_reference: true, is_mutable: type_ref.mutability.is_some(), ..inner } } _ => TypeInfo { is_primitive: false, is_reference: false, is_mutable: false, type_string: quote!(#ty).to_string(), }, } } /// Example: Generate optimized code based on const evaluation pub fn const_fold_binary_ops(expr: Expr) -> Expr { match expr { Expr::Binary(mut binary) => { // Recursively fold sub-expressions binary.left = Box::new(const_fold_binary_ops(*binary.left)); binary.right = Box::new(const_fold_binary_ops(*binary.right)); // Try to fold if both operands are literals if let (Expr::Lit(left_lit), Expr::Lit(right_lit)) = (binary.left.as_ref(), binary.right.as_ref()) { if let (Lit::Int(l), Lit::Int(r)) = (&left_lit.lit, &right_lit.lit) { if let (Ok(l_val), Ok(r_val)) = (l.base10_parse::<i64>(), r.base10_parse::<i64>()) { use syn::BinOp; let result = match binary.op { BinOp::Add(_) => Some(l_val + r_val), BinOp::Sub(_) => Some(l_val - r_val), BinOp::Mul(_) => Some(l_val * r_val), BinOp::Div(_) if r_val != 0 => Some(l_val / r_val), _ => None, }; if let Some(val) = result { return parse_quote!(#val); } } } } Expr::Binary(binary) } // Recursively process other expression types Expr::Paren(mut paren) => { paren.expr = Box::new(const_fold_binary_ops(*paren.expr)); Expr::Paren(paren) } Expr::Block(mut block) => { if let Some(Stmt::Expr(expr, _semi)) = block.block.stmts.last_mut() { *expr = const_fold_binary_ops(expr.clone()); } Expr::Block(block) } other => other, } } /// Error handling with span information pub fn validate_function(func: &ItemFn) -> std::result::Result<(), Vec<Error>> { let mut errors = Vec::new(); // Check function name conventions let name = func.sig.ident.to_string(); if name.starts_with('_') && func.vis != Visibility::Inherited { errors.push(Error::new( func.sig.ident.span(), "Public functions should not start with underscore", )); } // Check for missing documentation if !func.attrs.iter().any(|attr| attr.path().is_ident("doc")) { errors.push(Error::new( func.sig.ident.span(), "Missing documentation comment", )); } // Check parameter conventions for input in &func.sig.inputs { let FnArg::Typed(pat_type) = input else { continue; }; let Type::Reference(type_ref) = pat_type.ty.as_ref() else { continue; }; if type_ref.mutability.is_some() { continue; } let Type::Path(path) = type_ref.elem.as_ref() else { continue; }; let Some(ident) = path.path.get_ident() else { continue; }; let type_name = ident.to_string(); if matches!(type_name.as_str(), "String" | "Vec" | "HashMap") { errors.push(Error::new( pat_type.ty.span(), format!( "Consider using &{} instead of {} for better performance", type_name, type_name ), )); } } if errors.is_empty() { Ok(()) } else { Err(errors) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_function_analysis() { let input = quote! { pub async unsafe fn process_data<T>(input: &str, count: usize) -> Result<T> { todo!() } }; let analysis = analyze_function(input).unwrap(); assert_eq!(analysis.name, "process_data"); assert_eq!(analysis.param_count, 2); assert!(analysis.is_async); assert!(analysis.is_unsafe); assert!(analysis.has_generics); assert_eq!(analysis.params, vec!["input", "count"]); } #[test] fn test_inject_logging() { let input: ItemFn = parse_quote! { fn calculate(x: i32, y: i32) -> i32 { if x > y { return x - y; } x + y } }; let modified = inject_logging(input); let output = quote!(#modified).to_string(); assert!(output.contains("Entering function")); assert!(output.contains("Exiting function")); } #[test] fn test_const_folding() { // Test simple constant folding let expr: Expr = parse_quote! { 2 + 3 }; let folded = const_fold_binary_ops(expr); match &folded { Expr::Lit(lit) => { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } else { panic!("Expected integer literal"); } } _ => panic!( "Expected literal after folding, got: {:?}", quote!(#folded).to_string() ), } // Test division let expr: Expr = parse_quote! { 10 / 2 }; let folded = const_fold_binary_ops(expr); if let Expr::Lit(lit) = &folded { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } } // Test non-foldable expression (variable) let expr: Expr = parse_quote! { x + 3 }; let folded = const_fold_binary_ops(expr); assert!(matches!(folded, Expr::Binary(_))); } #[test] fn test_type_analysis() { let func: ItemFn = parse_quote! { fn example(x: i32, s: &str, data: &mut Vec<u8>) {} }; let types = analyze_types_in_function(&func); assert!(types["x"].is_primitive); assert!(types["s"].is_reference); assert!(!types["s"].is_mutable); assert!(types["data"].is_reference); assert!(types["data"].is_mutable); } } /// Example: Custom DSL parsing - Simple state machine language pub struct StateMachine { pub name: Ident, pub states: Vec<State>, pub initial: Ident, } }
#![allow(unused)] fn main() { use std::collections::HashMap; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{ parse_quote, Error, Expr, ExprLit, FnArg, ItemFn, Lit, Pat, Result, Stmt, Token, Type, Visibility, }; /// Example: Parsing and analyzing a Rust function pub fn analyze_function(input: TokenStream) -> Result<FunctionAnalysis> { let func: ItemFn = syn::parse2(input)?; let param_count = func.sig.inputs.len(); let is_async = func.sig.asyncness.is_some(); let is_unsafe = func.sig.unsafety.is_some(); let has_generics = !func.sig.generics.params.is_empty(); let params = func .sig .inputs .iter() .filter_map(|arg| match arg { FnArg::Typed(pat_type) => { if let Pat::Ident(ident) = pat_type.pat.as_ref() { Some(ident.ident.to_string()) } else { None } } _ => None, }) .collect(); Ok(FunctionAnalysis { name: func.sig.ident.to_string(), param_count, params, is_async, is_unsafe, has_generics, visibility: format!("{:?}", func.vis), }) } #[derive(Debug, Clone)] pub struct FunctionAnalysis { pub name: String, pub param_count: usize, pub params: Vec<String>, pub is_async: bool, pub is_unsafe: bool, pub has_generics: bool, pub visibility: String, } /// Example: Custom DSL parsing - Simple state machine language pub struct StateMachine { pub name: Ident, pub states: Vec<State>, pub initial: Ident, } pub struct Transition { pub event: Ident, pub target: Ident, pub action: Option<Expr>, } impl Parse for StateMachine { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; input.parse::<kw::machine>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); // Parse initial state content.parse::<kw::initial>()?; content.parse::<Token![:]>()?; let initial: Ident = content.parse()?; content.parse::<Token![;]>()?; // Parse states let mut states = Vec::new(); while !content.is_empty() { states.push(content.parse()?); } Ok(StateMachine { name, states, initial, }) } } impl Parse for State { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); let mut transitions = Vec::new(); while !content.is_empty() { transitions.push(content.parse()?); } Ok(State { name, transitions }) } } impl Parse for Transition { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::on>()?; let event: Ident = input.parse()?; input.parse::<Token![=>]>()?; let target: Ident = input.parse()?; let action = if input.peek(Token![,]) { input.parse::<Token![,]>()?; Some(input.parse()?) } else { None }; input.parse::<Token![;]>()?; Ok(Transition { event, target, action, }) } } mod kw { use syn::custom_keyword; custom_keyword!(state); custom_keyword!(machine); custom_keyword!(initial); custom_keyword!(on); } /// Example: AST transformation - Add logging to functions pub fn inject_logging(mut func: ItemFn) -> ItemFn { let fn_name = &func.sig.ident; let log_entry: Stmt = parse_quote! { println!("Entering function: {}", stringify!(#fn_name)); }; // Insert at the beginning of the function body func.block.stmts.insert(0, log_entry); // Add exit logging before each return let log_exit: Stmt = parse_quote! { println!("Exiting function: {}", stringify!(#fn_name)); }; let mut new_stmts = Vec::new(); for stmt in func.block.stmts.drain(..) { match &stmt { Stmt::Expr(Expr::Return(_), _) => { new_stmts.push(log_exit.clone()); new_stmts.push(stmt); } _ => new_stmts.push(stmt), } } // Add exit log at the end if there's no explicit return if !matches!(new_stmts.last(), Some(Stmt::Expr(Expr::Return(_), _))) { new_stmts.push(log_exit); } func.block.stmts = new_stmts; func } /// Example: Custom attribute parsing #[derive(Debug)] pub struct CompilerDirective { pub optimization_level: u8, pub inline: bool, pub target_features: Vec<String>, } impl Parse for CompilerDirective { fn parse(input: ParseStream) -> Result<Self> { let mut optimization_level = 0; let mut inline = false; let mut target_features = Vec::new(); let vars = Punctuated::<MetaItem, Token![,]>::parse_terminated(input)?; for var in vars { match var.name.to_string().as_str() { "opt_level" => optimization_level = var.value, "inline" => inline = true, "features" => { target_features = var .list .into_iter() .map(|s| s.trim_matches('"').to_string()) .collect(); } _ => { return Err(Error::new( var.name.span(), format!("Unknown directive: {}", var.name), )) } } } Ok(CompilerDirective { optimization_level, inline, target_features, }) } } struct MetaItem { name: Ident, value: u8, list: Vec<String>, } impl Parse for MetaItem { fn parse(input: ParseStream) -> Result<Self> { let name: Ident = input.parse()?; if input.peek(Token![=]) { input.parse::<Token![=]>()?; if let Ok(lit) = input.parse::<ExprLit>() { if let Lit::Int(int) = lit.lit { let value = int.base10_parse::<u8>()?; return Ok(MetaItem { name, value, list: vec![], }); } } } if input.peek(syn::token::Paren) { let content; syn::parenthesized!(content in input); let list = Punctuated::<ExprLit, Token![,]>::parse_terminated(&content)? .into_iter() .filter_map(|lit| { if let Lit::Str(s) = lit.lit { Some(s.value()) } else { None } }) .collect(); return Ok(MetaItem { name, value: 0, list, }); } Ok(MetaItem { name, value: 1, list: vec![], }) } } /// Example: Type analysis for compiler optimizations pub fn analyze_types_in_function(func: &ItemFn) -> HashMap<String, TypeInfo> { let mut type_info = HashMap::new(); // Analyze parameter types for input in &func.sig.inputs { if let FnArg::Typed(pat_type) = input { if let Pat::Ident(ident) = pat_type.pat.as_ref() { let info = analyze_type(&pat_type.ty); type_info.insert(ident.ident.to_string(), info); } } } type_info } #[derive(Debug, Clone)] pub struct TypeInfo { pub is_primitive: bool, pub is_reference: bool, pub is_mutable: bool, pub type_string: String, } fn analyze_type(ty: &Type) -> TypeInfo { match ty { Type::Path(type_path) => { let type_string = quote!(#type_path).to_string(); let is_primitive = matches!( type_string.as_str(), "i8" | "i16" | "i32" | "i64" | "i128" | "u8" | "u16" | "u32" | "u64" | "u128" | "f32" | "f64" | "bool" | "char" ); TypeInfo { is_primitive, is_reference: false, is_mutable: false, type_string, } } Type::Reference(type_ref) => { let inner = analyze_type(&type_ref.elem); TypeInfo { is_reference: true, is_mutable: type_ref.mutability.is_some(), ..inner } } _ => TypeInfo { is_primitive: false, is_reference: false, is_mutable: false, type_string: quote!(#ty).to_string(), }, } } /// Example: Generate optimized code based on const evaluation pub fn const_fold_binary_ops(expr: Expr) -> Expr { match expr { Expr::Binary(mut binary) => { // Recursively fold sub-expressions binary.left = Box::new(const_fold_binary_ops(*binary.left)); binary.right = Box::new(const_fold_binary_ops(*binary.right)); // Try to fold if both operands are literals if let (Expr::Lit(left_lit), Expr::Lit(right_lit)) = (binary.left.as_ref(), binary.right.as_ref()) { if let (Lit::Int(l), Lit::Int(r)) = (&left_lit.lit, &right_lit.lit) { if let (Ok(l_val), Ok(r_val)) = (l.base10_parse::<i64>(), r.base10_parse::<i64>()) { use syn::BinOp; let result = match binary.op { BinOp::Add(_) => Some(l_val + r_val), BinOp::Sub(_) => Some(l_val - r_val), BinOp::Mul(_) => Some(l_val * r_val), BinOp::Div(_) if r_val != 0 => Some(l_val / r_val), _ => None, }; if let Some(val) = result { return parse_quote!(#val); } } } } Expr::Binary(binary) } // Recursively process other expression types Expr::Paren(mut paren) => { paren.expr = Box::new(const_fold_binary_ops(*paren.expr)); Expr::Paren(paren) } Expr::Block(mut block) => { if let Some(Stmt::Expr(expr, _semi)) = block.block.stmts.last_mut() { *expr = const_fold_binary_ops(expr.clone()); } Expr::Block(block) } other => other, } } /// Error handling with span information pub fn validate_function(func: &ItemFn) -> std::result::Result<(), Vec<Error>> { let mut errors = Vec::new(); // Check function name conventions let name = func.sig.ident.to_string(); if name.starts_with('_') && func.vis != Visibility::Inherited { errors.push(Error::new( func.sig.ident.span(), "Public functions should not start with underscore", )); } // Check for missing documentation if !func.attrs.iter().any(|attr| attr.path().is_ident("doc")) { errors.push(Error::new( func.sig.ident.span(), "Missing documentation comment", )); } // Check parameter conventions for input in &func.sig.inputs { let FnArg::Typed(pat_type) = input else { continue; }; let Type::Reference(type_ref) = pat_type.ty.as_ref() else { continue; }; if type_ref.mutability.is_some() { continue; } let Type::Path(path) = type_ref.elem.as_ref() else { continue; }; let Some(ident) = path.path.get_ident() else { continue; }; let type_name = ident.to_string(); if matches!(type_name.as_str(), "String" | "Vec" | "HashMap") { errors.push(Error::new( pat_type.ty.span(), format!( "Consider using &{} instead of {} for better performance", type_name, type_name ), )); } } if errors.is_empty() { Ok(()) } else { Err(errors) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_function_analysis() { let input = quote! { pub async unsafe fn process_data<T>(input: &str, count: usize) -> Result<T> { todo!() } }; let analysis = analyze_function(input).unwrap(); assert_eq!(analysis.name, "process_data"); assert_eq!(analysis.param_count, 2); assert!(analysis.is_async); assert!(analysis.is_unsafe); assert!(analysis.has_generics); assert_eq!(analysis.params, vec!["input", "count"]); } #[test] fn test_inject_logging() { let input: ItemFn = parse_quote! { fn calculate(x: i32, y: i32) -> i32 { if x > y { return x - y; } x + y } }; let modified = inject_logging(input); let output = quote!(#modified).to_string(); assert!(output.contains("Entering function")); assert!(output.contains("Exiting function")); } #[test] fn test_const_folding() { // Test simple constant folding let expr: Expr = parse_quote! { 2 + 3 }; let folded = const_fold_binary_ops(expr); match &folded { Expr::Lit(lit) => { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } else { panic!("Expected integer literal"); } } _ => panic!( "Expected literal after folding, got: {:?}", quote!(#folded).to_string() ), } // Test division let expr: Expr = parse_quote! { 10 / 2 }; let folded = const_fold_binary_ops(expr); if let Expr::Lit(lit) = &folded { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } } // Test non-foldable expression (variable) let expr: Expr = parse_quote! { x + 3 }; let folded = const_fold_binary_ops(expr); assert!(matches!(folded, Expr::Binary(_))); } #[test] fn test_type_analysis() { let func: ItemFn = parse_quote! { fn example(x: i32, s: &str, data: &mut Vec<u8>) {} }; let types = analyze_types_in_function(&func); assert!(types["x"].is_primitive); assert!(types["s"].is_reference); assert!(!types["s"].is_mutable); assert!(types["data"].is_reference); assert!(types["data"].is_mutable); } } pub struct State { pub name: Ident, pub transitions: Vec<Transition>, } }
#![allow(unused)] fn main() { use std::collections::HashMap; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{ parse_quote, Error, Expr, ExprLit, FnArg, ItemFn, Lit, Pat, Result, Stmt, Token, Type, Visibility, }; /// Example: Parsing and analyzing a Rust function pub fn analyze_function(input: TokenStream) -> Result<FunctionAnalysis> { let func: ItemFn = syn::parse2(input)?; let param_count = func.sig.inputs.len(); let is_async = func.sig.asyncness.is_some(); let is_unsafe = func.sig.unsafety.is_some(); let has_generics = !func.sig.generics.params.is_empty(); let params = func .sig .inputs .iter() .filter_map(|arg| match arg { FnArg::Typed(pat_type) => { if let Pat::Ident(ident) = pat_type.pat.as_ref() { Some(ident.ident.to_string()) } else { None } } _ => None, }) .collect(); Ok(FunctionAnalysis { name: func.sig.ident.to_string(), param_count, params, is_async, is_unsafe, has_generics, visibility: format!("{:?}", func.vis), }) } #[derive(Debug, Clone)] pub struct FunctionAnalysis { pub name: String, pub param_count: usize, pub params: Vec<String>, pub is_async: bool, pub is_unsafe: bool, pub has_generics: bool, pub visibility: String, } /// Example: Custom DSL parsing - Simple state machine language pub struct StateMachine { pub name: Ident, pub states: Vec<State>, pub initial: Ident, } pub struct State { pub name: Ident, pub transitions: Vec<Transition>, } impl Parse for StateMachine { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; input.parse::<kw::machine>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); // Parse initial state content.parse::<kw::initial>()?; content.parse::<Token![:]>()?; let initial: Ident = content.parse()?; content.parse::<Token![;]>()?; // Parse states let mut states = Vec::new(); while !content.is_empty() { states.push(content.parse()?); } Ok(StateMachine { name, states, initial, }) } } impl Parse for State { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); let mut transitions = Vec::new(); while !content.is_empty() { transitions.push(content.parse()?); } Ok(State { name, transitions }) } } impl Parse for Transition { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::on>()?; let event: Ident = input.parse()?; input.parse::<Token![=>]>()?; let target: Ident = input.parse()?; let action = if input.peek(Token![,]) { input.parse::<Token![,]>()?; Some(input.parse()?) } else { None }; input.parse::<Token![;]>()?; Ok(Transition { event, target, action, }) } } mod kw { use syn::custom_keyword; custom_keyword!(state); custom_keyword!(machine); custom_keyword!(initial); custom_keyword!(on); } /// Example: AST transformation - Add logging to functions pub fn inject_logging(mut func: ItemFn) -> ItemFn { let fn_name = &func.sig.ident; let log_entry: Stmt = parse_quote! { println!("Entering function: {}", stringify!(#fn_name)); }; // Insert at the beginning of the function body func.block.stmts.insert(0, log_entry); // Add exit logging before each return let log_exit: Stmt = parse_quote! { println!("Exiting function: {}", stringify!(#fn_name)); }; let mut new_stmts = Vec::new(); for stmt in func.block.stmts.drain(..) { match &stmt { Stmt::Expr(Expr::Return(_), _) => { new_stmts.push(log_exit.clone()); new_stmts.push(stmt); } _ => new_stmts.push(stmt), } } // Add exit log at the end if there's no explicit return if !matches!(new_stmts.last(), Some(Stmt::Expr(Expr::Return(_), _))) { new_stmts.push(log_exit); } func.block.stmts = new_stmts; func } /// Example: Custom attribute parsing #[derive(Debug)] pub struct CompilerDirective { pub optimization_level: u8, pub inline: bool, pub target_features: Vec<String>, } impl Parse for CompilerDirective { fn parse(input: ParseStream) -> Result<Self> { let mut optimization_level = 0; let mut inline = false; let mut target_features = Vec::new(); let vars = Punctuated::<MetaItem, Token![,]>::parse_terminated(input)?; for var in vars { match var.name.to_string().as_str() { "opt_level" => optimization_level = var.value, "inline" => inline = true, "features" => { target_features = var .list .into_iter() .map(|s| s.trim_matches('"').to_string()) .collect(); } _ => { return Err(Error::new( var.name.span(), format!("Unknown directive: {}", var.name), )) } } } Ok(CompilerDirective { optimization_level, inline, target_features, }) } } struct MetaItem { name: Ident, value: u8, list: Vec<String>, } impl Parse for MetaItem { fn parse(input: ParseStream) -> Result<Self> { let name: Ident = input.parse()?; if input.peek(Token![=]) { input.parse::<Token![=]>()?; if let Ok(lit) = input.parse::<ExprLit>() { if let Lit::Int(int) = lit.lit { let value = int.base10_parse::<u8>()?; return Ok(MetaItem { name, value, list: vec![], }); } } } if input.peek(syn::token::Paren) { let content; syn::parenthesized!(content in input); let list = Punctuated::<ExprLit, Token![,]>::parse_terminated(&content)? .into_iter() .filter_map(|lit| { if let Lit::Str(s) = lit.lit { Some(s.value()) } else { None } }) .collect(); return Ok(MetaItem { name, value: 0, list, }); } Ok(MetaItem { name, value: 1, list: vec![], }) } } /// Example: Type analysis for compiler optimizations pub fn analyze_types_in_function(func: &ItemFn) -> HashMap<String, TypeInfo> { let mut type_info = HashMap::new(); // Analyze parameter types for input in &func.sig.inputs { if let FnArg::Typed(pat_type) = input { if let Pat::Ident(ident) = pat_type.pat.as_ref() { let info = analyze_type(&pat_type.ty); type_info.insert(ident.ident.to_string(), info); } } } type_info } #[derive(Debug, Clone)] pub struct TypeInfo { pub is_primitive: bool, pub is_reference: bool, pub is_mutable: bool, pub type_string: String, } fn analyze_type(ty: &Type) -> TypeInfo { match ty { Type::Path(type_path) => { let type_string = quote!(#type_path).to_string(); let is_primitive = matches!( type_string.as_str(), "i8" | "i16" | "i32" | "i64" | "i128" | "u8" | "u16" | "u32" | "u64" | "u128" | "f32" | "f64" | "bool" | "char" ); TypeInfo { is_primitive, is_reference: false, is_mutable: false, type_string, } } Type::Reference(type_ref) => { let inner = analyze_type(&type_ref.elem); TypeInfo { is_reference: true, is_mutable: type_ref.mutability.is_some(), ..inner } } _ => TypeInfo { is_primitive: false, is_reference: false, is_mutable: false, type_string: quote!(#ty).to_string(), }, } } /// Example: Generate optimized code based on const evaluation pub fn const_fold_binary_ops(expr: Expr) -> Expr { match expr { Expr::Binary(mut binary) => { // Recursively fold sub-expressions binary.left = Box::new(const_fold_binary_ops(*binary.left)); binary.right = Box::new(const_fold_binary_ops(*binary.right)); // Try to fold if both operands are literals if let (Expr::Lit(left_lit), Expr::Lit(right_lit)) = (binary.left.as_ref(), binary.right.as_ref()) { if let (Lit::Int(l), Lit::Int(r)) = (&left_lit.lit, &right_lit.lit) { if let (Ok(l_val), Ok(r_val)) = (l.base10_parse::<i64>(), r.base10_parse::<i64>()) { use syn::BinOp; let result = match binary.op { BinOp::Add(_) => Some(l_val + r_val), BinOp::Sub(_) => Some(l_val - r_val), BinOp::Mul(_) => Some(l_val * r_val), BinOp::Div(_) if r_val != 0 => Some(l_val / r_val), _ => None, }; if let Some(val) = result { return parse_quote!(#val); } } } } Expr::Binary(binary) } // Recursively process other expression types Expr::Paren(mut paren) => { paren.expr = Box::new(const_fold_binary_ops(*paren.expr)); Expr::Paren(paren) } Expr::Block(mut block) => { if let Some(Stmt::Expr(expr, _semi)) = block.block.stmts.last_mut() { *expr = const_fold_binary_ops(expr.clone()); } Expr::Block(block) } other => other, } } /// Error handling with span information pub fn validate_function(func: &ItemFn) -> std::result::Result<(), Vec<Error>> { let mut errors = Vec::new(); // Check function name conventions let name = func.sig.ident.to_string(); if name.starts_with('_') && func.vis != Visibility::Inherited { errors.push(Error::new( func.sig.ident.span(), "Public functions should not start with underscore", )); } // Check for missing documentation if !func.attrs.iter().any(|attr| attr.path().is_ident("doc")) { errors.push(Error::new( func.sig.ident.span(), "Missing documentation comment", )); } // Check parameter conventions for input in &func.sig.inputs { let FnArg::Typed(pat_type) = input else { continue; }; let Type::Reference(type_ref) = pat_type.ty.as_ref() else { continue; }; if type_ref.mutability.is_some() { continue; } let Type::Path(path) = type_ref.elem.as_ref() else { continue; }; let Some(ident) = path.path.get_ident() else { continue; }; let type_name = ident.to_string(); if matches!(type_name.as_str(), "String" | "Vec" | "HashMap") { errors.push(Error::new( pat_type.ty.span(), format!( "Consider using &{} instead of {} for better performance", type_name, type_name ), )); } } if errors.is_empty() { Ok(()) } else { Err(errors) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_function_analysis() { let input = quote! { pub async unsafe fn process_data<T>(input: &str, count: usize) -> Result<T> { todo!() } }; let analysis = analyze_function(input).unwrap(); assert_eq!(analysis.name, "process_data"); assert_eq!(analysis.param_count, 2); assert!(analysis.is_async); assert!(analysis.is_unsafe); assert!(analysis.has_generics); assert_eq!(analysis.params, vec!["input", "count"]); } #[test] fn test_inject_logging() { let input: ItemFn = parse_quote! { fn calculate(x: i32, y: i32) -> i32 { if x > y { return x - y; } x + y } }; let modified = inject_logging(input); let output = quote!(#modified).to_string(); assert!(output.contains("Entering function")); assert!(output.contains("Exiting function")); } #[test] fn test_const_folding() { // Test simple constant folding let expr: Expr = parse_quote! { 2 + 3 }; let folded = const_fold_binary_ops(expr); match &folded { Expr::Lit(lit) => { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } else { panic!("Expected integer literal"); } } _ => panic!( "Expected literal after folding, got: {:?}", quote!(#folded).to_string() ), } // Test division let expr: Expr = parse_quote! { 10 / 2 }; let folded = const_fold_binary_ops(expr); if let Expr::Lit(lit) = &folded { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } } // Test non-foldable expression (variable) let expr: Expr = parse_quote! { x + 3 }; let folded = const_fold_binary_ops(expr); assert!(matches!(folded, Expr::Binary(_))); } #[test] fn test_type_analysis() { let func: ItemFn = parse_quote! { fn example(x: i32, s: &str, data: &mut Vec<u8>) {} }; let types = analyze_types_in_function(&func); assert!(types["x"].is_primitive); assert!(types["s"].is_reference); assert!(!types["s"].is_mutable); assert!(types["data"].is_reference); assert!(types["data"].is_mutable); } } pub struct Transition { pub event: Ident, pub target: Ident, pub action: Option<Expr>, } }
The Parse implementations for these types demonstrate how to build recursive descent parsers using syn’s parsing infrastructure:
#![allow(unused)] fn main() { use std::collections::HashMap; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{ parse_quote, Error, Expr, ExprLit, FnArg, ItemFn, Lit, Pat, Result, Stmt, Token, Type, Visibility, }; /// Example: Parsing and analyzing a Rust function pub fn analyze_function(input: TokenStream) -> Result<FunctionAnalysis> { let func: ItemFn = syn::parse2(input)?; let param_count = func.sig.inputs.len(); let is_async = func.sig.asyncness.is_some(); let is_unsafe = func.sig.unsafety.is_some(); let has_generics = !func.sig.generics.params.is_empty(); let params = func .sig .inputs .iter() .filter_map(|arg| match arg { FnArg::Typed(pat_type) => { if let Pat::Ident(ident) = pat_type.pat.as_ref() { Some(ident.ident.to_string()) } else { None } } _ => None, }) .collect(); Ok(FunctionAnalysis { name: func.sig.ident.to_string(), param_count, params, is_async, is_unsafe, has_generics, visibility: format!("{:?}", func.vis), }) } #[derive(Debug, Clone)] pub struct FunctionAnalysis { pub name: String, pub param_count: usize, pub params: Vec<String>, pub is_async: bool, pub is_unsafe: bool, pub has_generics: bool, pub visibility: String, } /// Example: Custom DSL parsing - Simple state machine language pub struct StateMachine { pub name: Ident, pub states: Vec<State>, pub initial: Ident, } pub struct State { pub name: Ident, pub transitions: Vec<Transition>, } pub struct Transition { pub event: Ident, pub target: Ident, pub action: Option<Expr>, } impl Parse for State { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); let mut transitions = Vec::new(); while !content.is_empty() { transitions.push(content.parse()?); } Ok(State { name, transitions }) } } impl Parse for Transition { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::on>()?; let event: Ident = input.parse()?; input.parse::<Token![=>]>()?; let target: Ident = input.parse()?; let action = if input.peek(Token![,]) { input.parse::<Token![,]>()?; Some(input.parse()?) } else { None }; input.parse::<Token![;]>()?; Ok(Transition { event, target, action, }) } } mod kw { use syn::custom_keyword; custom_keyword!(state); custom_keyword!(machine); custom_keyword!(initial); custom_keyword!(on); } /// Example: AST transformation - Add logging to functions pub fn inject_logging(mut func: ItemFn) -> ItemFn { let fn_name = &func.sig.ident; let log_entry: Stmt = parse_quote! { println!("Entering function: {}", stringify!(#fn_name)); }; // Insert at the beginning of the function body func.block.stmts.insert(0, log_entry); // Add exit logging before each return let log_exit: Stmt = parse_quote! { println!("Exiting function: {}", stringify!(#fn_name)); }; let mut new_stmts = Vec::new(); for stmt in func.block.stmts.drain(..) { match &stmt { Stmt::Expr(Expr::Return(_), _) => { new_stmts.push(log_exit.clone()); new_stmts.push(stmt); } _ => new_stmts.push(stmt), } } // Add exit log at the end if there's no explicit return if !matches!(new_stmts.last(), Some(Stmt::Expr(Expr::Return(_), _))) { new_stmts.push(log_exit); } func.block.stmts = new_stmts; func } /// Example: Custom attribute parsing #[derive(Debug)] pub struct CompilerDirective { pub optimization_level: u8, pub inline: bool, pub target_features: Vec<String>, } impl Parse for CompilerDirective { fn parse(input: ParseStream) -> Result<Self> { let mut optimization_level = 0; let mut inline = false; let mut target_features = Vec::new(); let vars = Punctuated::<MetaItem, Token![,]>::parse_terminated(input)?; for var in vars { match var.name.to_string().as_str() { "opt_level" => optimization_level = var.value, "inline" => inline = true, "features" => { target_features = var .list .into_iter() .map(|s| s.trim_matches('"').to_string()) .collect(); } _ => { return Err(Error::new( var.name.span(), format!("Unknown directive: {}", var.name), )) } } } Ok(CompilerDirective { optimization_level, inline, target_features, }) } } struct MetaItem { name: Ident, value: u8, list: Vec<String>, } impl Parse for MetaItem { fn parse(input: ParseStream) -> Result<Self> { let name: Ident = input.parse()?; if input.peek(Token![=]) { input.parse::<Token![=]>()?; if let Ok(lit) = input.parse::<ExprLit>() { if let Lit::Int(int) = lit.lit { let value = int.base10_parse::<u8>()?; return Ok(MetaItem { name, value, list: vec![], }); } } } if input.peek(syn::token::Paren) { let content; syn::parenthesized!(content in input); let list = Punctuated::<ExprLit, Token![,]>::parse_terminated(&content)? .into_iter() .filter_map(|lit| { if let Lit::Str(s) = lit.lit { Some(s.value()) } else { None } }) .collect(); return Ok(MetaItem { name, value: 0, list, }); } Ok(MetaItem { name, value: 1, list: vec![], }) } } /// Example: Type analysis for compiler optimizations pub fn analyze_types_in_function(func: &ItemFn) -> HashMap<String, TypeInfo> { let mut type_info = HashMap::new(); // Analyze parameter types for input in &func.sig.inputs { if let FnArg::Typed(pat_type) = input { if let Pat::Ident(ident) = pat_type.pat.as_ref() { let info = analyze_type(&pat_type.ty); type_info.insert(ident.ident.to_string(), info); } } } type_info } #[derive(Debug, Clone)] pub struct TypeInfo { pub is_primitive: bool, pub is_reference: bool, pub is_mutable: bool, pub type_string: String, } fn analyze_type(ty: &Type) -> TypeInfo { match ty { Type::Path(type_path) => { let type_string = quote!(#type_path).to_string(); let is_primitive = matches!( type_string.as_str(), "i8" | "i16" | "i32" | "i64" | "i128" | "u8" | "u16" | "u32" | "u64" | "u128" | "f32" | "f64" | "bool" | "char" ); TypeInfo { is_primitive, is_reference: false, is_mutable: false, type_string, } } Type::Reference(type_ref) => { let inner = analyze_type(&type_ref.elem); TypeInfo { is_reference: true, is_mutable: type_ref.mutability.is_some(), ..inner } } _ => TypeInfo { is_primitive: false, is_reference: false, is_mutable: false, type_string: quote!(#ty).to_string(), }, } } /// Example: Generate optimized code based on const evaluation pub fn const_fold_binary_ops(expr: Expr) -> Expr { match expr { Expr::Binary(mut binary) => { // Recursively fold sub-expressions binary.left = Box::new(const_fold_binary_ops(*binary.left)); binary.right = Box::new(const_fold_binary_ops(*binary.right)); // Try to fold if both operands are literals if let (Expr::Lit(left_lit), Expr::Lit(right_lit)) = (binary.left.as_ref(), binary.right.as_ref()) { if let (Lit::Int(l), Lit::Int(r)) = (&left_lit.lit, &right_lit.lit) { if let (Ok(l_val), Ok(r_val)) = (l.base10_parse::<i64>(), r.base10_parse::<i64>()) { use syn::BinOp; let result = match binary.op { BinOp::Add(_) => Some(l_val + r_val), BinOp::Sub(_) => Some(l_val - r_val), BinOp::Mul(_) => Some(l_val * r_val), BinOp::Div(_) if r_val != 0 => Some(l_val / r_val), _ => None, }; if let Some(val) = result { return parse_quote!(#val); } } } } Expr::Binary(binary) } // Recursively process other expression types Expr::Paren(mut paren) => { paren.expr = Box::new(const_fold_binary_ops(*paren.expr)); Expr::Paren(paren) } Expr::Block(mut block) => { if let Some(Stmt::Expr(expr, _semi)) = block.block.stmts.last_mut() { *expr = const_fold_binary_ops(expr.clone()); } Expr::Block(block) } other => other, } } /// Error handling with span information pub fn validate_function(func: &ItemFn) -> std::result::Result<(), Vec<Error>> { let mut errors = Vec::new(); // Check function name conventions let name = func.sig.ident.to_string(); if name.starts_with('_') && func.vis != Visibility::Inherited { errors.push(Error::new( func.sig.ident.span(), "Public functions should not start with underscore", )); } // Check for missing documentation if !func.attrs.iter().any(|attr| attr.path().is_ident("doc")) { errors.push(Error::new( func.sig.ident.span(), "Missing documentation comment", )); } // Check parameter conventions for input in &func.sig.inputs { let FnArg::Typed(pat_type) = input else { continue; }; let Type::Reference(type_ref) = pat_type.ty.as_ref() else { continue; }; if type_ref.mutability.is_some() { continue; } let Type::Path(path) = type_ref.elem.as_ref() else { continue; }; let Some(ident) = path.path.get_ident() else { continue; }; let type_name = ident.to_string(); if matches!(type_name.as_str(), "String" | "Vec" | "HashMap") { errors.push(Error::new( pat_type.ty.span(), format!( "Consider using &{} instead of {} for better performance", type_name, type_name ), )); } } if errors.is_empty() { Ok(()) } else { Err(errors) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_function_analysis() { let input = quote! { pub async unsafe fn process_data<T>(input: &str, count: usize) -> Result<T> { todo!() } }; let analysis = analyze_function(input).unwrap(); assert_eq!(analysis.name, "process_data"); assert_eq!(analysis.param_count, 2); assert!(analysis.is_async); assert!(analysis.is_unsafe); assert!(analysis.has_generics); assert_eq!(analysis.params, vec!["input", "count"]); } #[test] fn test_inject_logging() { let input: ItemFn = parse_quote! { fn calculate(x: i32, y: i32) -> i32 { if x > y { return x - y; } x + y } }; let modified = inject_logging(input); let output = quote!(#modified).to_string(); assert!(output.contains("Entering function")); assert!(output.contains("Exiting function")); } #[test] fn test_const_folding() { // Test simple constant folding let expr: Expr = parse_quote! { 2 + 3 }; let folded = const_fold_binary_ops(expr); match &folded { Expr::Lit(lit) => { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } else { panic!("Expected integer literal"); } } _ => panic!( "Expected literal after folding, got: {:?}", quote!(#folded).to_string() ), } // Test division let expr: Expr = parse_quote! { 10 / 2 }; let folded = const_fold_binary_ops(expr); if let Expr::Lit(lit) = &folded { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } } // Test non-foldable expression (variable) let expr: Expr = parse_quote! { x + 3 }; let folded = const_fold_binary_ops(expr); assert!(matches!(folded, Expr::Binary(_))); } #[test] fn test_type_analysis() { let func: ItemFn = parse_quote! { fn example(x: i32, s: &str, data: &mut Vec<u8>) {} }; let types = analyze_types_in_function(&func); assert!(types["x"].is_primitive); assert!(types["s"].is_reference); assert!(!types["s"].is_mutable); assert!(types["data"].is_reference); assert!(types["data"].is_mutable); } } impl Parse for StateMachine { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; input.parse::<kw::machine>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); // Parse initial state content.parse::<kw::initial>()?; content.parse::<Token![:]>()?; let initial: Ident = content.parse()?; content.parse::<Token![;]>()?; // Parse states let mut states = Vec::new(); while !content.is_empty() { states.push(content.parse()?); } Ok(StateMachine { name, states, initial, }) } } }
This approach allows you to create languages that feel natural within Rust’s syntax while maintaining full control over parsing and error reporting. The custom keywords are defined using syn’s macro system, providing proper scoping and collision avoidance.
AST Transformation
Compiler construction often requires transforming abstract syntax trees to implement optimizations, add instrumentation, or change program behavior. Syn provides comprehensive facilities for traversing and modifying Rust ASTs while preserving source location information.
#![allow(unused)] fn main() { use std::collections::HashMap; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{ parse_quote, Error, Expr, ExprLit, FnArg, ItemFn, Lit, Pat, Result, Stmt, Token, Type, Visibility, }; /// Example: Parsing and analyzing a Rust function pub fn analyze_function(input: TokenStream) -> Result<FunctionAnalysis> { let func: ItemFn = syn::parse2(input)?; let param_count = func.sig.inputs.len(); let is_async = func.sig.asyncness.is_some(); let is_unsafe = func.sig.unsafety.is_some(); let has_generics = !func.sig.generics.params.is_empty(); let params = func .sig .inputs .iter() .filter_map(|arg| match arg { FnArg::Typed(pat_type) => { if let Pat::Ident(ident) = pat_type.pat.as_ref() { Some(ident.ident.to_string()) } else { None } } _ => None, }) .collect(); Ok(FunctionAnalysis { name: func.sig.ident.to_string(), param_count, params, is_async, is_unsafe, has_generics, visibility: format!("{:?}", func.vis), }) } #[derive(Debug, Clone)] pub struct FunctionAnalysis { pub name: String, pub param_count: usize, pub params: Vec<String>, pub is_async: bool, pub is_unsafe: bool, pub has_generics: bool, pub visibility: String, } /// Example: Custom DSL parsing - Simple state machine language pub struct StateMachine { pub name: Ident, pub states: Vec<State>, pub initial: Ident, } pub struct State { pub name: Ident, pub transitions: Vec<Transition>, } pub struct Transition { pub event: Ident, pub target: Ident, pub action: Option<Expr>, } impl Parse for StateMachine { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; input.parse::<kw::machine>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); // Parse initial state content.parse::<kw::initial>()?; content.parse::<Token![:]>()?; let initial: Ident = content.parse()?; content.parse::<Token![;]>()?; // Parse states let mut states = Vec::new(); while !content.is_empty() { states.push(content.parse()?); } Ok(StateMachine { name, states, initial, }) } } impl Parse for State { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); let mut transitions = Vec::new(); while !content.is_empty() { transitions.push(content.parse()?); } Ok(State { name, transitions }) } } impl Parse for Transition { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::on>()?; let event: Ident = input.parse()?; input.parse::<Token![=>]>()?; let target: Ident = input.parse()?; let action = if input.peek(Token![,]) { input.parse::<Token![,]>()?; Some(input.parse()?) } else { None }; input.parse::<Token![;]>()?; Ok(Transition { event, target, action, }) } } mod kw { use syn::custom_keyword; custom_keyword!(state); custom_keyword!(machine); custom_keyword!(initial); custom_keyword!(on); } /// Example: Custom attribute parsing #[derive(Debug)] pub struct CompilerDirective { pub optimization_level: u8, pub inline: bool, pub target_features: Vec<String>, } impl Parse for CompilerDirective { fn parse(input: ParseStream) -> Result<Self> { let mut optimization_level = 0; let mut inline = false; let mut target_features = Vec::new(); let vars = Punctuated::<MetaItem, Token![,]>::parse_terminated(input)?; for var in vars { match var.name.to_string().as_str() { "opt_level" => optimization_level = var.value, "inline" => inline = true, "features" => { target_features = var .list .into_iter() .map(|s| s.trim_matches('"').to_string()) .collect(); } _ => { return Err(Error::new( var.name.span(), format!("Unknown directive: {}", var.name), )) } } } Ok(CompilerDirective { optimization_level, inline, target_features, }) } } struct MetaItem { name: Ident, value: u8, list: Vec<String>, } impl Parse for MetaItem { fn parse(input: ParseStream) -> Result<Self> { let name: Ident = input.parse()?; if input.peek(Token![=]) { input.parse::<Token![=]>()?; if let Ok(lit) = input.parse::<ExprLit>() { if let Lit::Int(int) = lit.lit { let value = int.base10_parse::<u8>()?; return Ok(MetaItem { name, value, list: vec![], }); } } } if input.peek(syn::token::Paren) { let content; syn::parenthesized!(content in input); let list = Punctuated::<ExprLit, Token![,]>::parse_terminated(&content)? .into_iter() .filter_map(|lit| { if let Lit::Str(s) = lit.lit { Some(s.value()) } else { None } }) .collect(); return Ok(MetaItem { name, value: 0, list, }); } Ok(MetaItem { name, value: 1, list: vec![], }) } } /// Example: Type analysis for compiler optimizations pub fn analyze_types_in_function(func: &ItemFn) -> HashMap<String, TypeInfo> { let mut type_info = HashMap::new(); // Analyze parameter types for input in &func.sig.inputs { if let FnArg::Typed(pat_type) = input { if let Pat::Ident(ident) = pat_type.pat.as_ref() { let info = analyze_type(&pat_type.ty); type_info.insert(ident.ident.to_string(), info); } } } type_info } #[derive(Debug, Clone)] pub struct TypeInfo { pub is_primitive: bool, pub is_reference: bool, pub is_mutable: bool, pub type_string: String, } fn analyze_type(ty: &Type) -> TypeInfo { match ty { Type::Path(type_path) => { let type_string = quote!(#type_path).to_string(); let is_primitive = matches!( type_string.as_str(), "i8" | "i16" | "i32" | "i64" | "i128" | "u8" | "u16" | "u32" | "u64" | "u128" | "f32" | "f64" | "bool" | "char" ); TypeInfo { is_primitive, is_reference: false, is_mutable: false, type_string, } } Type::Reference(type_ref) => { let inner = analyze_type(&type_ref.elem); TypeInfo { is_reference: true, is_mutable: type_ref.mutability.is_some(), ..inner } } _ => TypeInfo { is_primitive: false, is_reference: false, is_mutable: false, type_string: quote!(#ty).to_string(), }, } } /// Example: Generate optimized code based on const evaluation pub fn const_fold_binary_ops(expr: Expr) -> Expr { match expr { Expr::Binary(mut binary) => { // Recursively fold sub-expressions binary.left = Box::new(const_fold_binary_ops(*binary.left)); binary.right = Box::new(const_fold_binary_ops(*binary.right)); // Try to fold if both operands are literals if let (Expr::Lit(left_lit), Expr::Lit(right_lit)) = (binary.left.as_ref(), binary.right.as_ref()) { if let (Lit::Int(l), Lit::Int(r)) = (&left_lit.lit, &right_lit.lit) { if let (Ok(l_val), Ok(r_val)) = (l.base10_parse::<i64>(), r.base10_parse::<i64>()) { use syn::BinOp; let result = match binary.op { BinOp::Add(_) => Some(l_val + r_val), BinOp::Sub(_) => Some(l_val - r_val), BinOp::Mul(_) => Some(l_val * r_val), BinOp::Div(_) if r_val != 0 => Some(l_val / r_val), _ => None, }; if let Some(val) = result { return parse_quote!(#val); } } } } Expr::Binary(binary) } // Recursively process other expression types Expr::Paren(mut paren) => { paren.expr = Box::new(const_fold_binary_ops(*paren.expr)); Expr::Paren(paren) } Expr::Block(mut block) => { if let Some(Stmt::Expr(expr, _semi)) = block.block.stmts.last_mut() { *expr = const_fold_binary_ops(expr.clone()); } Expr::Block(block) } other => other, } } /// Error handling with span information pub fn validate_function(func: &ItemFn) -> std::result::Result<(), Vec<Error>> { let mut errors = Vec::new(); // Check function name conventions let name = func.sig.ident.to_string(); if name.starts_with('_') && func.vis != Visibility::Inherited { errors.push(Error::new( func.sig.ident.span(), "Public functions should not start with underscore", )); } // Check for missing documentation if !func.attrs.iter().any(|attr| attr.path().is_ident("doc")) { errors.push(Error::new( func.sig.ident.span(), "Missing documentation comment", )); } // Check parameter conventions for input in &func.sig.inputs { let FnArg::Typed(pat_type) = input else { continue; }; let Type::Reference(type_ref) = pat_type.ty.as_ref() else { continue; }; if type_ref.mutability.is_some() { continue; } let Type::Path(path) = type_ref.elem.as_ref() else { continue; }; let Some(ident) = path.path.get_ident() else { continue; }; let type_name = ident.to_string(); if matches!(type_name.as_str(), "String" | "Vec" | "HashMap") { errors.push(Error::new( pat_type.ty.span(), format!( "Consider using &{} instead of {} for better performance", type_name, type_name ), )); } } if errors.is_empty() { Ok(()) } else { Err(errors) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_function_analysis() { let input = quote! { pub async unsafe fn process_data<T>(input: &str, count: usize) -> Result<T> { todo!() } }; let analysis = analyze_function(input).unwrap(); assert_eq!(analysis.name, "process_data"); assert_eq!(analysis.param_count, 2); assert!(analysis.is_async); assert!(analysis.is_unsafe); assert!(analysis.has_generics); assert_eq!(analysis.params, vec!["input", "count"]); } #[test] fn test_inject_logging() { let input: ItemFn = parse_quote! { fn calculate(x: i32, y: i32) -> i32 { if x > y { return x - y; } x + y } }; let modified = inject_logging(input); let output = quote!(#modified).to_string(); assert!(output.contains("Entering function")); assert!(output.contains("Exiting function")); } #[test] fn test_const_folding() { // Test simple constant folding let expr: Expr = parse_quote! { 2 + 3 }; let folded = const_fold_binary_ops(expr); match &folded { Expr::Lit(lit) => { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } else { panic!("Expected integer literal"); } } _ => panic!( "Expected literal after folding, got: {:?}", quote!(#folded).to_string() ), } // Test division let expr: Expr = parse_quote! { 10 / 2 }; let folded = const_fold_binary_ops(expr); if let Expr::Lit(lit) = &folded { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } } // Test non-foldable expression (variable) let expr: Expr = parse_quote! { x + 3 }; let folded = const_fold_binary_ops(expr); assert!(matches!(folded, Expr::Binary(_))); } #[test] fn test_type_analysis() { let func: ItemFn = parse_quote! { fn example(x: i32, s: &str, data: &mut Vec<u8>) {} }; let types = analyze_types_in_function(&func); assert!(types["x"].is_primitive); assert!(types["s"].is_reference); assert!(!types["s"].is_mutable); assert!(types["data"].is_reference); assert!(types["data"].is_mutable); } } /// Example: AST transformation - Add logging to functions pub fn inject_logging(mut func: ItemFn) -> ItemFn { let fn_name = &func.sig.ident; let log_entry: Stmt = parse_quote! { println!("Entering function: {}", stringify!(#fn_name)); }; // Insert at the beginning of the function body func.block.stmts.insert(0, log_entry); // Add exit logging before each return let log_exit: Stmt = parse_quote! { println!("Exiting function: {}", stringify!(#fn_name)); }; let mut new_stmts = Vec::new(); for stmt in func.block.stmts.drain(..) { match &stmt { Stmt::Expr(Expr::Return(_), _) => { new_stmts.push(log_exit.clone()); new_stmts.push(stmt); } _ => new_stmts.push(stmt), } } // Add exit log at the end if there's no explicit return if !matches!(new_stmts.last(), Some(Stmt::Expr(Expr::Return(_), _))) { new_stmts.push(log_exit); } func.block.stmts = new_stmts; func } }
This transformation demonstrates several important patterns for AST manipulation. The function modifies the AST in-place, preserving all type information and source locations. The parse_quote! macro allows embedding Rust syntax directly in transformation code, making it easy to construct new AST nodes.
Type Analysis
Understanding type information is crucial for many compiler optimizations. Syn provides detailed type representations that enable sophisticated analysis of Rust’s type system.
#![allow(unused)] fn main() { use std::collections::HashMap; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{ parse_quote, Error, Expr, ExprLit, FnArg, ItemFn, Lit, Pat, Result, Stmt, Token, Type, Visibility, }; /// Example: Parsing and analyzing a Rust function pub fn analyze_function(input: TokenStream) -> Result<FunctionAnalysis> { let func: ItemFn = syn::parse2(input)?; let param_count = func.sig.inputs.len(); let is_async = func.sig.asyncness.is_some(); let is_unsafe = func.sig.unsafety.is_some(); let has_generics = !func.sig.generics.params.is_empty(); let params = func .sig .inputs .iter() .filter_map(|arg| match arg { FnArg::Typed(pat_type) => { if let Pat::Ident(ident) = pat_type.pat.as_ref() { Some(ident.ident.to_string()) } else { None } } _ => None, }) .collect(); Ok(FunctionAnalysis { name: func.sig.ident.to_string(), param_count, params, is_async, is_unsafe, has_generics, visibility: format!("{:?}", func.vis), }) } #[derive(Debug, Clone)] pub struct FunctionAnalysis { pub name: String, pub param_count: usize, pub params: Vec<String>, pub is_async: bool, pub is_unsafe: bool, pub has_generics: bool, pub visibility: String, } /// Example: Custom DSL parsing - Simple state machine language pub struct StateMachine { pub name: Ident, pub states: Vec<State>, pub initial: Ident, } pub struct State { pub name: Ident, pub transitions: Vec<Transition>, } pub struct Transition { pub event: Ident, pub target: Ident, pub action: Option<Expr>, } impl Parse for StateMachine { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; input.parse::<kw::machine>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); // Parse initial state content.parse::<kw::initial>()?; content.parse::<Token![:]>()?; let initial: Ident = content.parse()?; content.parse::<Token![;]>()?; // Parse states let mut states = Vec::new(); while !content.is_empty() { states.push(content.parse()?); } Ok(StateMachine { name, states, initial, }) } } impl Parse for State { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); let mut transitions = Vec::new(); while !content.is_empty() { transitions.push(content.parse()?); } Ok(State { name, transitions }) } } impl Parse for Transition { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::on>()?; let event: Ident = input.parse()?; input.parse::<Token![=>]>()?; let target: Ident = input.parse()?; let action = if input.peek(Token![,]) { input.parse::<Token![,]>()?; Some(input.parse()?) } else { None }; input.parse::<Token![;]>()?; Ok(Transition { event, target, action, }) } } mod kw { use syn::custom_keyword; custom_keyword!(state); custom_keyword!(machine); custom_keyword!(initial); custom_keyword!(on); } /// Example: AST transformation - Add logging to functions pub fn inject_logging(mut func: ItemFn) -> ItemFn { let fn_name = &func.sig.ident; let log_entry: Stmt = parse_quote! { println!("Entering function: {}", stringify!(#fn_name)); }; // Insert at the beginning of the function body func.block.stmts.insert(0, log_entry); // Add exit logging before each return let log_exit: Stmt = parse_quote! { println!("Exiting function: {}", stringify!(#fn_name)); }; let mut new_stmts = Vec::new(); for stmt in func.block.stmts.drain(..) { match &stmt { Stmt::Expr(Expr::Return(_), _) => { new_stmts.push(log_exit.clone()); new_stmts.push(stmt); } _ => new_stmts.push(stmt), } } // Add exit log at the end if there's no explicit return if !matches!(new_stmts.last(), Some(Stmt::Expr(Expr::Return(_), _))) { new_stmts.push(log_exit); } func.block.stmts = new_stmts; func } /// Example: Custom attribute parsing #[derive(Debug)] pub struct CompilerDirective { pub optimization_level: u8, pub inline: bool, pub target_features: Vec<String>, } impl Parse for CompilerDirective { fn parse(input: ParseStream) -> Result<Self> { let mut optimization_level = 0; let mut inline = false; let mut target_features = Vec::new(); let vars = Punctuated::<MetaItem, Token![,]>::parse_terminated(input)?; for var in vars { match var.name.to_string().as_str() { "opt_level" => optimization_level = var.value, "inline" => inline = true, "features" => { target_features = var .list .into_iter() .map(|s| s.trim_matches('"').to_string()) .collect(); } _ => { return Err(Error::new( var.name.span(), format!("Unknown directive: {}", var.name), )) } } } Ok(CompilerDirective { optimization_level, inline, target_features, }) } } struct MetaItem { name: Ident, value: u8, list: Vec<String>, } impl Parse for MetaItem { fn parse(input: ParseStream) -> Result<Self> { let name: Ident = input.parse()?; if input.peek(Token![=]) { input.parse::<Token![=]>()?; if let Ok(lit) = input.parse::<ExprLit>() { if let Lit::Int(int) = lit.lit { let value = int.base10_parse::<u8>()?; return Ok(MetaItem { name, value, list: vec![], }); } } } if input.peek(syn::token::Paren) { let content; syn::parenthesized!(content in input); let list = Punctuated::<ExprLit, Token![,]>::parse_terminated(&content)? .into_iter() .filter_map(|lit| { if let Lit::Str(s) = lit.lit { Some(s.value()) } else { None } }) .collect(); return Ok(MetaItem { name, value: 0, list, }); } Ok(MetaItem { name, value: 1, list: vec![], }) } } #[derive(Debug, Clone)] pub struct TypeInfo { pub is_primitive: bool, pub is_reference: bool, pub is_mutable: bool, pub type_string: String, } fn analyze_type(ty: &Type) -> TypeInfo { match ty { Type::Path(type_path) => { let type_string = quote!(#type_path).to_string(); let is_primitive = matches!( type_string.as_str(), "i8" | "i16" | "i32" | "i64" | "i128" | "u8" | "u16" | "u32" | "u64" | "u128" | "f32" | "f64" | "bool" | "char" ); TypeInfo { is_primitive, is_reference: false, is_mutable: false, type_string, } } Type::Reference(type_ref) => { let inner = analyze_type(&type_ref.elem); TypeInfo { is_reference: true, is_mutable: type_ref.mutability.is_some(), ..inner } } _ => TypeInfo { is_primitive: false, is_reference: false, is_mutable: false, type_string: quote!(#ty).to_string(), }, } } /// Example: Generate optimized code based on const evaluation pub fn const_fold_binary_ops(expr: Expr) -> Expr { match expr { Expr::Binary(mut binary) => { // Recursively fold sub-expressions binary.left = Box::new(const_fold_binary_ops(*binary.left)); binary.right = Box::new(const_fold_binary_ops(*binary.right)); // Try to fold if both operands are literals if let (Expr::Lit(left_lit), Expr::Lit(right_lit)) = (binary.left.as_ref(), binary.right.as_ref()) { if let (Lit::Int(l), Lit::Int(r)) = (&left_lit.lit, &right_lit.lit) { if let (Ok(l_val), Ok(r_val)) = (l.base10_parse::<i64>(), r.base10_parse::<i64>()) { use syn::BinOp; let result = match binary.op { BinOp::Add(_) => Some(l_val + r_val), BinOp::Sub(_) => Some(l_val - r_val), BinOp::Mul(_) => Some(l_val * r_val), BinOp::Div(_) if r_val != 0 => Some(l_val / r_val), _ => None, }; if let Some(val) = result { return parse_quote!(#val); } } } } Expr::Binary(binary) } // Recursively process other expression types Expr::Paren(mut paren) => { paren.expr = Box::new(const_fold_binary_ops(*paren.expr)); Expr::Paren(paren) } Expr::Block(mut block) => { if let Some(Stmt::Expr(expr, _semi)) = block.block.stmts.last_mut() { *expr = const_fold_binary_ops(expr.clone()); } Expr::Block(block) } other => other, } } /// Error handling with span information pub fn validate_function(func: &ItemFn) -> std::result::Result<(), Vec<Error>> { let mut errors = Vec::new(); // Check function name conventions let name = func.sig.ident.to_string(); if name.starts_with('_') && func.vis != Visibility::Inherited { errors.push(Error::new( func.sig.ident.span(), "Public functions should not start with underscore", )); } // Check for missing documentation if !func.attrs.iter().any(|attr| attr.path().is_ident("doc")) { errors.push(Error::new( func.sig.ident.span(), "Missing documentation comment", )); } // Check parameter conventions for input in &func.sig.inputs { let FnArg::Typed(pat_type) = input else { continue; }; let Type::Reference(type_ref) = pat_type.ty.as_ref() else { continue; }; if type_ref.mutability.is_some() { continue; } let Type::Path(path) = type_ref.elem.as_ref() else { continue; }; let Some(ident) = path.path.get_ident() else { continue; }; let type_name = ident.to_string(); if matches!(type_name.as_str(), "String" | "Vec" | "HashMap") { errors.push(Error::new( pat_type.ty.span(), format!( "Consider using &{} instead of {} for better performance", type_name, type_name ), )); } } if errors.is_empty() { Ok(()) } else { Err(errors) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_function_analysis() { let input = quote! { pub async unsafe fn process_data<T>(input: &str, count: usize) -> Result<T> { todo!() } }; let analysis = analyze_function(input).unwrap(); assert_eq!(analysis.name, "process_data"); assert_eq!(analysis.param_count, 2); assert!(analysis.is_async); assert!(analysis.is_unsafe); assert!(analysis.has_generics); assert_eq!(analysis.params, vec!["input", "count"]); } #[test] fn test_inject_logging() { let input: ItemFn = parse_quote! { fn calculate(x: i32, y: i32) -> i32 { if x > y { return x - y; } x + y } }; let modified = inject_logging(input); let output = quote!(#modified).to_string(); assert!(output.contains("Entering function")); assert!(output.contains("Exiting function")); } #[test] fn test_const_folding() { // Test simple constant folding let expr: Expr = parse_quote! { 2 + 3 }; let folded = const_fold_binary_ops(expr); match &folded { Expr::Lit(lit) => { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } else { panic!("Expected integer literal"); } } _ => panic!( "Expected literal after folding, got: {:?}", quote!(#folded).to_string() ), } // Test division let expr: Expr = parse_quote! { 10 / 2 }; let folded = const_fold_binary_ops(expr); if let Expr::Lit(lit) = &folded { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } } // Test non-foldable expression (variable) let expr: Expr = parse_quote! { x + 3 }; let folded = const_fold_binary_ops(expr); assert!(matches!(folded, Expr::Binary(_))); } #[test] fn test_type_analysis() { let func: ItemFn = parse_quote! { fn example(x: i32, s: &str, data: &mut Vec<u8>) {} }; let types = analyze_types_in_function(&func); assert!(types["x"].is_primitive); assert!(types["s"].is_reference); assert!(!types["s"].is_mutable); assert!(types["data"].is_reference); assert!(types["data"].is_mutable); } } /// Example: Type analysis for compiler optimizations pub fn analyze_types_in_function(func: &ItemFn) -> HashMap<String, TypeInfo> { let mut type_info = HashMap::new(); // Analyze parameter types for input in &func.sig.inputs { if let FnArg::Typed(pat_type) = input { if let Pat::Ident(ident) = pat_type.pat.as_ref() { let info = analyze_type(&pat_type.ty); type_info.insert(ident.ident.to_string(), info); } } } type_info } }
#![allow(unused)] fn main() { use std::collections::HashMap; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{ parse_quote, Error, Expr, ExprLit, FnArg, ItemFn, Lit, Pat, Result, Stmt, Token, Type, Visibility, }; /// Example: Parsing and analyzing a Rust function pub fn analyze_function(input: TokenStream) -> Result<FunctionAnalysis> { let func: ItemFn = syn::parse2(input)?; let param_count = func.sig.inputs.len(); let is_async = func.sig.asyncness.is_some(); let is_unsafe = func.sig.unsafety.is_some(); let has_generics = !func.sig.generics.params.is_empty(); let params = func .sig .inputs .iter() .filter_map(|arg| match arg { FnArg::Typed(pat_type) => { if let Pat::Ident(ident) = pat_type.pat.as_ref() { Some(ident.ident.to_string()) } else { None } } _ => None, }) .collect(); Ok(FunctionAnalysis { name: func.sig.ident.to_string(), param_count, params, is_async, is_unsafe, has_generics, visibility: format!("{:?}", func.vis), }) } #[derive(Debug, Clone)] pub struct FunctionAnalysis { pub name: String, pub param_count: usize, pub params: Vec<String>, pub is_async: bool, pub is_unsafe: bool, pub has_generics: bool, pub visibility: String, } /// Example: Custom DSL parsing - Simple state machine language pub struct StateMachine { pub name: Ident, pub states: Vec<State>, pub initial: Ident, } pub struct State { pub name: Ident, pub transitions: Vec<Transition>, } pub struct Transition { pub event: Ident, pub target: Ident, pub action: Option<Expr>, } impl Parse for StateMachine { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; input.parse::<kw::machine>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); // Parse initial state content.parse::<kw::initial>()?; content.parse::<Token![:]>()?; let initial: Ident = content.parse()?; content.parse::<Token![;]>()?; // Parse states let mut states = Vec::new(); while !content.is_empty() { states.push(content.parse()?); } Ok(StateMachine { name, states, initial, }) } } impl Parse for State { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); let mut transitions = Vec::new(); while !content.is_empty() { transitions.push(content.parse()?); } Ok(State { name, transitions }) } } impl Parse for Transition { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::on>()?; let event: Ident = input.parse()?; input.parse::<Token![=>]>()?; let target: Ident = input.parse()?; let action = if input.peek(Token![,]) { input.parse::<Token![,]>()?; Some(input.parse()?) } else { None }; input.parse::<Token![;]>()?; Ok(Transition { event, target, action, }) } } mod kw { use syn::custom_keyword; custom_keyword!(state); custom_keyword!(machine); custom_keyword!(initial); custom_keyword!(on); } /// Example: AST transformation - Add logging to functions pub fn inject_logging(mut func: ItemFn) -> ItemFn { let fn_name = &func.sig.ident; let log_entry: Stmt = parse_quote! { println!("Entering function: {}", stringify!(#fn_name)); }; // Insert at the beginning of the function body func.block.stmts.insert(0, log_entry); // Add exit logging before each return let log_exit: Stmt = parse_quote! { println!("Exiting function: {}", stringify!(#fn_name)); }; let mut new_stmts = Vec::new(); for stmt in func.block.stmts.drain(..) { match &stmt { Stmt::Expr(Expr::Return(_), _) => { new_stmts.push(log_exit.clone()); new_stmts.push(stmt); } _ => new_stmts.push(stmt), } } // Add exit log at the end if there's no explicit return if !matches!(new_stmts.last(), Some(Stmt::Expr(Expr::Return(_), _))) { new_stmts.push(log_exit); } func.block.stmts = new_stmts; func } /// Example: Custom attribute parsing #[derive(Debug)] pub struct CompilerDirective { pub optimization_level: u8, pub inline: bool, pub target_features: Vec<String>, } impl Parse for CompilerDirective { fn parse(input: ParseStream) -> Result<Self> { let mut optimization_level = 0; let mut inline = false; let mut target_features = Vec::new(); let vars = Punctuated::<MetaItem, Token![,]>::parse_terminated(input)?; for var in vars { match var.name.to_string().as_str() { "opt_level" => optimization_level = var.value, "inline" => inline = true, "features" => { target_features = var .list .into_iter() .map(|s| s.trim_matches('"').to_string()) .collect(); } _ => { return Err(Error::new( var.name.span(), format!("Unknown directive: {}", var.name), )) } } } Ok(CompilerDirective { optimization_level, inline, target_features, }) } } struct MetaItem { name: Ident, value: u8, list: Vec<String>, } impl Parse for MetaItem { fn parse(input: ParseStream) -> Result<Self> { let name: Ident = input.parse()?; if input.peek(Token![=]) { input.parse::<Token![=]>()?; if let Ok(lit) = input.parse::<ExprLit>() { if let Lit::Int(int) = lit.lit { let value = int.base10_parse::<u8>()?; return Ok(MetaItem { name, value, list: vec![], }); } } } if input.peek(syn::token::Paren) { let content; syn::parenthesized!(content in input); let list = Punctuated::<ExprLit, Token![,]>::parse_terminated(&content)? .into_iter() .filter_map(|lit| { if let Lit::Str(s) = lit.lit { Some(s.value()) } else { None } }) .collect(); return Ok(MetaItem { name, value: 0, list, }); } Ok(MetaItem { name, value: 1, list: vec![], }) } } /// Example: Type analysis for compiler optimizations pub fn analyze_types_in_function(func: &ItemFn) -> HashMap<String, TypeInfo> { let mut type_info = HashMap::new(); // Analyze parameter types for input in &func.sig.inputs { if let FnArg::Typed(pat_type) = input { if let Pat::Ident(ident) = pat_type.pat.as_ref() { let info = analyze_type(&pat_type.ty); type_info.insert(ident.ident.to_string(), info); } } } type_info } #[derive(Debug, Clone)] pub struct TypeInfo { pub is_primitive: bool, pub is_reference: bool, pub is_mutable: bool, pub type_string: String, } /// Example: Generate optimized code based on const evaluation pub fn const_fold_binary_ops(expr: Expr) -> Expr { match expr { Expr::Binary(mut binary) => { // Recursively fold sub-expressions binary.left = Box::new(const_fold_binary_ops(*binary.left)); binary.right = Box::new(const_fold_binary_ops(*binary.right)); // Try to fold if both operands are literals if let (Expr::Lit(left_lit), Expr::Lit(right_lit)) = (binary.left.as_ref(), binary.right.as_ref()) { if let (Lit::Int(l), Lit::Int(r)) = (&left_lit.lit, &right_lit.lit) { if let (Ok(l_val), Ok(r_val)) = (l.base10_parse::<i64>(), r.base10_parse::<i64>()) { use syn::BinOp; let result = match binary.op { BinOp::Add(_) => Some(l_val + r_val), BinOp::Sub(_) => Some(l_val - r_val), BinOp::Mul(_) => Some(l_val * r_val), BinOp::Div(_) if r_val != 0 => Some(l_val / r_val), _ => None, }; if let Some(val) = result { return parse_quote!(#val); } } } } Expr::Binary(binary) } // Recursively process other expression types Expr::Paren(mut paren) => { paren.expr = Box::new(const_fold_binary_ops(*paren.expr)); Expr::Paren(paren) } Expr::Block(mut block) => { if let Some(Stmt::Expr(expr, _semi)) = block.block.stmts.last_mut() { *expr = const_fold_binary_ops(expr.clone()); } Expr::Block(block) } other => other, } } /// Error handling with span information pub fn validate_function(func: &ItemFn) -> std::result::Result<(), Vec<Error>> { let mut errors = Vec::new(); // Check function name conventions let name = func.sig.ident.to_string(); if name.starts_with('_') && func.vis != Visibility::Inherited { errors.push(Error::new( func.sig.ident.span(), "Public functions should not start with underscore", )); } // Check for missing documentation if !func.attrs.iter().any(|attr| attr.path().is_ident("doc")) { errors.push(Error::new( func.sig.ident.span(), "Missing documentation comment", )); } // Check parameter conventions for input in &func.sig.inputs { let FnArg::Typed(pat_type) = input else { continue; }; let Type::Reference(type_ref) = pat_type.ty.as_ref() else { continue; }; if type_ref.mutability.is_some() { continue; } let Type::Path(path) = type_ref.elem.as_ref() else { continue; }; let Some(ident) = path.path.get_ident() else { continue; }; let type_name = ident.to_string(); if matches!(type_name.as_str(), "String" | "Vec" | "HashMap") { errors.push(Error::new( pat_type.ty.span(), format!( "Consider using &{} instead of {} for better performance", type_name, type_name ), )); } } if errors.is_empty() { Ok(()) } else { Err(errors) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_function_analysis() { let input = quote! { pub async unsafe fn process_data<T>(input: &str, count: usize) -> Result<T> { todo!() } }; let analysis = analyze_function(input).unwrap(); assert_eq!(analysis.name, "process_data"); assert_eq!(analysis.param_count, 2); assert!(analysis.is_async); assert!(analysis.is_unsafe); assert!(analysis.has_generics); assert_eq!(analysis.params, vec!["input", "count"]); } #[test] fn test_inject_logging() { let input: ItemFn = parse_quote! { fn calculate(x: i32, y: i32) -> i32 { if x > y { return x - y; } x + y } }; let modified = inject_logging(input); let output = quote!(#modified).to_string(); assert!(output.contains("Entering function")); assert!(output.contains("Exiting function")); } #[test] fn test_const_folding() { // Test simple constant folding let expr: Expr = parse_quote! { 2 + 3 }; let folded = const_fold_binary_ops(expr); match &folded { Expr::Lit(lit) => { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } else { panic!("Expected integer literal"); } } _ => panic!( "Expected literal after folding, got: {:?}", quote!(#folded).to_string() ), } // Test division let expr: Expr = parse_quote! { 10 / 2 }; let folded = const_fold_binary_ops(expr); if let Expr::Lit(lit) = &folded { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } } // Test non-foldable expression (variable) let expr: Expr = parse_quote! { x + 3 }; let folded = const_fold_binary_ops(expr); assert!(matches!(folded, Expr::Binary(_))); } #[test] fn test_type_analysis() { let func: ItemFn = parse_quote! { fn example(x: i32, s: &str, data: &mut Vec<u8>) {} }; let types = analyze_types_in_function(&func); assert!(types["x"].is_primitive); assert!(types["s"].is_reference); assert!(!types["s"].is_mutable); assert!(types["data"].is_reference); assert!(types["data"].is_mutable); } } fn analyze_type(ty: &Type) -> TypeInfo { match ty { Type::Path(type_path) => { let type_string = quote!(#type_path).to_string(); let is_primitive = matches!( type_string.as_str(), "i8" | "i16" | "i32" | "i64" | "i128" | "u8" | "u16" | "u32" | "u64" | "u128" | "f32" | "f64" | "bool" | "char" ); TypeInfo { is_primitive, is_reference: false, is_mutable: false, type_string, } } Type::Reference(type_ref) => { let inner = analyze_type(&type_ref.elem); TypeInfo { is_reference: true, is_mutable: type_ref.mutability.is_some(), ..inner } } _ => TypeInfo { is_primitive: false, is_reference: false, is_mutable: false, type_string: quote!(#ty).to_string(), }, } } }
This type analysis can inform optimization decisions, such as determining whether values can be stack-allocated, identifying opportunities for specialization, or checking whether types implement specific traits.
Constant Folding
Compile-time evaluation of expressions is a fundamental compiler optimization. Syn’s expression types make it straightforward to implement constant folding and other algebraic simplifications.
#![allow(unused)] fn main() { use std::collections::HashMap; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{ parse_quote, Error, Expr, ExprLit, FnArg, ItemFn, Lit, Pat, Result, Stmt, Token, Type, Visibility, }; /// Example: Parsing and analyzing a Rust function pub fn analyze_function(input: TokenStream) -> Result<FunctionAnalysis> { let func: ItemFn = syn::parse2(input)?; let param_count = func.sig.inputs.len(); let is_async = func.sig.asyncness.is_some(); let is_unsafe = func.sig.unsafety.is_some(); let has_generics = !func.sig.generics.params.is_empty(); let params = func .sig .inputs .iter() .filter_map(|arg| match arg { FnArg::Typed(pat_type) => { if let Pat::Ident(ident) = pat_type.pat.as_ref() { Some(ident.ident.to_string()) } else { None } } _ => None, }) .collect(); Ok(FunctionAnalysis { name: func.sig.ident.to_string(), param_count, params, is_async, is_unsafe, has_generics, visibility: format!("{:?}", func.vis), }) } #[derive(Debug, Clone)] pub struct FunctionAnalysis { pub name: String, pub param_count: usize, pub params: Vec<String>, pub is_async: bool, pub is_unsafe: bool, pub has_generics: bool, pub visibility: String, } /// Example: Custom DSL parsing - Simple state machine language pub struct StateMachine { pub name: Ident, pub states: Vec<State>, pub initial: Ident, } pub struct State { pub name: Ident, pub transitions: Vec<Transition>, } pub struct Transition { pub event: Ident, pub target: Ident, pub action: Option<Expr>, } impl Parse for StateMachine { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; input.parse::<kw::machine>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); // Parse initial state content.parse::<kw::initial>()?; content.parse::<Token![:]>()?; let initial: Ident = content.parse()?; content.parse::<Token![;]>()?; // Parse states let mut states = Vec::new(); while !content.is_empty() { states.push(content.parse()?); } Ok(StateMachine { name, states, initial, }) } } impl Parse for State { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); let mut transitions = Vec::new(); while !content.is_empty() { transitions.push(content.parse()?); } Ok(State { name, transitions }) } } impl Parse for Transition { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::on>()?; let event: Ident = input.parse()?; input.parse::<Token![=>]>()?; let target: Ident = input.parse()?; let action = if input.peek(Token![,]) { input.parse::<Token![,]>()?; Some(input.parse()?) } else { None }; input.parse::<Token![;]>()?; Ok(Transition { event, target, action, }) } } mod kw { use syn::custom_keyword; custom_keyword!(state); custom_keyword!(machine); custom_keyword!(initial); custom_keyword!(on); } /// Example: AST transformation - Add logging to functions pub fn inject_logging(mut func: ItemFn) -> ItemFn { let fn_name = &func.sig.ident; let log_entry: Stmt = parse_quote! { println!("Entering function: {}", stringify!(#fn_name)); }; // Insert at the beginning of the function body func.block.stmts.insert(0, log_entry); // Add exit logging before each return let log_exit: Stmt = parse_quote! { println!("Exiting function: {}", stringify!(#fn_name)); }; let mut new_stmts = Vec::new(); for stmt in func.block.stmts.drain(..) { match &stmt { Stmt::Expr(Expr::Return(_), _) => { new_stmts.push(log_exit.clone()); new_stmts.push(stmt); } _ => new_stmts.push(stmt), } } // Add exit log at the end if there's no explicit return if !matches!(new_stmts.last(), Some(Stmt::Expr(Expr::Return(_), _))) { new_stmts.push(log_exit); } func.block.stmts = new_stmts; func } /// Example: Custom attribute parsing #[derive(Debug)] pub struct CompilerDirective { pub optimization_level: u8, pub inline: bool, pub target_features: Vec<String>, } impl Parse for CompilerDirective { fn parse(input: ParseStream) -> Result<Self> { let mut optimization_level = 0; let mut inline = false; let mut target_features = Vec::new(); let vars = Punctuated::<MetaItem, Token![,]>::parse_terminated(input)?; for var in vars { match var.name.to_string().as_str() { "opt_level" => optimization_level = var.value, "inline" => inline = true, "features" => { target_features = var .list .into_iter() .map(|s| s.trim_matches('"').to_string()) .collect(); } _ => { return Err(Error::new( var.name.span(), format!("Unknown directive: {}", var.name), )) } } } Ok(CompilerDirective { optimization_level, inline, target_features, }) } } struct MetaItem { name: Ident, value: u8, list: Vec<String>, } impl Parse for MetaItem { fn parse(input: ParseStream) -> Result<Self> { let name: Ident = input.parse()?; if input.peek(Token![=]) { input.parse::<Token![=]>()?; if let Ok(lit) = input.parse::<ExprLit>() { if let Lit::Int(int) = lit.lit { let value = int.base10_parse::<u8>()?; return Ok(MetaItem { name, value, list: vec![], }); } } } if input.peek(syn::token::Paren) { let content; syn::parenthesized!(content in input); let list = Punctuated::<ExprLit, Token![,]>::parse_terminated(&content)? .into_iter() .filter_map(|lit| { if let Lit::Str(s) = lit.lit { Some(s.value()) } else { None } }) .collect(); return Ok(MetaItem { name, value: 0, list, }); } Ok(MetaItem { name, value: 1, list: vec![], }) } } /// Example: Type analysis for compiler optimizations pub fn analyze_types_in_function(func: &ItemFn) -> HashMap<String, TypeInfo> { let mut type_info = HashMap::new(); // Analyze parameter types for input in &func.sig.inputs { if let FnArg::Typed(pat_type) = input { if let Pat::Ident(ident) = pat_type.pat.as_ref() { let info = analyze_type(&pat_type.ty); type_info.insert(ident.ident.to_string(), info); } } } type_info } #[derive(Debug, Clone)] pub struct TypeInfo { pub is_primitive: bool, pub is_reference: bool, pub is_mutable: bool, pub type_string: String, } fn analyze_type(ty: &Type) -> TypeInfo { match ty { Type::Path(type_path) => { let type_string = quote!(#type_path).to_string(); let is_primitive = matches!( type_string.as_str(), "i8" | "i16" | "i32" | "i64" | "i128" | "u8" | "u16" | "u32" | "u64" | "u128" | "f32" | "f64" | "bool" | "char" ); TypeInfo { is_primitive, is_reference: false, is_mutable: false, type_string, } } Type::Reference(type_ref) => { let inner = analyze_type(&type_ref.elem); TypeInfo { is_reference: true, is_mutable: type_ref.mutability.is_some(), ..inner } } _ => TypeInfo { is_primitive: false, is_reference: false, is_mutable: false, type_string: quote!(#ty).to_string(), }, } } /// Error handling with span information pub fn validate_function(func: &ItemFn) -> std::result::Result<(), Vec<Error>> { let mut errors = Vec::new(); // Check function name conventions let name = func.sig.ident.to_string(); if name.starts_with('_') && func.vis != Visibility::Inherited { errors.push(Error::new( func.sig.ident.span(), "Public functions should not start with underscore", )); } // Check for missing documentation if !func.attrs.iter().any(|attr| attr.path().is_ident("doc")) { errors.push(Error::new( func.sig.ident.span(), "Missing documentation comment", )); } // Check parameter conventions for input in &func.sig.inputs { let FnArg::Typed(pat_type) = input else { continue; }; let Type::Reference(type_ref) = pat_type.ty.as_ref() else { continue; }; if type_ref.mutability.is_some() { continue; } let Type::Path(path) = type_ref.elem.as_ref() else { continue; }; let Some(ident) = path.path.get_ident() else { continue; }; let type_name = ident.to_string(); if matches!(type_name.as_str(), "String" | "Vec" | "HashMap") { errors.push(Error::new( pat_type.ty.span(), format!( "Consider using &{} instead of {} for better performance", type_name, type_name ), )); } } if errors.is_empty() { Ok(()) } else { Err(errors) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_function_analysis() { let input = quote! { pub async unsafe fn process_data<T>(input: &str, count: usize) -> Result<T> { todo!() } }; let analysis = analyze_function(input).unwrap(); assert_eq!(analysis.name, "process_data"); assert_eq!(analysis.param_count, 2); assert!(analysis.is_async); assert!(analysis.is_unsafe); assert!(analysis.has_generics); assert_eq!(analysis.params, vec!["input", "count"]); } #[test] fn test_inject_logging() { let input: ItemFn = parse_quote! { fn calculate(x: i32, y: i32) -> i32 { if x > y { return x - y; } x + y } }; let modified = inject_logging(input); let output = quote!(#modified).to_string(); assert!(output.contains("Entering function")); assert!(output.contains("Exiting function")); } #[test] fn test_const_folding() { // Test simple constant folding let expr: Expr = parse_quote! { 2 + 3 }; let folded = const_fold_binary_ops(expr); match &folded { Expr::Lit(lit) => { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } else { panic!("Expected integer literal"); } } _ => panic!( "Expected literal after folding, got: {:?}", quote!(#folded).to_string() ), } // Test division let expr: Expr = parse_quote! { 10 / 2 }; let folded = const_fold_binary_ops(expr); if let Expr::Lit(lit) = &folded { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } } // Test non-foldable expression (variable) let expr: Expr = parse_quote! { x + 3 }; let folded = const_fold_binary_ops(expr); assert!(matches!(folded, Expr::Binary(_))); } #[test] fn test_type_analysis() { let func: ItemFn = parse_quote! { fn example(x: i32, s: &str, data: &mut Vec<u8>) {} }; let types = analyze_types_in_function(&func); assert!(types["x"].is_primitive); assert!(types["s"].is_reference); assert!(!types["s"].is_mutable); assert!(types["data"].is_reference); assert!(types["data"].is_mutable); } } /// Example: Generate optimized code based on const evaluation pub fn const_fold_binary_ops(expr: Expr) -> Expr { match expr { Expr::Binary(mut binary) => { // Recursively fold sub-expressions binary.left = Box::new(const_fold_binary_ops(*binary.left)); binary.right = Box::new(const_fold_binary_ops(*binary.right)); // Try to fold if both operands are literals if let (Expr::Lit(left_lit), Expr::Lit(right_lit)) = (binary.left.as_ref(), binary.right.as_ref()) { if let (Lit::Int(l), Lit::Int(r)) = (&left_lit.lit, &right_lit.lit) { if let (Ok(l_val), Ok(r_val)) = (l.base10_parse::<i64>(), r.base10_parse::<i64>()) { use syn::BinOp; let result = match binary.op { BinOp::Add(_) => Some(l_val + r_val), BinOp::Sub(_) => Some(l_val - r_val), BinOp::Mul(_) => Some(l_val * r_val), BinOp::Div(_) if r_val != 0 => Some(l_val / r_val), _ => None, }; if let Some(val) = result { return parse_quote!(#val); } } } } Expr::Binary(binary) } // Recursively process other expression types Expr::Paren(mut paren) => { paren.expr = Box::new(const_fold_binary_ops(*paren.expr)); Expr::Paren(paren) } Expr::Block(mut block) => { if let Some(Stmt::Expr(expr, _semi)) = block.block.stmts.last_mut() { *expr = const_fold_binary_ops(expr.clone()); } Expr::Block(block) } other => other, } } }
This example shows how to recursively traverse expression trees and apply transformations. While simple, this pattern extends to more sophisticated optimizations like strength reduction, algebraic simplification, and dead code elimination.
Custom Attributes and Directives
Compilers often need to process custom attributes that control optimization, linking, or other compilation aspects. Syn makes it easy to define and parse such attributes with full type safety.
#![allow(unused)] fn main() { use std::collections::HashMap; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{ parse_quote, Error, Expr, ExprLit, FnArg, ItemFn, Lit, Pat, Result, Stmt, Token, Type, Visibility, }; /// Example: Parsing and analyzing a Rust function pub fn analyze_function(input: TokenStream) -> Result<FunctionAnalysis> { let func: ItemFn = syn::parse2(input)?; let param_count = func.sig.inputs.len(); let is_async = func.sig.asyncness.is_some(); let is_unsafe = func.sig.unsafety.is_some(); let has_generics = !func.sig.generics.params.is_empty(); let params = func .sig .inputs .iter() .filter_map(|arg| match arg { FnArg::Typed(pat_type) => { if let Pat::Ident(ident) = pat_type.pat.as_ref() { Some(ident.ident.to_string()) } else { None } } _ => None, }) .collect(); Ok(FunctionAnalysis { name: func.sig.ident.to_string(), param_count, params, is_async, is_unsafe, has_generics, visibility: format!("{:?}", func.vis), }) } #[derive(Debug, Clone)] pub struct FunctionAnalysis { pub name: String, pub param_count: usize, pub params: Vec<String>, pub is_async: bool, pub is_unsafe: bool, pub has_generics: bool, pub visibility: String, } /// Example: Custom DSL parsing - Simple state machine language pub struct StateMachine { pub name: Ident, pub states: Vec<State>, pub initial: Ident, } pub struct State { pub name: Ident, pub transitions: Vec<Transition>, } pub struct Transition { pub event: Ident, pub target: Ident, pub action: Option<Expr>, } impl Parse for StateMachine { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; input.parse::<kw::machine>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); // Parse initial state content.parse::<kw::initial>()?; content.parse::<Token![:]>()?; let initial: Ident = content.parse()?; content.parse::<Token![;]>()?; // Parse states let mut states = Vec::new(); while !content.is_empty() { states.push(content.parse()?); } Ok(StateMachine { name, states, initial, }) } } impl Parse for State { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); let mut transitions = Vec::new(); while !content.is_empty() { transitions.push(content.parse()?); } Ok(State { name, transitions }) } } impl Parse for Transition { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::on>()?; let event: Ident = input.parse()?; input.parse::<Token![=>]>()?; let target: Ident = input.parse()?; let action = if input.peek(Token![,]) { input.parse::<Token![,]>()?; Some(input.parse()?) } else { None }; input.parse::<Token![;]>()?; Ok(Transition { event, target, action, }) } } mod kw { use syn::custom_keyword; custom_keyword!(state); custom_keyword!(machine); custom_keyword!(initial); custom_keyword!(on); } /// Example: AST transformation - Add logging to functions pub fn inject_logging(mut func: ItemFn) -> ItemFn { let fn_name = &func.sig.ident; let log_entry: Stmt = parse_quote! { println!("Entering function: {}", stringify!(#fn_name)); }; // Insert at the beginning of the function body func.block.stmts.insert(0, log_entry); // Add exit logging before each return let log_exit: Stmt = parse_quote! { println!("Exiting function: {}", stringify!(#fn_name)); }; let mut new_stmts = Vec::new(); for stmt in func.block.stmts.drain(..) { match &stmt { Stmt::Expr(Expr::Return(_), _) => { new_stmts.push(log_exit.clone()); new_stmts.push(stmt); } _ => new_stmts.push(stmt), } } // Add exit log at the end if there's no explicit return if !matches!(new_stmts.last(), Some(Stmt::Expr(Expr::Return(_), _))) { new_stmts.push(log_exit); } func.block.stmts = new_stmts; func } impl Parse for CompilerDirective { fn parse(input: ParseStream) -> Result<Self> { let mut optimization_level = 0; let mut inline = false; let mut target_features = Vec::new(); let vars = Punctuated::<MetaItem, Token![,]>::parse_terminated(input)?; for var in vars { match var.name.to_string().as_str() { "opt_level" => optimization_level = var.value, "inline" => inline = true, "features" => { target_features = var .list .into_iter() .map(|s| s.trim_matches('"').to_string()) .collect(); } _ => { return Err(Error::new( var.name.span(), format!("Unknown directive: {}", var.name), )) } } } Ok(CompilerDirective { optimization_level, inline, target_features, }) } } struct MetaItem { name: Ident, value: u8, list: Vec<String>, } impl Parse for MetaItem { fn parse(input: ParseStream) -> Result<Self> { let name: Ident = input.parse()?; if input.peek(Token![=]) { input.parse::<Token![=]>()?; if let Ok(lit) = input.parse::<ExprLit>() { if let Lit::Int(int) = lit.lit { let value = int.base10_parse::<u8>()?; return Ok(MetaItem { name, value, list: vec![], }); } } } if input.peek(syn::token::Paren) { let content; syn::parenthesized!(content in input); let list = Punctuated::<ExprLit, Token![,]>::parse_terminated(&content)? .into_iter() .filter_map(|lit| { if let Lit::Str(s) = lit.lit { Some(s.value()) } else { None } }) .collect(); return Ok(MetaItem { name, value: 0, list, }); } Ok(MetaItem { name, value: 1, list: vec![], }) } } /// Example: Type analysis for compiler optimizations pub fn analyze_types_in_function(func: &ItemFn) -> HashMap<String, TypeInfo> { let mut type_info = HashMap::new(); // Analyze parameter types for input in &func.sig.inputs { if let FnArg::Typed(pat_type) = input { if let Pat::Ident(ident) = pat_type.pat.as_ref() { let info = analyze_type(&pat_type.ty); type_info.insert(ident.ident.to_string(), info); } } } type_info } #[derive(Debug, Clone)] pub struct TypeInfo { pub is_primitive: bool, pub is_reference: bool, pub is_mutable: bool, pub type_string: String, } fn analyze_type(ty: &Type) -> TypeInfo { match ty { Type::Path(type_path) => { let type_string = quote!(#type_path).to_string(); let is_primitive = matches!( type_string.as_str(), "i8" | "i16" | "i32" | "i64" | "i128" | "u8" | "u16" | "u32" | "u64" | "u128" | "f32" | "f64" | "bool" | "char" ); TypeInfo { is_primitive, is_reference: false, is_mutable: false, type_string, } } Type::Reference(type_ref) => { let inner = analyze_type(&type_ref.elem); TypeInfo { is_reference: true, is_mutable: type_ref.mutability.is_some(), ..inner } } _ => TypeInfo { is_primitive: false, is_reference: false, is_mutable: false, type_string: quote!(#ty).to_string(), }, } } /// Example: Generate optimized code based on const evaluation pub fn const_fold_binary_ops(expr: Expr) -> Expr { match expr { Expr::Binary(mut binary) => { // Recursively fold sub-expressions binary.left = Box::new(const_fold_binary_ops(*binary.left)); binary.right = Box::new(const_fold_binary_ops(*binary.right)); // Try to fold if both operands are literals if let (Expr::Lit(left_lit), Expr::Lit(right_lit)) = (binary.left.as_ref(), binary.right.as_ref()) { if let (Lit::Int(l), Lit::Int(r)) = (&left_lit.lit, &right_lit.lit) { if let (Ok(l_val), Ok(r_val)) = (l.base10_parse::<i64>(), r.base10_parse::<i64>()) { use syn::BinOp; let result = match binary.op { BinOp::Add(_) => Some(l_val + r_val), BinOp::Sub(_) => Some(l_val - r_val), BinOp::Mul(_) => Some(l_val * r_val), BinOp::Div(_) if r_val != 0 => Some(l_val / r_val), _ => None, }; if let Some(val) = result { return parse_quote!(#val); } } } } Expr::Binary(binary) } // Recursively process other expression types Expr::Paren(mut paren) => { paren.expr = Box::new(const_fold_binary_ops(*paren.expr)); Expr::Paren(paren) } Expr::Block(mut block) => { if let Some(Stmt::Expr(expr, _semi)) = block.block.stmts.last_mut() { *expr = const_fold_binary_ops(expr.clone()); } Expr::Block(block) } other => other, } } /// Error handling with span information pub fn validate_function(func: &ItemFn) -> std::result::Result<(), Vec<Error>> { let mut errors = Vec::new(); // Check function name conventions let name = func.sig.ident.to_string(); if name.starts_with('_') && func.vis != Visibility::Inherited { errors.push(Error::new( func.sig.ident.span(), "Public functions should not start with underscore", )); } // Check for missing documentation if !func.attrs.iter().any(|attr| attr.path().is_ident("doc")) { errors.push(Error::new( func.sig.ident.span(), "Missing documentation comment", )); } // Check parameter conventions for input in &func.sig.inputs { let FnArg::Typed(pat_type) = input else { continue; }; let Type::Reference(type_ref) = pat_type.ty.as_ref() else { continue; }; if type_ref.mutability.is_some() { continue; } let Type::Path(path) = type_ref.elem.as_ref() else { continue; }; let Some(ident) = path.path.get_ident() else { continue; }; let type_name = ident.to_string(); if matches!(type_name.as_str(), "String" | "Vec" | "HashMap") { errors.push(Error::new( pat_type.ty.span(), format!( "Consider using &{} instead of {} for better performance", type_name, type_name ), )); } } if errors.is_empty() { Ok(()) } else { Err(errors) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_function_analysis() { let input = quote! { pub async unsafe fn process_data<T>(input: &str, count: usize) -> Result<T> { todo!() } }; let analysis = analyze_function(input).unwrap(); assert_eq!(analysis.name, "process_data"); assert_eq!(analysis.param_count, 2); assert!(analysis.is_async); assert!(analysis.is_unsafe); assert!(analysis.has_generics); assert_eq!(analysis.params, vec!["input", "count"]); } #[test] fn test_inject_logging() { let input: ItemFn = parse_quote! { fn calculate(x: i32, y: i32) -> i32 { if x > y { return x - y; } x + y } }; let modified = inject_logging(input); let output = quote!(#modified).to_string(); assert!(output.contains("Entering function")); assert!(output.contains("Exiting function")); } #[test] fn test_const_folding() { // Test simple constant folding let expr: Expr = parse_quote! { 2 + 3 }; let folded = const_fold_binary_ops(expr); match &folded { Expr::Lit(lit) => { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } else { panic!("Expected integer literal"); } } _ => panic!( "Expected literal after folding, got: {:?}", quote!(#folded).to_string() ), } // Test division let expr: Expr = parse_quote! { 10 / 2 }; let folded = const_fold_binary_ops(expr); if let Expr::Lit(lit) = &folded { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } } // Test non-foldable expression (variable) let expr: Expr = parse_quote! { x + 3 }; let folded = const_fold_binary_ops(expr); assert!(matches!(folded, Expr::Binary(_))); } #[test] fn test_type_analysis() { let func: ItemFn = parse_quote! { fn example(x: i32, s: &str, data: &mut Vec<u8>) {} }; let types = analyze_types_in_function(&func); assert!(types["x"].is_primitive); assert!(types["s"].is_reference); assert!(!types["s"].is_mutable); assert!(types["data"].is_reference); assert!(types["data"].is_mutable); } } /// Example: Custom attribute parsing #[derive(Debug)] pub struct CompilerDirective { pub optimization_level: u8, pub inline: bool, pub target_features: Vec<String>, } }
#![allow(unused)] fn main() { use std::collections::HashMap; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{ parse_quote, Error, Expr, ExprLit, FnArg, ItemFn, Lit, Pat, Result, Stmt, Token, Type, Visibility, }; /// Example: Parsing and analyzing a Rust function pub fn analyze_function(input: TokenStream) -> Result<FunctionAnalysis> { let func: ItemFn = syn::parse2(input)?; let param_count = func.sig.inputs.len(); let is_async = func.sig.asyncness.is_some(); let is_unsafe = func.sig.unsafety.is_some(); let has_generics = !func.sig.generics.params.is_empty(); let params = func .sig .inputs .iter() .filter_map(|arg| match arg { FnArg::Typed(pat_type) => { if let Pat::Ident(ident) = pat_type.pat.as_ref() { Some(ident.ident.to_string()) } else { None } } _ => None, }) .collect(); Ok(FunctionAnalysis { name: func.sig.ident.to_string(), param_count, params, is_async, is_unsafe, has_generics, visibility: format!("{:?}", func.vis), }) } #[derive(Debug, Clone)] pub struct FunctionAnalysis { pub name: String, pub param_count: usize, pub params: Vec<String>, pub is_async: bool, pub is_unsafe: bool, pub has_generics: bool, pub visibility: String, } /// Example: Custom DSL parsing - Simple state machine language pub struct StateMachine { pub name: Ident, pub states: Vec<State>, pub initial: Ident, } pub struct State { pub name: Ident, pub transitions: Vec<Transition>, } pub struct Transition { pub event: Ident, pub target: Ident, pub action: Option<Expr>, } impl Parse for StateMachine { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; input.parse::<kw::machine>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); // Parse initial state content.parse::<kw::initial>()?; content.parse::<Token![:]>()?; let initial: Ident = content.parse()?; content.parse::<Token![;]>()?; // Parse states let mut states = Vec::new(); while !content.is_empty() { states.push(content.parse()?); } Ok(StateMachine { name, states, initial, }) } } impl Parse for State { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); let mut transitions = Vec::new(); while !content.is_empty() { transitions.push(content.parse()?); } Ok(State { name, transitions }) } } impl Parse for Transition { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::on>()?; let event: Ident = input.parse()?; input.parse::<Token![=>]>()?; let target: Ident = input.parse()?; let action = if input.peek(Token![,]) { input.parse::<Token![,]>()?; Some(input.parse()?) } else { None }; input.parse::<Token![;]>()?; Ok(Transition { event, target, action, }) } } mod kw { use syn::custom_keyword; custom_keyword!(state); custom_keyword!(machine); custom_keyword!(initial); custom_keyword!(on); } /// Example: AST transformation - Add logging to functions pub fn inject_logging(mut func: ItemFn) -> ItemFn { let fn_name = &func.sig.ident; let log_entry: Stmt = parse_quote! { println!("Entering function: {}", stringify!(#fn_name)); }; // Insert at the beginning of the function body func.block.stmts.insert(0, log_entry); // Add exit logging before each return let log_exit: Stmt = parse_quote! { println!("Exiting function: {}", stringify!(#fn_name)); }; let mut new_stmts = Vec::new(); for stmt in func.block.stmts.drain(..) { match &stmt { Stmt::Expr(Expr::Return(_), _) => { new_stmts.push(log_exit.clone()); new_stmts.push(stmt); } _ => new_stmts.push(stmt), } } // Add exit log at the end if there's no explicit return if !matches!(new_stmts.last(), Some(Stmt::Expr(Expr::Return(_), _))) { new_stmts.push(log_exit); } func.block.stmts = new_stmts; func } /// Example: Custom attribute parsing #[derive(Debug)] pub struct CompilerDirective { pub optimization_level: u8, pub inline: bool, pub target_features: Vec<String>, } struct MetaItem { name: Ident, value: u8, list: Vec<String>, } impl Parse for MetaItem { fn parse(input: ParseStream) -> Result<Self> { let name: Ident = input.parse()?; if input.peek(Token![=]) { input.parse::<Token![=]>()?; if let Ok(lit) = input.parse::<ExprLit>() { if let Lit::Int(int) = lit.lit { let value = int.base10_parse::<u8>()?; return Ok(MetaItem { name, value, list: vec![], }); } } } if input.peek(syn::token::Paren) { let content; syn::parenthesized!(content in input); let list = Punctuated::<ExprLit, Token![,]>::parse_terminated(&content)? .into_iter() .filter_map(|lit| { if let Lit::Str(s) = lit.lit { Some(s.value()) } else { None } }) .collect(); return Ok(MetaItem { name, value: 0, list, }); } Ok(MetaItem { name, value: 1, list: vec![], }) } } /// Example: Type analysis for compiler optimizations pub fn analyze_types_in_function(func: &ItemFn) -> HashMap<String, TypeInfo> { let mut type_info = HashMap::new(); // Analyze parameter types for input in &func.sig.inputs { if let FnArg::Typed(pat_type) = input { if let Pat::Ident(ident) = pat_type.pat.as_ref() { let info = analyze_type(&pat_type.ty); type_info.insert(ident.ident.to_string(), info); } } } type_info } #[derive(Debug, Clone)] pub struct TypeInfo { pub is_primitive: bool, pub is_reference: bool, pub is_mutable: bool, pub type_string: String, } fn analyze_type(ty: &Type) -> TypeInfo { match ty { Type::Path(type_path) => { let type_string = quote!(#type_path).to_string(); let is_primitive = matches!( type_string.as_str(), "i8" | "i16" | "i32" | "i64" | "i128" | "u8" | "u16" | "u32" | "u64" | "u128" | "f32" | "f64" | "bool" | "char" ); TypeInfo { is_primitive, is_reference: false, is_mutable: false, type_string, } } Type::Reference(type_ref) => { let inner = analyze_type(&type_ref.elem); TypeInfo { is_reference: true, is_mutable: type_ref.mutability.is_some(), ..inner } } _ => TypeInfo { is_primitive: false, is_reference: false, is_mutable: false, type_string: quote!(#ty).to_string(), }, } } /// Example: Generate optimized code based on const evaluation pub fn const_fold_binary_ops(expr: Expr) -> Expr { match expr { Expr::Binary(mut binary) => { // Recursively fold sub-expressions binary.left = Box::new(const_fold_binary_ops(*binary.left)); binary.right = Box::new(const_fold_binary_ops(*binary.right)); // Try to fold if both operands are literals if let (Expr::Lit(left_lit), Expr::Lit(right_lit)) = (binary.left.as_ref(), binary.right.as_ref()) { if let (Lit::Int(l), Lit::Int(r)) = (&left_lit.lit, &right_lit.lit) { if let (Ok(l_val), Ok(r_val)) = (l.base10_parse::<i64>(), r.base10_parse::<i64>()) { use syn::BinOp; let result = match binary.op { BinOp::Add(_) => Some(l_val + r_val), BinOp::Sub(_) => Some(l_val - r_val), BinOp::Mul(_) => Some(l_val * r_val), BinOp::Div(_) if r_val != 0 => Some(l_val / r_val), _ => None, }; if let Some(val) = result { return parse_quote!(#val); } } } } Expr::Binary(binary) } // Recursively process other expression types Expr::Paren(mut paren) => { paren.expr = Box::new(const_fold_binary_ops(*paren.expr)); Expr::Paren(paren) } Expr::Block(mut block) => { if let Some(Stmt::Expr(expr, _semi)) = block.block.stmts.last_mut() { *expr = const_fold_binary_ops(expr.clone()); } Expr::Block(block) } other => other, } } /// Error handling with span information pub fn validate_function(func: &ItemFn) -> std::result::Result<(), Vec<Error>> { let mut errors = Vec::new(); // Check function name conventions let name = func.sig.ident.to_string(); if name.starts_with('_') && func.vis != Visibility::Inherited { errors.push(Error::new( func.sig.ident.span(), "Public functions should not start with underscore", )); } // Check for missing documentation if !func.attrs.iter().any(|attr| attr.path().is_ident("doc")) { errors.push(Error::new( func.sig.ident.span(), "Missing documentation comment", )); } // Check parameter conventions for input in &func.sig.inputs { let FnArg::Typed(pat_type) = input else { continue; }; let Type::Reference(type_ref) = pat_type.ty.as_ref() else { continue; }; if type_ref.mutability.is_some() { continue; } let Type::Path(path) = type_ref.elem.as_ref() else { continue; }; let Some(ident) = path.path.get_ident() else { continue; }; let type_name = ident.to_string(); if matches!(type_name.as_str(), "String" | "Vec" | "HashMap") { errors.push(Error::new( pat_type.ty.span(), format!( "Consider using &{} instead of {} for better performance", type_name, type_name ), )); } } if errors.is_empty() { Ok(()) } else { Err(errors) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_function_analysis() { let input = quote! { pub async unsafe fn process_data<T>(input: &str, count: usize) -> Result<T> { todo!() } }; let analysis = analyze_function(input).unwrap(); assert_eq!(analysis.name, "process_data"); assert_eq!(analysis.param_count, 2); assert!(analysis.is_async); assert!(analysis.is_unsafe); assert!(analysis.has_generics); assert_eq!(analysis.params, vec!["input", "count"]); } #[test] fn test_inject_logging() { let input: ItemFn = parse_quote! { fn calculate(x: i32, y: i32) -> i32 { if x > y { return x - y; } x + y } }; let modified = inject_logging(input); let output = quote!(#modified).to_string(); assert!(output.contains("Entering function")); assert!(output.contains("Exiting function")); } #[test] fn test_const_folding() { // Test simple constant folding let expr: Expr = parse_quote! { 2 + 3 }; let folded = const_fold_binary_ops(expr); match &folded { Expr::Lit(lit) => { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } else { panic!("Expected integer literal"); } } _ => panic!( "Expected literal after folding, got: {:?}", quote!(#folded).to_string() ), } // Test division let expr: Expr = parse_quote! { 10 / 2 }; let folded = const_fold_binary_ops(expr); if let Expr::Lit(lit) = &folded { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } } // Test non-foldable expression (variable) let expr: Expr = parse_quote! { x + 3 }; let folded = const_fold_binary_ops(expr); assert!(matches!(folded, Expr::Binary(_))); } #[test] fn test_type_analysis() { let func: ItemFn = parse_quote! { fn example(x: i32, s: &str, data: &mut Vec<u8>) {} }; let types = analyze_types_in_function(&func); assert!(types["x"].is_primitive); assert!(types["s"].is_reference); assert!(!types["s"].is_mutable); assert!(types["data"].is_reference); assert!(types["data"].is_mutable); } } impl Parse for CompilerDirective { fn parse(input: ParseStream) -> Result<Self> { let mut optimization_level = 0; let mut inline = false; let mut target_features = Vec::new(); let vars = Punctuated::<MetaItem, Token![,]>::parse_terminated(input)?; for var in vars { match var.name.to_string().as_str() { "opt_level" => optimization_level = var.value, "inline" => inline = true, "features" => { target_features = var .list .into_iter() .map(|s| s.trim_matches('"').to_string()) .collect(); } _ => { return Err(Error::new( var.name.span(), format!("Unknown directive: {}", var.name), )) } } } Ok(CompilerDirective { optimization_level, inline, target_features, }) } } }
These custom attributes can control various aspects of compilation, from optimization levels to target-specific features, providing a clean interface between source code and compiler behavior.
Error Handling and Diagnostics
High-quality error messages are essential for any compiler. Syn provides detailed span information for every AST node, enabling precise error reporting that points directly to problematic source code.
#![allow(unused)] fn main() { use std::collections::HashMap; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{ parse_quote, Error, Expr, ExprLit, FnArg, ItemFn, Lit, Pat, Result, Stmt, Token, Type, Visibility, }; /// Example: Parsing and analyzing a Rust function pub fn analyze_function(input: TokenStream) -> Result<FunctionAnalysis> { let func: ItemFn = syn::parse2(input)?; let param_count = func.sig.inputs.len(); let is_async = func.sig.asyncness.is_some(); let is_unsafe = func.sig.unsafety.is_some(); let has_generics = !func.sig.generics.params.is_empty(); let params = func .sig .inputs .iter() .filter_map(|arg| match arg { FnArg::Typed(pat_type) => { if let Pat::Ident(ident) = pat_type.pat.as_ref() { Some(ident.ident.to_string()) } else { None } } _ => None, }) .collect(); Ok(FunctionAnalysis { name: func.sig.ident.to_string(), param_count, params, is_async, is_unsafe, has_generics, visibility: format!("{:?}", func.vis), }) } #[derive(Debug, Clone)] pub struct FunctionAnalysis { pub name: String, pub param_count: usize, pub params: Vec<String>, pub is_async: bool, pub is_unsafe: bool, pub has_generics: bool, pub visibility: String, } /// Example: Custom DSL parsing - Simple state machine language pub struct StateMachine { pub name: Ident, pub states: Vec<State>, pub initial: Ident, } pub struct State { pub name: Ident, pub transitions: Vec<Transition>, } pub struct Transition { pub event: Ident, pub target: Ident, pub action: Option<Expr>, } impl Parse for StateMachine { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; input.parse::<kw::machine>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); // Parse initial state content.parse::<kw::initial>()?; content.parse::<Token![:]>()?; let initial: Ident = content.parse()?; content.parse::<Token![;]>()?; // Parse states let mut states = Vec::new(); while !content.is_empty() { states.push(content.parse()?); } Ok(StateMachine { name, states, initial, }) } } impl Parse for State { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::state>()?; let name: Ident = input.parse()?; let content; syn::braced!(content in input); let mut transitions = Vec::new(); while !content.is_empty() { transitions.push(content.parse()?); } Ok(State { name, transitions }) } } impl Parse for Transition { fn parse(input: ParseStream) -> Result<Self> { input.parse::<kw::on>()?; let event: Ident = input.parse()?; input.parse::<Token![=>]>()?; let target: Ident = input.parse()?; let action = if input.peek(Token![,]) { input.parse::<Token![,]>()?; Some(input.parse()?) } else { None }; input.parse::<Token![;]>()?; Ok(Transition { event, target, action, }) } } mod kw { use syn::custom_keyword; custom_keyword!(state); custom_keyword!(machine); custom_keyword!(initial); custom_keyword!(on); } /// Example: AST transformation - Add logging to functions pub fn inject_logging(mut func: ItemFn) -> ItemFn { let fn_name = &func.sig.ident; let log_entry: Stmt = parse_quote! { println!("Entering function: {}", stringify!(#fn_name)); }; // Insert at the beginning of the function body func.block.stmts.insert(0, log_entry); // Add exit logging before each return let log_exit: Stmt = parse_quote! { println!("Exiting function: {}", stringify!(#fn_name)); }; let mut new_stmts = Vec::new(); for stmt in func.block.stmts.drain(..) { match &stmt { Stmt::Expr(Expr::Return(_), _) => { new_stmts.push(log_exit.clone()); new_stmts.push(stmt); } _ => new_stmts.push(stmt), } } // Add exit log at the end if there's no explicit return if !matches!(new_stmts.last(), Some(Stmt::Expr(Expr::Return(_), _))) { new_stmts.push(log_exit); } func.block.stmts = new_stmts; func } /// Example: Custom attribute parsing #[derive(Debug)] pub struct CompilerDirective { pub optimization_level: u8, pub inline: bool, pub target_features: Vec<String>, } impl Parse for CompilerDirective { fn parse(input: ParseStream) -> Result<Self> { let mut optimization_level = 0; let mut inline = false; let mut target_features = Vec::new(); let vars = Punctuated::<MetaItem, Token![,]>::parse_terminated(input)?; for var in vars { match var.name.to_string().as_str() { "opt_level" => optimization_level = var.value, "inline" => inline = true, "features" => { target_features = var .list .into_iter() .map(|s| s.trim_matches('"').to_string()) .collect(); } _ => { return Err(Error::new( var.name.span(), format!("Unknown directive: {}", var.name), )) } } } Ok(CompilerDirective { optimization_level, inline, target_features, }) } } struct MetaItem { name: Ident, value: u8, list: Vec<String>, } impl Parse for MetaItem { fn parse(input: ParseStream) -> Result<Self> { let name: Ident = input.parse()?; if input.peek(Token![=]) { input.parse::<Token![=]>()?; if let Ok(lit) = input.parse::<ExprLit>() { if let Lit::Int(int) = lit.lit { let value = int.base10_parse::<u8>()?; return Ok(MetaItem { name, value, list: vec![], }); } } } if input.peek(syn::token::Paren) { let content; syn::parenthesized!(content in input); let list = Punctuated::<ExprLit, Token![,]>::parse_terminated(&content)? .into_iter() .filter_map(|lit| { if let Lit::Str(s) = lit.lit { Some(s.value()) } else { None } }) .collect(); return Ok(MetaItem { name, value: 0, list, }); } Ok(MetaItem { name, value: 1, list: vec![], }) } } /// Example: Type analysis for compiler optimizations pub fn analyze_types_in_function(func: &ItemFn) -> HashMap<String, TypeInfo> { let mut type_info = HashMap::new(); // Analyze parameter types for input in &func.sig.inputs { if let FnArg::Typed(pat_type) = input { if let Pat::Ident(ident) = pat_type.pat.as_ref() { let info = analyze_type(&pat_type.ty); type_info.insert(ident.ident.to_string(), info); } } } type_info } #[derive(Debug, Clone)] pub struct TypeInfo { pub is_primitive: bool, pub is_reference: bool, pub is_mutable: bool, pub type_string: String, } fn analyze_type(ty: &Type) -> TypeInfo { match ty { Type::Path(type_path) => { let type_string = quote!(#type_path).to_string(); let is_primitive = matches!( type_string.as_str(), "i8" | "i16" | "i32" | "i64" | "i128" | "u8" | "u16" | "u32" | "u64" | "u128" | "f32" | "f64" | "bool" | "char" ); TypeInfo { is_primitive, is_reference: false, is_mutable: false, type_string, } } Type::Reference(type_ref) => { let inner = analyze_type(&type_ref.elem); TypeInfo { is_reference: true, is_mutable: type_ref.mutability.is_some(), ..inner } } _ => TypeInfo { is_primitive: false, is_reference: false, is_mutable: false, type_string: quote!(#ty).to_string(), }, } } /// Example: Generate optimized code based on const evaluation pub fn const_fold_binary_ops(expr: Expr) -> Expr { match expr { Expr::Binary(mut binary) => { // Recursively fold sub-expressions binary.left = Box::new(const_fold_binary_ops(*binary.left)); binary.right = Box::new(const_fold_binary_ops(*binary.right)); // Try to fold if both operands are literals if let (Expr::Lit(left_lit), Expr::Lit(right_lit)) = (binary.left.as_ref(), binary.right.as_ref()) { if let (Lit::Int(l), Lit::Int(r)) = (&left_lit.lit, &right_lit.lit) { if let (Ok(l_val), Ok(r_val)) = (l.base10_parse::<i64>(), r.base10_parse::<i64>()) { use syn::BinOp; let result = match binary.op { BinOp::Add(_) => Some(l_val + r_val), BinOp::Sub(_) => Some(l_val - r_val), BinOp::Mul(_) => Some(l_val * r_val), BinOp::Div(_) if r_val != 0 => Some(l_val / r_val), _ => None, }; if let Some(val) = result { return parse_quote!(#val); } } } } Expr::Binary(binary) } // Recursively process other expression types Expr::Paren(mut paren) => { paren.expr = Box::new(const_fold_binary_ops(*paren.expr)); Expr::Paren(paren) } Expr::Block(mut block) => { if let Some(Stmt::Expr(expr, _semi)) = block.block.stmts.last_mut() { *expr = const_fold_binary_ops(expr.clone()); } Expr::Block(block) } other => other, } } #[cfg(test)] mod tests { use super::*; #[test] fn test_function_analysis() { let input = quote! { pub async unsafe fn process_data<T>(input: &str, count: usize) -> Result<T> { todo!() } }; let analysis = analyze_function(input).unwrap(); assert_eq!(analysis.name, "process_data"); assert_eq!(analysis.param_count, 2); assert!(analysis.is_async); assert!(analysis.is_unsafe); assert!(analysis.has_generics); assert_eq!(analysis.params, vec!["input", "count"]); } #[test] fn test_inject_logging() { let input: ItemFn = parse_quote! { fn calculate(x: i32, y: i32) -> i32 { if x > y { return x - y; } x + y } }; let modified = inject_logging(input); let output = quote!(#modified).to_string(); assert!(output.contains("Entering function")); assert!(output.contains("Exiting function")); } #[test] fn test_const_folding() { // Test simple constant folding let expr: Expr = parse_quote! { 2 + 3 }; let folded = const_fold_binary_ops(expr); match &folded { Expr::Lit(lit) => { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } else { panic!("Expected integer literal"); } } _ => panic!( "Expected literal after folding, got: {:?}", quote!(#folded).to_string() ), } // Test division let expr: Expr = parse_quote! { 10 / 2 }; let folded = const_fold_binary_ops(expr); if let Expr::Lit(lit) = &folded { if let Lit::Int(int) = &lit.lit { assert_eq!(int.base10_parse::<i64>().unwrap(), 5); } } // Test non-foldable expression (variable) let expr: Expr = parse_quote! { x + 3 }; let folded = const_fold_binary_ops(expr); assert!(matches!(folded, Expr::Binary(_))); } #[test] fn test_type_analysis() { let func: ItemFn = parse_quote! { fn example(x: i32, s: &str, data: &mut Vec<u8>) {} }; let types = analyze_types_in_function(&func); assert!(types["x"].is_primitive); assert!(types["s"].is_reference); assert!(!types["s"].is_mutable); assert!(types["data"].is_reference); assert!(types["data"].is_mutable); } } /// Error handling with span information pub fn validate_function(func: &ItemFn) -> std::result::Result<(), Vec<Error>> { let mut errors = Vec::new(); // Check function name conventions let name = func.sig.ident.to_string(); if name.starts_with('_') && func.vis != Visibility::Inherited { errors.push(Error::new( func.sig.ident.span(), "Public functions should not start with underscore", )); } // Check for missing documentation if !func.attrs.iter().any(|attr| attr.path().is_ident("doc")) { errors.push(Error::new( func.sig.ident.span(), "Missing documentation comment", )); } // Check parameter conventions for input in &func.sig.inputs { let FnArg::Typed(pat_type) = input else { continue; }; let Type::Reference(type_ref) = pat_type.ty.as_ref() else { continue; }; if type_ref.mutability.is_some() { continue; } let Type::Path(path) = type_ref.elem.as_ref() else { continue; }; let Some(ident) = path.path.get_ident() else { continue; }; let type_name = ident.to_string(); if matches!(type_name.as_str(), "String" | "Vec" | "HashMap") { errors.push(Error::new( pat_type.ty.span(), format!( "Consider using &{} instead of {} for better performance", type_name, type_name ), )); } } if errors.is_empty() { Ok(()) } else { Err(errors) } } }
The Error type in syn includes span information that integrates with Rust’s diagnostic system, producing error messages that feel native to the Rust compiler. This integration is particularly valuable when building tools that extend the Rust compiler or when creating lints and code analysis tools.
Integration with Quote
Syn works hand-in-hand with the quote crate for code generation. While syn parses TokenStreams into ASTs, quote converts ASTs back into TokenStreams. This bidirectional conversion enables powerful metaprogramming patterns.
The quote! macro supports interpolation of syn types, making it easy to construct complex code fragments. The parse_quote! macro combines both operations, parsing tokens directly into syn types. This combination provides a complete toolkit for reading, analyzing, transforming, and generating Rust code.
Advanced Patterns
Building production compilers with syn involves several advanced patterns. Visitor traits (Visit and VisitMut) enable systematic traversal of large ASTs. Fold traits support functional transformation patterns. The punctuated module handles comma-separated lists with proper parsing of trailing commas.
For performance-critical applications, syn supports parsing without allocating strings for identifiers, using lifetime parameters to borrow from the original token stream. This zero-copy parsing can significantly improve performance when processing large codebases.
Best Practices
When using syn for compiler construction, organize your code to separate parsing, analysis, and transformation phases. Define clear AST types for your domain-specific constructs. Preserve span information throughout transformations to maintain high-quality error messages.
Test your parsers thoroughly using syn’s parsing functions directly. The library’s strong typing catches many errors at compile time, but runtime testing remains essential for ensuring correct parsing of edge cases.
Consider performance implications when designing AST transformations. While syn is highly optimized, traversing large ASTs multiple times can impact compilation speed. Combine related transformations when possible to minimize traversal overhead.
Common Patterns
Several patterns appear repeatedly in syn-based compiler tools. The parse-transform-generate pipeline forms the basis of most procedural macros. Custom parsing often combines syn’s built-in types with domain-specific structures. Hygiene preservation ensures that generated code doesn’t accidentally capture or shadow user identifiers.
Error accumulation allows reporting multiple problems in a single compilation pass. Span manipulation enables precise error messages and suggestions. Integration with the broader Rust ecosystem through traits and standard types ensures that syn-based tools compose well with other compiler infrastructure.
Syn provides a solid foundation for building sophisticated compiler tools that integrate seamlessly with Rust. Whether you’re creating procedural macros, building development tools, or implementing entirely new languages, syn’s combination of power, safety, and ergonomics makes it an invaluable tool in the compiler writer’s toolkit.
ariadne
Ariadne is a modern diagnostic reporting library that emphasizes beautiful, user-friendly error messages. Named after the Greek mythological figure who provided thread to navigate the labyrinth, ariadne helps users navigate through complex error scenarios with clear, visually appealing diagnostics. It provides more flexibility and features than codespan-reporting while maintaining ease of use.
The library excels at showing relationships between different parts of code, using colors and connecting lines to make error contexts immediately clear. It supports multi-file errors, complex label relationships, and provides excellent defaults while remaining highly customizable.
Core Architecture
Ariadne separates error data from presentation, making it easy to generate diagnostics from your existing error types. The Report type is the centerpiece, built using a fluent API that encourages clear, helpful error messages.
#![allow(unused)] fn main() { use std::collections::HashMap; use std::fmt; use std::ops::Range; use ariadne::{Color, ColorGenerator, Fmt, Label, Report, ReportKind, Source}; /// A source file with name and content pub struct SourceFile { pub name: String, pub content: String, } /// Source code manager for multi-file projects pub struct SourceManager { files: HashMap<String, SourceFile>, } impl SourceManager { pub fn new() -> Self { Self { files: HashMap::new(), } } pub fn add_file(&mut self, name: String, content: String) { self.files.insert( name.clone(), SourceFile { name: name.clone(), content, }, ); } pub fn get_source(&self, file: &str) -> Option<Source> { self.files .get(file) .map(|f| Source::from(f.content.clone())) } } impl Default for SourceManager { fn default() -> Self { Self::new() } } /// Type representation for our language #[derive(Debug, Clone, PartialEq)] pub enum Type { Int, Float, String, Bool, Array(Box<Type>), Function(Vec<Type>, Box<Type>), Struct(String, Vec<(String, Type)>), Generic(String), } impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Type::Int => write!(f, "int"), Type::Float => write!(f, "float"), Type::String => write!(f, "string"), Type::Bool => write!(f, "bool"), Type::Array(elem) => write!(f, "{}[]", elem), Type::Function(params, ret) => { write!(f, "(")?; for (i, param) in params.iter().enumerate() { if i > 0 { write!(f, ", ")?; } write!(f, "{}", param)?; } write!(f, ") -> {}", ret) } Type::Struct(name, _) => write!(f, "{}", name), Type::Generic(name) => write!(f, "'{}", name), } } } impl CompilerDiagnostic { pub fn to_report(&self, _file_id: &str) -> Report<'static, (&'static str, Range<usize>)> { match self { CompilerDiagnostic::TypeError { expected, found, expr_span, expected_span, context, } => { let mut report = Report::build(ReportKind::Error, ("file", expr_span.clone())) .with_message(format!("Type mismatch in {}", context)) .with_label( Label::new(("file", expr_span.clone())) .with_message(format!( "Expected {}, found {}", expected.to_string().fg(Color::Green), found.to_string().fg(Color::Red) )) .with_color(Color::Red), ); if let Some(expected_span) = expected_span { report = report.with_label( Label::new(("file", expected_span.clone())) .with_message("Expected because of this") .with_color(Color::Blue), ); } report .with_note(format!( "Cannot convert {} to {}", found.to_string().fg(Color::Red), expected.to_string().fg(Color::Green) )) .finish() } CompilerDiagnostic::UnresolvedName { name, span, similar_names, imported_modules, } => { let mut report = Report::build(ReportKind::Error, ("file", span.clone())) .with_message(format!("Cannot find '{}' in scope", name)) .with_label( Label::new(("file", span.clone())) .with_message("Not found") .with_color(Color::Red), ); if !similar_names.is_empty() { let suggestions = similar_names .iter() .map(|s| s.fg(Color::Green).to_string()) .collect::<Vec<_>>() .join(", "); report = report.with_help(format!("Did you mean: {}?", suggestions)); } if !imported_modules.is_empty() { report = report.with_note(format!( "Available in modules: {}", imported_modules.join(", ") )); } report.finish() } CompilerDiagnostic::SyntaxError { message, span, expected, note, } => { let mut report = Report::build(ReportKind::Error, ("file", span.clone())) .with_message("Syntax error") .with_label( Label::new(("file", span.clone())) .with_message(message) .with_color(Color::Red), ); if !expected.is_empty() { report = report.with_help(format!( "Expected one of: {}", expected .iter() .map(|e| format!("'{}'", e).fg(Color::Green).to_string()) .collect::<Vec<_>>() .join(", ") )); } if let Some(note) = note { report = report.with_note(note); } report.finish() } CompilerDiagnostic::BorrowError { var_name, first_borrow, second_borrow, first_mutable, second_mutable, } => { let (first_kind, first_color) = if *first_mutable { ("mutable", Color::Yellow) } else { ("immutable", Color::Blue) }; let (second_kind, second_color) = if *second_mutable { ("mutable", Color::Yellow) } else { ("immutable", Color::Blue) }; Report::build(ReportKind::Error, ("file", second_borrow.clone())) .with_message(format!("Cannot borrow '{}' as {}", var_name, second_kind)) .with_label( Label::new(("file", first_borrow.clone())) .with_message(format!("First {} borrow occurs here", first_kind)) .with_color(first_color), ) .with_label( Label::new(("file", second_borrow.clone())) .with_message(format!( "Second {} borrow occurs here", second_kind )) .with_color(second_color), ) .with_note("Cannot have multiple mutable borrows or a mutable borrow with immutable borrows") .finish() } CompilerDiagnostic::CyclicDependency { modules } => { let mut colors = ColorGenerator::new(); let mut report = Report::build(ReportKind::Error, ("module", modules[0].1.clone())) .with_message("Cyclic module dependency detected"); for (i, (module, span)) in modules.iter().enumerate() { let color = colors.next(); let next_module = &modules[(i + 1) % modules.len()].0; report = report.with_label( Label::new(("module", span.clone())) .with_message(format!("'{}' imports '{}'", module, next_module)) .with_color(color), ); } report .with_note("Remove one of the imports to break the cycle") .finish() } } } } /// Language server protocol-style diagnostics pub struct LspDiagnostic { pub severity: DiagnosticSeverity, pub code: Option<String>, pub message: String, pub related_information: Vec<RelatedInformation>, pub tags: Vec<DiagnosticTag>, } #[derive(Debug, Clone, Copy)] pub enum DiagnosticSeverity { Error, Warning, Information, Hint, } #[derive(Debug, Clone)] pub struct RelatedInformation { pub location: (String, Range<usize>), pub message: String, } #[derive(Debug, Clone, Copy)] pub enum DiagnosticTag { Unnecessary, Deprecated, } /// Convert compiler diagnostics to LSP format pub fn to_lsp_diagnostic(diagnostic: &CompilerDiagnostic, _file: &str) -> LspDiagnostic { match diagnostic { CompilerDiagnostic::TypeError { .. } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: Some("E0308".to_string()), message: "Type mismatch".to_string(), related_information: vec![], tags: vec![], }, CompilerDiagnostic::UnresolvedName { name, .. } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: Some("E0425".to_string()), message: format!("Cannot find '{}' in scope", name), related_information: vec![], tags: vec![], }, CompilerDiagnostic::SyntaxError { message, .. } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: None, message: message.clone(), related_information: vec![], tags: vec![], }, CompilerDiagnostic::BorrowError { var_name, .. } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: Some("E0502".to_string()), message: format!("Cannot borrow '{}'", var_name), related_information: vec![], tags: vec![], }, CompilerDiagnostic::CyclicDependency { modules } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: Some("E0391".to_string()), message: "Cyclic dependency detected".to_string(), related_information: modules .iter() .map(|(module, span)| RelatedInformation { location: (module.clone(), span.clone()), message: format!("Module '{}' is part of the cycle", module), }) .collect(), tags: vec![], }, } } /// Helper function to create error reports pub fn error_report( _file: &str, span: Range<usize>, message: &str, label_msg: &str, ) -> Report<'static, (&'static str, Range<usize>)> { Report::build(ReportKind::Error, ("static", span.clone())) .with_message(message) .with_label( Label::new(("static", span)) .with_message(label_msg) .with_color(Color::Red), ) .finish() } /// Helper function to create warning reports pub fn warning_report( _file: &str, span: Range<usize>, message: &str, label_msg: &str, ) -> Report<'static, (&'static str, Range<usize>)> { Report::build(ReportKind::Warning, ("static", span.clone())) .with_message(message) .with_label( Label::new(("static", span)) .with_message(label_msg) .with_color(Color::Yellow), ) .finish() } #[cfg(test)] mod tests { use super::*; #[test] fn test_type_display() { let int_array = Type::Array(Box::new(Type::Int)); assert_eq!(int_array.to_string(), "int[]"); let func = Type::Function(vec![Type::Int, Type::String], Box::new(Type::Bool)); assert_eq!(func.to_string(), "(int, string) -> bool"); } #[test] fn test_source_manager() { let mut manager = SourceManager::new(); manager.add_file("test.rs".to_string(), "let x = 5;".to_string()); assert!(manager.get_source("test.rs").is_some()); assert!(manager.get_source("missing.rs").is_none()); } #[test] fn test_error_report() { let report = error_report("test.rs", 10..15, "Type mismatch", "Expected int"); // Just ensure it builds without panic let _ = format!("{:?}", report); } } /// Compiler diagnostics with rich information #[derive(Debug, Clone)] pub enum CompilerDiagnostic { TypeError { expected: Type, found: Type, expr_span: Range<usize>, expected_span: Option<Range<usize>>, context: String, }, UnresolvedName { name: String, span: Range<usize>, similar_names: Vec<String>, imported_modules: Vec<String>, }, SyntaxError { message: String, span: Range<usize>, expected: Vec<String>, note: Option<String>, }, BorrowError { var_name: String, first_borrow: Range<usize>, second_borrow: Range<usize>, first_mutable: bool, second_mutable: bool, }, CyclicDependency { modules: Vec<(String, Range<usize>)>, }, } }
Each diagnostic variant captures semantic information about the error, not just locations and strings. This separation makes it easier to maintain consistent error messages and potentially provide automated fixes.
Building Reports
Reports are constructed using a builder pattern that makes the intent clear and the code readable. Each report has a kind (error, warning, or advice), a main message, and can include multiple labels with different colors and priorities.
The to_report
method on CompilerDiagnostic converts error data into ariadne Report objects, handling all the details of label creation and color assignment.
The type error case demonstrates several key features: colored type names for clarity, primary and secondary labels to show relationships, and helpful notes explaining why the error occurred.
Color Management
Ariadne provides intelligent color assignment through ColorGenerator, ensuring that related labels have distinct, readable colors. This is especially useful for complex errors with many related locations.
#![allow(unused)] fn main() { use std::collections::HashMap; use std::fmt; use std::ops::Range; use ariadne::{Color, ColorGenerator, Fmt, Label, Report, ReportKind, Source}; /// A source file with name and content pub struct SourceFile { pub name: String, pub content: String, } /// Source code manager for multi-file projects pub struct SourceManager { files: HashMap<String, SourceFile>, } impl SourceManager { pub fn new() -> Self { Self { files: HashMap::new(), } } pub fn add_file(&mut self, name: String, content: String) { self.files.insert( name.clone(), SourceFile { name: name.clone(), content, }, ); } pub fn get_source(&self, file: &str) -> Option<Source> { self.files .get(file) .map(|f| Source::from(f.content.clone())) } } impl Default for SourceManager { fn default() -> Self { Self::new() } } /// Type representation for our language #[derive(Debug, Clone, PartialEq)] pub enum Type { Int, Float, String, Bool, Array(Box<Type>), Function(Vec<Type>, Box<Type>), Struct(String, Vec<(String, Type)>), Generic(String), } impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Type::Int => write!(f, "int"), Type::Float => write!(f, "float"), Type::String => write!(f, "string"), Type::Bool => write!(f, "bool"), Type::Array(elem) => write!(f, "{}[]", elem), Type::Function(params, ret) => { write!(f, "(")?; for (i, param) in params.iter().enumerate() { if i > 0 { write!(f, ", ")?; } write!(f, "{}", param)?; } write!(f, ") -> {}", ret) } Type::Struct(name, _) => write!(f, "{}", name), Type::Generic(name) => write!(f, "'{}", name), } } } /// Compiler diagnostics with rich information #[derive(Debug, Clone)] pub enum CompilerDiagnostic { TypeError { expected: Type, found: Type, expr_span: Range<usize>, expected_span: Option<Range<usize>>, context: String, }, UnresolvedName { name: String, span: Range<usize>, similar_names: Vec<String>, imported_modules: Vec<String>, }, SyntaxError { message: String, span: Range<usize>, expected: Vec<String>, note: Option<String>, }, BorrowError { var_name: String, first_borrow: Range<usize>, second_borrow: Range<usize>, first_mutable: bool, second_mutable: bool, }, CyclicDependency { modules: Vec<(String, Range<usize>)>, }, } impl CompilerDiagnostic { pub fn to_report(&self, _file_id: &str) -> Report<'static, (&'static str, Range<usize>)> { match self { CompilerDiagnostic::TypeError { expected, found, expr_span, expected_span, context, } => { let mut report = Report::build(ReportKind::Error, ("file", expr_span.clone())) .with_message(format!("Type mismatch in {}", context)) .with_label( Label::new(("file", expr_span.clone())) .with_message(format!( "Expected {}, found {}", expected.to_string().fg(Color::Green), found.to_string().fg(Color::Red) )) .with_color(Color::Red), ); if let Some(expected_span) = expected_span { report = report.with_label( Label::new(("file", expected_span.clone())) .with_message("Expected because of this") .with_color(Color::Blue), ); } report .with_note(format!( "Cannot convert {} to {}", found.to_string().fg(Color::Red), expected.to_string().fg(Color::Green) )) .finish() } CompilerDiagnostic::UnresolvedName { name, span, similar_names, imported_modules, } => { let mut report = Report::build(ReportKind::Error, ("file", span.clone())) .with_message(format!("Cannot find '{}' in scope", name)) .with_label( Label::new(("file", span.clone())) .with_message("Not found") .with_color(Color::Red), ); if !similar_names.is_empty() { let suggestions = similar_names .iter() .map(|s| s.fg(Color::Green).to_string()) .collect::<Vec<_>>() .join(", "); report = report.with_help(format!("Did you mean: {}?", suggestions)); } if !imported_modules.is_empty() { report = report.with_note(format!( "Available in modules: {}", imported_modules.join(", ") )); } report.finish() } CompilerDiagnostic::SyntaxError { message, span, expected, note, } => { let mut report = Report::build(ReportKind::Error, ("file", span.clone())) .with_message("Syntax error") .with_label( Label::new(("file", span.clone())) .with_message(message) .with_color(Color::Red), ); if !expected.is_empty() { report = report.with_help(format!( "Expected one of: {}", expected .iter() .map(|e| format!("'{}'", e).fg(Color::Green).to_string()) .collect::<Vec<_>>() .join(", ") )); } if let Some(note) = note { report = report.with_note(note); } report.finish() } CompilerDiagnostic::BorrowError { var_name, first_borrow, second_borrow, first_mutable, second_mutable, } => { let (first_kind, first_color) = if *first_mutable { ("mutable", Color::Yellow) } else { ("immutable", Color::Blue) }; let (second_kind, second_color) = if *second_mutable { ("mutable", Color::Yellow) } else { ("immutable", Color::Blue) }; Report::build(ReportKind::Error, ("file", second_borrow.clone())) .with_message(format!("Cannot borrow '{}' as {}", var_name, second_kind)) .with_label( Label::new(("file", first_borrow.clone())) .with_message(format!("First {} borrow occurs here", first_kind)) .with_color(first_color), ) .with_label( Label::new(("file", second_borrow.clone())) .with_message(format!( "Second {} borrow occurs here", second_kind )) .with_color(second_color), ) .with_note("Cannot have multiple mutable borrows or a mutable borrow with immutable borrows") .finish() } CompilerDiagnostic::CyclicDependency { modules } => { let mut colors = ColorGenerator::new(); let mut report = Report::build(ReportKind::Error, ("module", modules[0].1.clone())) .with_message("Cyclic module dependency detected"); for (i, (module, span)) in modules.iter().enumerate() { let color = colors.next(); let next_module = &modules[(i + 1) % modules.len()].0; report = report.with_label( Label::new(("module", span.clone())) .with_message(format!("'{}' imports '{}'", module, next_module)) .with_color(color), ); } report .with_note("Remove one of the imports to break the cycle") .finish() } } } } /// Language server protocol-style diagnostics pub struct LspDiagnostic { pub severity: DiagnosticSeverity, pub code: Option<String>, pub message: String, pub related_information: Vec<RelatedInformation>, pub tags: Vec<DiagnosticTag>, } #[derive(Debug, Clone, Copy)] pub enum DiagnosticSeverity { Error, Warning, Information, Hint, } #[derive(Debug, Clone)] pub struct RelatedInformation { pub location: (String, Range<usize>), pub message: String, } #[derive(Debug, Clone, Copy)] pub enum DiagnosticTag { Unnecessary, Deprecated, } /// Convert compiler diagnostics to LSP format pub fn to_lsp_diagnostic(diagnostic: &CompilerDiagnostic, _file: &str) -> LspDiagnostic { match diagnostic { CompilerDiagnostic::TypeError { .. } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: Some("E0308".to_string()), message: "Type mismatch".to_string(), related_information: vec![], tags: vec![], }, CompilerDiagnostic::UnresolvedName { name, .. } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: Some("E0425".to_string()), message: format!("Cannot find '{}' in scope", name), related_information: vec![], tags: vec![], }, CompilerDiagnostic::SyntaxError { message, .. } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: None, message: message.clone(), related_information: vec![], tags: vec![], }, CompilerDiagnostic::BorrowError { var_name, .. } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: Some("E0502".to_string()), message: format!("Cannot borrow '{}'", var_name), related_information: vec![], tags: vec![], }, CompilerDiagnostic::CyclicDependency { modules } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: Some("E0391".to_string()), message: "Cyclic dependency detected".to_string(), related_information: modules .iter() .map(|(module, span)| RelatedInformation { location: (module.clone(), span.clone()), message: format!("Module '{}' is part of the cycle", module), }) .collect(), tags: vec![], }, } } /// Helper function to create warning reports pub fn warning_report( _file: &str, span: Range<usize>, message: &str, label_msg: &str, ) -> Report<'static, (&'static str, Range<usize>)> { Report::build(ReportKind::Warning, ("static", span.clone())) .with_message(message) .with_label( Label::new(("static", span)) .with_message(label_msg) .with_color(Color::Yellow), ) .finish() } #[cfg(test)] mod tests { use super::*; #[test] fn test_type_display() { let int_array = Type::Array(Box::new(Type::Int)); assert_eq!(int_array.to_string(), "int[]"); let func = Type::Function(vec![Type::Int, Type::String], Box::new(Type::Bool)); assert_eq!(func.to_string(), "(int, string) -> bool"); } #[test] fn test_source_manager() { let mut manager = SourceManager::new(); manager.add_file("test.rs".to_string(), "let x = 5;".to_string()); assert!(manager.get_source("test.rs").is_some()); assert!(manager.get_source("missing.rs").is_none()); } #[test] fn test_error_report() { let report = error_report("test.rs", 10..15, "Type mismatch", "Expected int"); // Just ensure it builds without panic let _ = format!("{:?}", report); } } /// Helper function to create error reports pub fn error_report( _file: &str, span: Range<usize>, message: &str, label_msg: &str, ) -> Report<'static, (&'static str, Range<usize>)> { Report::build(ReportKind::Error, ("static", span.clone())) .with_message(message) .with_label( Label::new(("static", span)) .with_message(label_msg) .with_color(Color::Red), ) .finish() } }
These helper functions provide a simpler way to create basic error and warning reports when you don’t need the full flexibility of the builder pattern.
The builder pattern allows for flexible report construction while ensuring all required fields are provided. This makes it harder to accidentally create incomplete error messages.
Multi-file Errors
Modern compilers often need to report errors spanning multiple files. Ariadne handles this elegantly, allowing labels to reference different sources while maintaining a cohesive error presentation.
The CyclicDependency variant in CompilerDiagnostic shows how to represent errors involving multiple files. Each module in the cycle gets its own label with a distinct color, making the relationship clear.
Source Management
For production use, you’ll need a source management system that can efficiently provide file contents for error reporting. Ariadne works with any source provider through a simple trait.
#![allow(unused)] fn main() { use std::collections::HashMap; use std::fmt; use std::ops::Range; use ariadne::{Color, ColorGenerator, Fmt, Label, Report, ReportKind, Source}; /// A source file with name and content pub struct SourceFile { pub name: String, pub content: String, } impl SourceManager { pub fn new() -> Self { Self { files: HashMap::new(), } } pub fn add_file(&mut self, name: String, content: String) { self.files.insert( name.clone(), SourceFile { name: name.clone(), content, }, ); } pub fn get_source(&self, file: &str) -> Option<Source> { self.files .get(file) .map(|f| Source::from(f.content.clone())) } } impl Default for SourceManager { fn default() -> Self { Self::new() } } /// Type representation for our language #[derive(Debug, Clone, PartialEq)] pub enum Type { Int, Float, String, Bool, Array(Box<Type>), Function(Vec<Type>, Box<Type>), Struct(String, Vec<(String, Type)>), Generic(String), } impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Type::Int => write!(f, "int"), Type::Float => write!(f, "float"), Type::String => write!(f, "string"), Type::Bool => write!(f, "bool"), Type::Array(elem) => write!(f, "{}[]", elem), Type::Function(params, ret) => { write!(f, "(")?; for (i, param) in params.iter().enumerate() { if i > 0 { write!(f, ", ")?; } write!(f, "{}", param)?; } write!(f, ") -> {}", ret) } Type::Struct(name, _) => write!(f, "{}", name), Type::Generic(name) => write!(f, "'{}", name), } } } /// Compiler diagnostics with rich information #[derive(Debug, Clone)] pub enum CompilerDiagnostic { TypeError { expected: Type, found: Type, expr_span: Range<usize>, expected_span: Option<Range<usize>>, context: String, }, UnresolvedName { name: String, span: Range<usize>, similar_names: Vec<String>, imported_modules: Vec<String>, }, SyntaxError { message: String, span: Range<usize>, expected: Vec<String>, note: Option<String>, }, BorrowError { var_name: String, first_borrow: Range<usize>, second_borrow: Range<usize>, first_mutable: bool, second_mutable: bool, }, CyclicDependency { modules: Vec<(String, Range<usize>)>, }, } impl CompilerDiagnostic { pub fn to_report(&self, _file_id: &str) -> Report<'static, (&'static str, Range<usize>)> { match self { CompilerDiagnostic::TypeError { expected, found, expr_span, expected_span, context, } => { let mut report = Report::build(ReportKind::Error, ("file", expr_span.clone())) .with_message(format!("Type mismatch in {}", context)) .with_label( Label::new(("file", expr_span.clone())) .with_message(format!( "Expected {}, found {}", expected.to_string().fg(Color::Green), found.to_string().fg(Color::Red) )) .with_color(Color::Red), ); if let Some(expected_span) = expected_span { report = report.with_label( Label::new(("file", expected_span.clone())) .with_message("Expected because of this") .with_color(Color::Blue), ); } report .with_note(format!( "Cannot convert {} to {}", found.to_string().fg(Color::Red), expected.to_string().fg(Color::Green) )) .finish() } CompilerDiagnostic::UnresolvedName { name, span, similar_names, imported_modules, } => { let mut report = Report::build(ReportKind::Error, ("file", span.clone())) .with_message(format!("Cannot find '{}' in scope", name)) .with_label( Label::new(("file", span.clone())) .with_message("Not found") .with_color(Color::Red), ); if !similar_names.is_empty() { let suggestions = similar_names .iter() .map(|s| s.fg(Color::Green).to_string()) .collect::<Vec<_>>() .join(", "); report = report.with_help(format!("Did you mean: {}?", suggestions)); } if !imported_modules.is_empty() { report = report.with_note(format!( "Available in modules: {}", imported_modules.join(", ") )); } report.finish() } CompilerDiagnostic::SyntaxError { message, span, expected, note, } => { let mut report = Report::build(ReportKind::Error, ("file", span.clone())) .with_message("Syntax error") .with_label( Label::new(("file", span.clone())) .with_message(message) .with_color(Color::Red), ); if !expected.is_empty() { report = report.with_help(format!( "Expected one of: {}", expected .iter() .map(|e| format!("'{}'", e).fg(Color::Green).to_string()) .collect::<Vec<_>>() .join(", ") )); } if let Some(note) = note { report = report.with_note(note); } report.finish() } CompilerDiagnostic::BorrowError { var_name, first_borrow, second_borrow, first_mutable, second_mutable, } => { let (first_kind, first_color) = if *first_mutable { ("mutable", Color::Yellow) } else { ("immutable", Color::Blue) }; let (second_kind, second_color) = if *second_mutable { ("mutable", Color::Yellow) } else { ("immutable", Color::Blue) }; Report::build(ReportKind::Error, ("file", second_borrow.clone())) .with_message(format!("Cannot borrow '{}' as {}", var_name, second_kind)) .with_label( Label::new(("file", first_borrow.clone())) .with_message(format!("First {} borrow occurs here", first_kind)) .with_color(first_color), ) .with_label( Label::new(("file", second_borrow.clone())) .with_message(format!( "Second {} borrow occurs here", second_kind )) .with_color(second_color), ) .with_note("Cannot have multiple mutable borrows or a mutable borrow with immutable borrows") .finish() } CompilerDiagnostic::CyclicDependency { modules } => { let mut colors = ColorGenerator::new(); let mut report = Report::build(ReportKind::Error, ("module", modules[0].1.clone())) .with_message("Cyclic module dependency detected"); for (i, (module, span)) in modules.iter().enumerate() { let color = colors.next(); let next_module = &modules[(i + 1) % modules.len()].0; report = report.with_label( Label::new(("module", span.clone())) .with_message(format!("'{}' imports '{}'", module, next_module)) .with_color(color), ); } report .with_note("Remove one of the imports to break the cycle") .finish() } } } } /// Language server protocol-style diagnostics pub struct LspDiagnostic { pub severity: DiagnosticSeverity, pub code: Option<String>, pub message: String, pub related_information: Vec<RelatedInformation>, pub tags: Vec<DiagnosticTag>, } #[derive(Debug, Clone, Copy)] pub enum DiagnosticSeverity { Error, Warning, Information, Hint, } #[derive(Debug, Clone)] pub struct RelatedInformation { pub location: (String, Range<usize>), pub message: String, } #[derive(Debug, Clone, Copy)] pub enum DiagnosticTag { Unnecessary, Deprecated, } /// Convert compiler diagnostics to LSP format pub fn to_lsp_diagnostic(diagnostic: &CompilerDiagnostic, _file: &str) -> LspDiagnostic { match diagnostic { CompilerDiagnostic::TypeError { .. } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: Some("E0308".to_string()), message: "Type mismatch".to_string(), related_information: vec![], tags: vec![], }, CompilerDiagnostic::UnresolvedName { name, .. } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: Some("E0425".to_string()), message: format!("Cannot find '{}' in scope", name), related_information: vec![], tags: vec![], }, CompilerDiagnostic::SyntaxError { message, .. } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: None, message: message.clone(), related_information: vec![], tags: vec![], }, CompilerDiagnostic::BorrowError { var_name, .. } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: Some("E0502".to_string()), message: format!("Cannot borrow '{}'", var_name), related_information: vec![], tags: vec![], }, CompilerDiagnostic::CyclicDependency { modules } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: Some("E0391".to_string()), message: "Cyclic dependency detected".to_string(), related_information: modules .iter() .map(|(module, span)| RelatedInformation { location: (module.clone(), span.clone()), message: format!("Module '{}' is part of the cycle", module), }) .collect(), tags: vec![], }, } } /// Helper function to create error reports pub fn error_report( _file: &str, span: Range<usize>, message: &str, label_msg: &str, ) -> Report<'static, (&'static str, Range<usize>)> { Report::build(ReportKind::Error, ("static", span.clone())) .with_message(message) .with_label( Label::new(("static", span)) .with_message(label_msg) .with_color(Color::Red), ) .finish() } /// Helper function to create warning reports pub fn warning_report( _file: &str, span: Range<usize>, message: &str, label_msg: &str, ) -> Report<'static, (&'static str, Range<usize>)> { Report::build(ReportKind::Warning, ("static", span.clone())) .with_message(message) .with_label( Label::new(("static", span)) .with_message(label_msg) .with_color(Color::Yellow), ) .finish() } #[cfg(test)] mod tests { use super::*; #[test] fn test_type_display() { let int_array = Type::Array(Box::new(Type::Int)); assert_eq!(int_array.to_string(), "int[]"); let func = Type::Function(vec![Type::Int, Type::String], Box::new(Type::Bool)); assert_eq!(func.to_string(), "(int, string) -> bool"); } #[test] fn test_source_manager() { let mut manager = SourceManager::new(); manager.add_file("test.rs".to_string(), "let x = 5;".to_string()); assert!(manager.get_source("test.rs").is_some()); assert!(manager.get_source("missing.rs").is_none()); } #[test] fn test_error_report() { let report = error_report("test.rs", 10..15, "Type mismatch", "Expected int"); // Just ensure it builds without panic let _ = format!("{:?}", report); } } /// Source code manager for multi-file projects pub struct SourceManager { files: HashMap<String, SourceFile>, } }
This manager can be extended to support incremental updates, caching, and other optimizations needed for language server implementations.
Advanced Formatting
Ariadne supports rich formatting within messages using the Fmt trait. This allows for inline styling of important elements like type names, keywords, or suggestions.
The library provides extensive configuration options through the Config type. You can control character sets (Unicode vs ASCII), compactness, line numbering style, and more. This ensures your diagnostics look good in any environment.
Language Server Integration
Ariadne diagnostics can be converted to Language Server Protocol format for IDE integration. This allows the same error reporting logic to power both command-line tools and IDE experiences.
#![allow(unused)] fn main() { use std::collections::HashMap; use std::fmt; use std::ops::Range; use ariadne::{Color, ColorGenerator, Fmt, Label, Report, ReportKind, Source}; /// A source file with name and content pub struct SourceFile { pub name: String, pub content: String, } /// Source code manager for multi-file projects pub struct SourceManager { files: HashMap<String, SourceFile>, } impl SourceManager { pub fn new() -> Self { Self { files: HashMap::new(), } } pub fn add_file(&mut self, name: String, content: String) { self.files.insert( name.clone(), SourceFile { name: name.clone(), content, }, ); } pub fn get_source(&self, file: &str) -> Option<Source> { self.files .get(file) .map(|f| Source::from(f.content.clone())) } } impl Default for SourceManager { fn default() -> Self { Self::new() } } /// Type representation for our language #[derive(Debug, Clone, PartialEq)] pub enum Type { Int, Float, String, Bool, Array(Box<Type>), Function(Vec<Type>, Box<Type>), Struct(String, Vec<(String, Type)>), Generic(String), } impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Type::Int => write!(f, "int"), Type::Float => write!(f, "float"), Type::String => write!(f, "string"), Type::Bool => write!(f, "bool"), Type::Array(elem) => write!(f, "{}[]", elem), Type::Function(params, ret) => { write!(f, "(")?; for (i, param) in params.iter().enumerate() { if i > 0 { write!(f, ", ")?; } write!(f, "{}", param)?; } write!(f, ") -> {}", ret) } Type::Struct(name, _) => write!(f, "{}", name), Type::Generic(name) => write!(f, "'{}", name), } } } /// Compiler diagnostics with rich information #[derive(Debug, Clone)] pub enum CompilerDiagnostic { TypeError { expected: Type, found: Type, expr_span: Range<usize>, expected_span: Option<Range<usize>>, context: String, }, UnresolvedName { name: String, span: Range<usize>, similar_names: Vec<String>, imported_modules: Vec<String>, }, SyntaxError { message: String, span: Range<usize>, expected: Vec<String>, note: Option<String>, }, BorrowError { var_name: String, first_borrow: Range<usize>, second_borrow: Range<usize>, first_mutable: bool, second_mutable: bool, }, CyclicDependency { modules: Vec<(String, Range<usize>)>, }, } impl CompilerDiagnostic { pub fn to_report(&self, _file_id: &str) -> Report<'static, (&'static str, Range<usize>)> { match self { CompilerDiagnostic::TypeError { expected, found, expr_span, expected_span, context, } => { let mut report = Report::build(ReportKind::Error, ("file", expr_span.clone())) .with_message(format!("Type mismatch in {}", context)) .with_label( Label::new(("file", expr_span.clone())) .with_message(format!( "Expected {}, found {}", expected.to_string().fg(Color::Green), found.to_string().fg(Color::Red) )) .with_color(Color::Red), ); if let Some(expected_span) = expected_span { report = report.with_label( Label::new(("file", expected_span.clone())) .with_message("Expected because of this") .with_color(Color::Blue), ); } report .with_note(format!( "Cannot convert {} to {}", found.to_string().fg(Color::Red), expected.to_string().fg(Color::Green) )) .finish() } CompilerDiagnostic::UnresolvedName { name, span, similar_names, imported_modules, } => { let mut report = Report::build(ReportKind::Error, ("file", span.clone())) .with_message(format!("Cannot find '{}' in scope", name)) .with_label( Label::new(("file", span.clone())) .with_message("Not found") .with_color(Color::Red), ); if !similar_names.is_empty() { let suggestions = similar_names .iter() .map(|s| s.fg(Color::Green).to_string()) .collect::<Vec<_>>() .join(", "); report = report.with_help(format!("Did you mean: {}?", suggestions)); } if !imported_modules.is_empty() { report = report.with_note(format!( "Available in modules: {}", imported_modules.join(", ") )); } report.finish() } CompilerDiagnostic::SyntaxError { message, span, expected, note, } => { let mut report = Report::build(ReportKind::Error, ("file", span.clone())) .with_message("Syntax error") .with_label( Label::new(("file", span.clone())) .with_message(message) .with_color(Color::Red), ); if !expected.is_empty() { report = report.with_help(format!( "Expected one of: {}", expected .iter() .map(|e| format!("'{}'", e).fg(Color::Green).to_string()) .collect::<Vec<_>>() .join(", ") )); } if let Some(note) = note { report = report.with_note(note); } report.finish() } CompilerDiagnostic::BorrowError { var_name, first_borrow, second_borrow, first_mutable, second_mutable, } => { let (first_kind, first_color) = if *first_mutable { ("mutable", Color::Yellow) } else { ("immutable", Color::Blue) }; let (second_kind, second_color) = if *second_mutable { ("mutable", Color::Yellow) } else { ("immutable", Color::Blue) }; Report::build(ReportKind::Error, ("file", second_borrow.clone())) .with_message(format!("Cannot borrow '{}' as {}", var_name, second_kind)) .with_label( Label::new(("file", first_borrow.clone())) .with_message(format!("First {} borrow occurs here", first_kind)) .with_color(first_color), ) .with_label( Label::new(("file", second_borrow.clone())) .with_message(format!( "Second {} borrow occurs here", second_kind )) .with_color(second_color), ) .with_note("Cannot have multiple mutable borrows or a mutable borrow with immutable borrows") .finish() } CompilerDiagnostic::CyclicDependency { modules } => { let mut colors = ColorGenerator::new(); let mut report = Report::build(ReportKind::Error, ("module", modules[0].1.clone())) .with_message("Cyclic module dependency detected"); for (i, (module, span)) in modules.iter().enumerate() { let color = colors.next(); let next_module = &modules[(i + 1) % modules.len()].0; report = report.with_label( Label::new(("module", span.clone())) .with_message(format!("'{}' imports '{}'", module, next_module)) .with_color(color), ); } report .with_note("Remove one of the imports to break the cycle") .finish() } } } } /// Language server protocol-style diagnostics pub struct LspDiagnostic { pub severity: DiagnosticSeverity, pub code: Option<String>, pub message: String, pub related_information: Vec<RelatedInformation>, pub tags: Vec<DiagnosticTag>, } #[derive(Debug, Clone, Copy)] pub enum DiagnosticSeverity { Error, Warning, Information, Hint, } #[derive(Debug, Clone)] pub struct RelatedInformation { pub location: (String, Range<usize>), pub message: String, } #[derive(Debug, Clone, Copy)] pub enum DiagnosticTag { Unnecessary, Deprecated, } /// Helper function to create error reports pub fn error_report( _file: &str, span: Range<usize>, message: &str, label_msg: &str, ) -> Report<'static, (&'static str, Range<usize>)> { Report::build(ReportKind::Error, ("static", span.clone())) .with_message(message) .with_label( Label::new(("static", span)) .with_message(label_msg) .with_color(Color::Red), ) .finish() } /// Helper function to create warning reports pub fn warning_report( _file: &str, span: Range<usize>, message: &str, label_msg: &str, ) -> Report<'static, (&'static str, Range<usize>)> { Report::build(ReportKind::Warning, ("static", span.clone())) .with_message(message) .with_label( Label::new(("static", span)) .with_message(label_msg) .with_color(Color::Yellow), ) .finish() } #[cfg(test)] mod tests { use super::*; #[test] fn test_type_display() { let int_array = Type::Array(Box::new(Type::Int)); assert_eq!(int_array.to_string(), "int[]"); let func = Type::Function(vec![Type::Int, Type::String], Box::new(Type::Bool)); assert_eq!(func.to_string(), "(int, string) -> bool"); } #[test] fn test_source_manager() { let mut manager = SourceManager::new(); manager.add_file("test.rs".to_string(), "let x = 5;".to_string()); assert!(manager.get_source("test.rs").is_some()); assert!(manager.get_source("missing.rs").is_none()); } #[test] fn test_error_report() { let report = error_report("test.rs", 10..15, "Type mismatch", "Expected int"); // Just ensure it builds without panic let _ = format!("{:?}", report); } } /// Convert compiler diagnostics to LSP format pub fn to_lsp_diagnostic(diagnostic: &CompilerDiagnostic, _file: &str) -> LspDiagnostic { match diagnostic { CompilerDiagnostic::TypeError { .. } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: Some("E0308".to_string()), message: "Type mismatch".to_string(), related_information: vec![], tags: vec![], }, CompilerDiagnostic::UnresolvedName { name, .. } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: Some("E0425".to_string()), message: format!("Cannot find '{}' in scope", name), related_information: vec![], tags: vec![], }, CompilerDiagnostic::SyntaxError { message, .. } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: None, message: message.clone(), related_information: vec![], tags: vec![], }, CompilerDiagnostic::BorrowError { var_name, .. } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: Some("E0502".to_string()), message: format!("Cannot borrow '{}'", var_name), related_information: vec![], tags: vec![], }, CompilerDiagnostic::CyclicDependency { modules } => LspDiagnostic { severity: DiagnosticSeverity::Error, code: Some("E0391".to_string()), message: "Cyclic dependency detected".to_string(), related_information: modules .iter() .map(|(module, span)| RelatedInformation { location: (module.clone(), span.clone()), message: format!("Module '{}' is part of the cycle", module), }) .collect(), tags: vec![], }, } } }
The conversion preserves error codes, severity levels, and related information, ensuring a consistent experience across different tools.
Best Practices
Design your error types to capture intent, not just data. Instead of a generic “SyntaxError”, have specific variants like “MissingClosingBrace” or “UnexpectedToken”. This makes it easier to provide targeted help.
Use color meaningfully. Primary error locations should use red, secondary related locations can use blue or yellow, and informational labels can use gray. Consistency helps users quickly understand error relationships.
Write error messages that teach. Instead of just saying what’s wrong, explain why it’s wrong and how to fix it. Good diagnostics are an opportunity to educate users about language rules and best practices.
Consider error recovery when designing diagnostics. If you can guess what the user meant, include that in help text. For example, if they typed “fucntion” instead of “function”, suggest the correction.
Group related errors when they have a common cause. If a type error in one function causes errors in its callers, present them as a single diagnostic with multiple labels rather than separate errors.
codespan-reporting
The codespan-reporting crate provides beautiful diagnostic rendering for compilers and development tools. It generates the same style of error messages you see in Rust, with source code snippets, underlines, and helpful annotations. This has become the standard for high-quality compiler diagnostics in the Rust ecosystem.
Unlike simple error printing, codespan-reporting handles multi-line spans, multiple files, and complex error relationships. It automatically handles terminal colors, Unicode rendering, and even provides alternatives for non-Unicode terminals. The crate integrates seamlessly with language servers and other tooling.
Core Concepts
The library revolves around three main types: Files for source management, Diagnostics for error information, and Labels for marking specific code locations. Each diagnostic can have multiple labels pointing to different parts of the code, making it easy to show cause-and-effect relationships.
#![allow(unused)] fn main() { use std::ops::Range; use codespan_reporting::diagnostic::{Diagnostic, Label}; use codespan_reporting::files::SimpleFiles; use codespan_reporting::term::{self, Config}; use termcolor::{ColorChoice, StandardStream}; impl DiagnosticEngine { pub fn new() -> Self { Self { files: SimpleFiles::new(), config: Config::default(), } } pub fn add_file(&mut self, name: String, source: String) -> usize { self.files.add(name, source) } pub fn emit_diagnostic(&self, diagnostic: Diagnostic<usize>) { let writer = StandardStream::stderr(ColorChoice::Always); let _ = term::emit(&mut writer.lock(), &self.config, &self.files, &diagnostic); } } impl Default for DiagnosticEngine { fn default() -> Self { Self::new() } } /// Type system for a simple functional language #[derive(Debug, Clone, PartialEq)] pub enum Type { Int, Bool, String, Function(Box<Type>, Box<Type>), List(Box<Type>), Unknown, } impl std::fmt::Display for Type { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Type::Int => write!(f, "int"), Type::Bool => write!(f, "bool"), Type::String => write!(f, "string"), Type::Function(from, to) => write!(f, "{} -> {}", from, to), Type::List(elem) => write!(f, "[{}]", elem), Type::Unknown => write!(f, "_"), } } } /// Common compiler errors with location information #[derive(Debug, Clone)] pub enum CompilerError { TypeMismatch { expected: Type, found: Type, location: Range<usize>, }, UndefinedVariable { name: String, location: Range<usize>, similar: Vec<String>, }, ParseError { message: String, location: Range<usize>, hint: Option<String>, }, DuplicateDefinition { name: String, first_location: Range<usize>, second_location: Range<usize>, }, } impl CompilerError { pub fn to_diagnostic(&self, file_id: usize) -> Diagnostic<usize> { match self { CompilerError::TypeMismatch { expected, found, location, } => Diagnostic::error() .with_message("type mismatch") .with_labels(vec![Label::primary(file_id, location.clone()) .with_message(format!("expected `{}`, found `{}`", expected, found))]), CompilerError::UndefinedVariable { name, location, similar, } => { let mut diagnostic = Diagnostic::error() .with_message(format!("undefined variable `{}`", name)) .with_labels(vec![Label::primary(file_id, location.clone()) .with_message("not found in scope")]); if !similar.is_empty() { let suggestions = similar.join(", "); diagnostic = diagnostic.with_notes(vec![format!("did you mean: {}?", suggestions)]); } diagnostic } CompilerError::ParseError { message, location, hint, } => { let mut diagnostic = Diagnostic::error() .with_message("parse error") .with_labels(vec![ Label::primary(file_id, location.clone()).with_message(message) ]); if let Some(hint) = hint { diagnostic = diagnostic.with_notes(vec![hint.clone()]); } diagnostic } CompilerError::DuplicateDefinition { name, first_location, second_location, } => Diagnostic::error() .with_message(format!("duplicate definition of `{}`", name)) .with_labels(vec![ Label::secondary(file_id, first_location.clone()) .with_message("first definition here"), Label::primary(file_id, second_location.clone()).with_message("redefined here"), ]), } } } /// A simple lexer for demonstration purposes pub struct Lexer<'a> { input: &'a str, position: usize, } #[derive(Debug, Clone)] pub struct Token { pub kind: TokenKind, pub span: Range<usize>, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Number(i64), Identifier(String), Let, If, Else, Function, Arrow, LeftParen, RightParen, Equals, Plus, Minus, Star, Slash, Less, Greater, Bang, Semicolon, Eof, } impl<'a> Lexer<'a> { pub fn new(input: &'a str) -> Self { Self { input, position: 0 } } pub fn tokenize(&mut self) -> Result<Vec<Token>, CompilerError> { let mut tokens = Vec::new(); while self.position < self.input.len() { self.skip_whitespace(); if self.position >= self.input.len() { break; } let start = self.position; let token = self.next_token()?; let end = self.position; tokens.push(Token { kind: token, span: start..end, }); } tokens.push(Token { kind: TokenKind::Eof, span: self.position..self.position, }); Ok(tokens) } fn skip_whitespace(&mut self) { while self.position < self.input.len() && self.input.as_bytes()[self.position].is_ascii_whitespace() { self.position += 1; } } fn next_token(&mut self) -> Result<TokenKind, CompilerError> { let start = self.position; let ch = self.current_char(); match ch { '0'..='9' => self.read_number(), 'a'..='z' | 'A'..='Z' | '_' => self.read_identifier(), '+' => { self.advance(); Ok(TokenKind::Plus) } '-' => { self.advance(); if self.current_char() == '>' { self.advance(); Ok(TokenKind::Arrow) } else { Ok(TokenKind::Minus) } } '*' => { self.advance(); Ok(TokenKind::Star) } '/' => { self.advance(); Ok(TokenKind::Slash) } '<' => { self.advance(); Ok(TokenKind::Less) } '>' => { self.advance(); Ok(TokenKind::Greater) } '!' => { self.advance(); Ok(TokenKind::Bang) } '=' => { self.advance(); Ok(TokenKind::Equals) } '(' => { self.advance(); Ok(TokenKind::LeftParen) } ')' => { self.advance(); Ok(TokenKind::RightParen) } ';' => { self.advance(); Ok(TokenKind::Semicolon) } _ => Err(CompilerError::ParseError { message: format!("unexpected character `{}`", ch), location: start..start + 1, hint: Some("expected a number, identifier, or operator".to_string()), }), } } fn current_char(&self) -> char { self.input.as_bytes()[self.position] as char } fn advance(&mut self) { self.position += 1; } fn read_number(&mut self) -> Result<TokenKind, CompilerError> { let start = self.position; while self.position < self.input.len() && self.input.as_bytes()[self.position].is_ascii_digit() { self.position += 1; } let num_str = &self.input[start..self.position]; let num = num_str.parse().map_err(|_| CompilerError::ParseError { message: "Invalid number format".to_string(), location: start..self.position, hint: Some("Number too large to parse".to_string()), })?; Ok(TokenKind::Number(num)) } fn read_identifier(&mut self) -> Result<TokenKind, CompilerError> { let start = self.position; while self.position < self.input.len() { let ch = self.input.as_bytes()[self.position]; if ch.is_ascii_alphanumeric() || ch == b'_' { self.position += 1; } else { break; } } let ident = &self.input[start..self.position]; let kind = match ident { "let" => TokenKind::Let, "if" => TokenKind::If, "else" => TokenKind::Else, "fn" => TokenKind::Function, _ => TokenKind::Identifier(ident.to_string()), }; Ok(kind) } } /// Multi-file project support pub struct Project { engine: DiagnosticEngine, file_ids: Vec<(String, usize)>, } impl Project { pub fn new() -> Self { Self { engine: DiagnosticEngine::new(), file_ids: Vec::new(), } } pub fn add_file(&mut self, path: String, content: String) -> usize { let file_id = self.engine.add_file(path.clone(), content); self.file_ids.push((path, file_id)); file_id } pub fn compile(&self) -> Result<(), Vec<Diagnostic<usize>>> { let mut diagnostics = Vec::new(); // Simulate compilation with various error types for (path, file_id) in &self.file_ids { if path.ends_with("types.ml") { // Type error example diagnostics.push( CompilerError::TypeMismatch { expected: Type::Int, found: Type::String, location: 45..52, } .to_diagnostic(*file_id), ); } else if path.ends_with("undefined.ml") { // Undefined variable with suggestions diagnostics.push( CompilerError::UndefinedVariable { name: "lenght".to_string(), location: 23..29, similar: vec!["length".to_string(), "len".to_string()], } .to_diagnostic(*file_id), ); } } if diagnostics.is_empty() { Ok(()) } else { for diagnostic in &diagnostics { self.engine.emit_diagnostic(diagnostic.clone()); } Err(diagnostics) } } } impl Default for Project { fn default() -> Self { Self::new() } } /// Create a warning diagnostic pub fn create_warning( file_id: usize, message: &str, location: Range<usize>, note: Option<String>, ) -> Diagnostic<usize> { let mut diagnostic = Diagnostic::warning() .with_message(message) .with_labels(vec![Label::primary(file_id, location)]); if let Some(note) = note { diagnostic = diagnostic.with_notes(vec![note]); } diagnostic } /// Create an information diagnostic pub fn create_info(file_id: usize, message: &str, location: Range<usize>) -> Diagnostic<usize> { Diagnostic::note() .with_message(message) .with_labels(vec![Label::primary(file_id, location)]) } #[cfg(test)] mod tests { use codespan_reporting::diagnostic::Severity; use super::*; #[test] fn test_lexer() { let mut lexer = Lexer::new("let x = 42 + 3"); let tokens = lexer.tokenize().unwrap(); assert_eq!(tokens.len(), 7); // let x = 42 + 3 EOF } #[test] fn test_type_display() { let int_to_bool = Type::Function(Box::new(Type::Int), Box::new(Type::Bool)); assert_eq!(int_to_bool.to_string(), "int -> bool"); let list_of_ints = Type::List(Box::new(Type::Int)); assert_eq!(list_of_ints.to_string(), "[int]"); } #[test] fn test_diagnostic_creation() { let error = CompilerError::TypeMismatch { expected: Type::Int, found: Type::Bool, location: 10..15, }; let diagnostic = error.to_diagnostic(0); assert_eq!(diagnostic.severity, Severity::Error); } } /// A compiler diagnostic system built on codespan-reporting pub struct DiagnosticEngine { files: SimpleFiles<String, String>, config: Config, } }
This wrapper provides a convenient interface for managing files and emitting diagnostics. In a real compiler, you would integrate this with your existing source management system.
Creating Diagnostics
Diagnostics are built using a fluent API that makes it easy to construct rich error messages. Each diagnostic has a severity level, a main message, and can include multiple labeled spans with their own messages.
#![allow(unused)] fn main() { use std::ops::Range; use codespan_reporting::diagnostic::{Diagnostic, Label}; use codespan_reporting::files::SimpleFiles; use codespan_reporting::term::{self, Config}; use termcolor::{ColorChoice, StandardStream}; /// A compiler diagnostic system built on codespan-reporting pub struct DiagnosticEngine { files: SimpleFiles<String, String>, config: Config, } impl DiagnosticEngine { pub fn new() -> Self { Self { files: SimpleFiles::new(), config: Config::default(), } } pub fn add_file(&mut self, name: String, source: String) -> usize { self.files.add(name, source) } pub fn emit_diagnostic(&self, diagnostic: Diagnostic<usize>) { let writer = StandardStream::stderr(ColorChoice::Always); let _ = term::emit(&mut writer.lock(), &self.config, &self.files, &diagnostic); } } impl Default for DiagnosticEngine { fn default() -> Self { Self::new() } } /// Type system for a simple functional language #[derive(Debug, Clone, PartialEq)] pub enum Type { Int, Bool, String, Function(Box<Type>, Box<Type>), List(Box<Type>), Unknown, } impl std::fmt::Display for Type { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Type::Int => write!(f, "int"), Type::Bool => write!(f, "bool"), Type::String => write!(f, "string"), Type::Function(from, to) => write!(f, "{} -> {}", from, to), Type::List(elem) => write!(f, "[{}]", elem), Type::Unknown => write!(f, "_"), } } } impl CompilerError { pub fn to_diagnostic(&self, file_id: usize) -> Diagnostic<usize> { match self { CompilerError::TypeMismatch { expected, found, location, } => Diagnostic::error() .with_message("type mismatch") .with_labels(vec![Label::primary(file_id, location.clone()) .with_message(format!("expected `{}`, found `{}`", expected, found))]), CompilerError::UndefinedVariable { name, location, similar, } => { let mut diagnostic = Diagnostic::error() .with_message(format!("undefined variable `{}`", name)) .with_labels(vec![Label::primary(file_id, location.clone()) .with_message("not found in scope")]); if !similar.is_empty() { let suggestions = similar.join(", "); diagnostic = diagnostic.with_notes(vec![format!("did you mean: {}?", suggestions)]); } diagnostic } CompilerError::ParseError { message, location, hint, } => { let mut diagnostic = Diagnostic::error() .with_message("parse error") .with_labels(vec![ Label::primary(file_id, location.clone()).with_message(message) ]); if let Some(hint) = hint { diagnostic = diagnostic.with_notes(vec![hint.clone()]); } diagnostic } CompilerError::DuplicateDefinition { name, first_location, second_location, } => Diagnostic::error() .with_message(format!("duplicate definition of `{}`", name)) .with_labels(vec![ Label::secondary(file_id, first_location.clone()) .with_message("first definition here"), Label::primary(file_id, second_location.clone()).with_message("redefined here"), ]), } } } /// A simple lexer for demonstration purposes pub struct Lexer<'a> { input: &'a str, position: usize, } #[derive(Debug, Clone)] pub struct Token { pub kind: TokenKind, pub span: Range<usize>, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Number(i64), Identifier(String), Let, If, Else, Function, Arrow, LeftParen, RightParen, Equals, Plus, Minus, Star, Slash, Less, Greater, Bang, Semicolon, Eof, } impl<'a> Lexer<'a> { pub fn new(input: &'a str) -> Self { Self { input, position: 0 } } pub fn tokenize(&mut self) -> Result<Vec<Token>, CompilerError> { let mut tokens = Vec::new(); while self.position < self.input.len() { self.skip_whitespace(); if self.position >= self.input.len() { break; } let start = self.position; let token = self.next_token()?; let end = self.position; tokens.push(Token { kind: token, span: start..end, }); } tokens.push(Token { kind: TokenKind::Eof, span: self.position..self.position, }); Ok(tokens) } fn skip_whitespace(&mut self) { while self.position < self.input.len() && self.input.as_bytes()[self.position].is_ascii_whitespace() { self.position += 1; } } fn next_token(&mut self) -> Result<TokenKind, CompilerError> { let start = self.position; let ch = self.current_char(); match ch { '0'..='9' => self.read_number(), 'a'..='z' | 'A'..='Z' | '_' => self.read_identifier(), '+' => { self.advance(); Ok(TokenKind::Plus) } '-' => { self.advance(); if self.current_char() == '>' { self.advance(); Ok(TokenKind::Arrow) } else { Ok(TokenKind::Minus) } } '*' => { self.advance(); Ok(TokenKind::Star) } '/' => { self.advance(); Ok(TokenKind::Slash) } '<' => { self.advance(); Ok(TokenKind::Less) } '>' => { self.advance(); Ok(TokenKind::Greater) } '!' => { self.advance(); Ok(TokenKind::Bang) } '=' => { self.advance(); Ok(TokenKind::Equals) } '(' => { self.advance(); Ok(TokenKind::LeftParen) } ')' => { self.advance(); Ok(TokenKind::RightParen) } ';' => { self.advance(); Ok(TokenKind::Semicolon) } _ => Err(CompilerError::ParseError { message: format!("unexpected character `{}`", ch), location: start..start + 1, hint: Some("expected a number, identifier, or operator".to_string()), }), } } fn current_char(&self) -> char { self.input.as_bytes()[self.position] as char } fn advance(&mut self) { self.position += 1; } fn read_number(&mut self) -> Result<TokenKind, CompilerError> { let start = self.position; while self.position < self.input.len() && self.input.as_bytes()[self.position].is_ascii_digit() { self.position += 1; } let num_str = &self.input[start..self.position]; let num = num_str.parse().map_err(|_| CompilerError::ParseError { message: "Invalid number format".to_string(), location: start..self.position, hint: Some("Number too large to parse".to_string()), })?; Ok(TokenKind::Number(num)) } fn read_identifier(&mut self) -> Result<TokenKind, CompilerError> { let start = self.position; while self.position < self.input.len() { let ch = self.input.as_bytes()[self.position]; if ch.is_ascii_alphanumeric() || ch == b'_' { self.position += 1; } else { break; } } let ident = &self.input[start..self.position]; let kind = match ident { "let" => TokenKind::Let, "if" => TokenKind::If, "else" => TokenKind::Else, "fn" => TokenKind::Function, _ => TokenKind::Identifier(ident.to_string()), }; Ok(kind) } } /// Multi-file project support pub struct Project { engine: DiagnosticEngine, file_ids: Vec<(String, usize)>, } impl Project { pub fn new() -> Self { Self { engine: DiagnosticEngine::new(), file_ids: Vec::new(), } } pub fn add_file(&mut self, path: String, content: String) -> usize { let file_id = self.engine.add_file(path.clone(), content); self.file_ids.push((path, file_id)); file_id } pub fn compile(&self) -> Result<(), Vec<Diagnostic<usize>>> { let mut diagnostics = Vec::new(); // Simulate compilation with various error types for (path, file_id) in &self.file_ids { if path.ends_with("types.ml") { // Type error example diagnostics.push( CompilerError::TypeMismatch { expected: Type::Int, found: Type::String, location: 45..52, } .to_diagnostic(*file_id), ); } else if path.ends_with("undefined.ml") { // Undefined variable with suggestions diagnostics.push( CompilerError::UndefinedVariable { name: "lenght".to_string(), location: 23..29, similar: vec!["length".to_string(), "len".to_string()], } .to_diagnostic(*file_id), ); } } if diagnostics.is_empty() { Ok(()) } else { for diagnostic in &diagnostics { self.engine.emit_diagnostic(diagnostic.clone()); } Err(diagnostics) } } } impl Default for Project { fn default() -> Self { Self::new() } } /// Create a warning diagnostic pub fn create_warning( file_id: usize, message: &str, location: Range<usize>, note: Option<String>, ) -> Diagnostic<usize> { let mut diagnostic = Diagnostic::warning() .with_message(message) .with_labels(vec![Label::primary(file_id, location)]); if let Some(note) = note { diagnostic = diagnostic.with_notes(vec![note]); } diagnostic } /// Create an information diagnostic pub fn create_info(file_id: usize, message: &str, location: Range<usize>) -> Diagnostic<usize> { Diagnostic::note() .with_message(message) .with_labels(vec![Label::primary(file_id, location)]) } #[cfg(test)] mod tests { use codespan_reporting::diagnostic::Severity; use super::*; #[test] fn test_lexer() { let mut lexer = Lexer::new("let x = 42 + 3"); let tokens = lexer.tokenize().unwrap(); assert_eq!(tokens.len(), 7); // let x = 42 + 3 EOF } #[test] fn test_type_display() { let int_to_bool = Type::Function(Box::new(Type::Int), Box::new(Type::Bool)); assert_eq!(int_to_bool.to_string(), "int -> bool"); let list_of_ints = Type::List(Box::new(Type::Int)); assert_eq!(list_of_ints.to_string(), "[int]"); } #[test] fn test_diagnostic_creation() { let error = CompilerError::TypeMismatch { expected: Type::Int, found: Type::Bool, location: 10..15, }; let diagnostic = error.to_diagnostic(0); assert_eq!(diagnostic.severity, Severity::Error); } } /// Common compiler errors with location information #[derive(Debug, Clone)] pub enum CompilerError { TypeMismatch { expected: Type, found: Type, location: Range<usize>, }, UndefinedVariable { name: String, location: Range<usize>, similar: Vec<String>, }, ParseError { message: String, location: Range<usize>, hint: Option<String>, }, DuplicateDefinition { name: String, first_location: Range<usize>, second_location: Range<usize>, }, } }
The to_diagnostic
method converts these error types into codespan-reporting diagnostics with appropriate labels and notes.
The type error case shows how to create a diagnostic with primary and secondary labels. The primary label marks the error location, while secondary labels provide additional context.
File Management
For multi-file projects, codespan-reporting provides a simple file database interface. You can use the built-in SimpleFiles or implement your own file provider.
#![allow(unused)] fn main() { use std::ops::Range; use codespan_reporting::diagnostic::{Diagnostic, Label}; use codespan_reporting::files::SimpleFiles; use codespan_reporting::term::{self, Config}; use termcolor::{ColorChoice, StandardStream}; /// A compiler diagnostic system built on codespan-reporting pub struct DiagnosticEngine { files: SimpleFiles<String, String>, config: Config, } impl DiagnosticEngine { pub fn new() -> Self { Self { files: SimpleFiles::new(), config: Config::default(), } } pub fn add_file(&mut self, name: String, source: String) -> usize { self.files.add(name, source) } pub fn emit_diagnostic(&self, diagnostic: Diagnostic<usize>) { let writer = StandardStream::stderr(ColorChoice::Always); let _ = term::emit(&mut writer.lock(), &self.config, &self.files, &diagnostic); } } impl Default for DiagnosticEngine { fn default() -> Self { Self::new() } } /// Type system for a simple functional language #[derive(Debug, Clone, PartialEq)] pub enum Type { Int, Bool, String, Function(Box<Type>, Box<Type>), List(Box<Type>), Unknown, } impl std::fmt::Display for Type { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Type::Int => write!(f, "int"), Type::Bool => write!(f, "bool"), Type::String => write!(f, "string"), Type::Function(from, to) => write!(f, "{} -> {}", from, to), Type::List(elem) => write!(f, "[{}]", elem), Type::Unknown => write!(f, "_"), } } } /// Common compiler errors with location information #[derive(Debug, Clone)] pub enum CompilerError { TypeMismatch { expected: Type, found: Type, location: Range<usize>, }, UndefinedVariable { name: String, location: Range<usize>, similar: Vec<String>, }, ParseError { message: String, location: Range<usize>, hint: Option<String>, }, DuplicateDefinition { name: String, first_location: Range<usize>, second_location: Range<usize>, }, } impl CompilerError { pub fn to_diagnostic(&self, file_id: usize) -> Diagnostic<usize> { match self { CompilerError::TypeMismatch { expected, found, location, } => Diagnostic::error() .with_message("type mismatch") .with_labels(vec![Label::primary(file_id, location.clone()) .with_message(format!("expected `{}`, found `{}`", expected, found))]), CompilerError::UndefinedVariable { name, location, similar, } => { let mut diagnostic = Diagnostic::error() .with_message(format!("undefined variable `{}`", name)) .with_labels(vec![Label::primary(file_id, location.clone()) .with_message("not found in scope")]); if !similar.is_empty() { let suggestions = similar.join(", "); diagnostic = diagnostic.with_notes(vec![format!("did you mean: {}?", suggestions)]); } diagnostic } CompilerError::ParseError { message, location, hint, } => { let mut diagnostic = Diagnostic::error() .with_message("parse error") .with_labels(vec![ Label::primary(file_id, location.clone()).with_message(message) ]); if let Some(hint) = hint { diagnostic = diagnostic.with_notes(vec![hint.clone()]); } diagnostic } CompilerError::DuplicateDefinition { name, first_location, second_location, } => Diagnostic::error() .with_message(format!("duplicate definition of `{}`", name)) .with_labels(vec![ Label::secondary(file_id, first_location.clone()) .with_message("first definition here"), Label::primary(file_id, second_location.clone()).with_message("redefined here"), ]), } } } /// A simple lexer for demonstration purposes pub struct Lexer<'a> { input: &'a str, position: usize, } #[derive(Debug, Clone)] pub struct Token { pub kind: TokenKind, pub span: Range<usize>, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Number(i64), Identifier(String), Let, If, Else, Function, Arrow, LeftParen, RightParen, Equals, Plus, Minus, Star, Slash, Less, Greater, Bang, Semicolon, Eof, } impl<'a> Lexer<'a> { pub fn new(input: &'a str) -> Self { Self { input, position: 0 } } pub fn tokenize(&mut self) -> Result<Vec<Token>, CompilerError> { let mut tokens = Vec::new(); while self.position < self.input.len() { self.skip_whitespace(); if self.position >= self.input.len() { break; } let start = self.position; let token = self.next_token()?; let end = self.position; tokens.push(Token { kind: token, span: start..end, }); } tokens.push(Token { kind: TokenKind::Eof, span: self.position..self.position, }); Ok(tokens) } fn skip_whitespace(&mut self) { while self.position < self.input.len() && self.input.as_bytes()[self.position].is_ascii_whitespace() { self.position += 1; } } fn next_token(&mut self) -> Result<TokenKind, CompilerError> { let start = self.position; let ch = self.current_char(); match ch { '0'..='9' => self.read_number(), 'a'..='z' | 'A'..='Z' | '_' => self.read_identifier(), '+' => { self.advance(); Ok(TokenKind::Plus) } '-' => { self.advance(); if self.current_char() == '>' { self.advance(); Ok(TokenKind::Arrow) } else { Ok(TokenKind::Minus) } } '*' => { self.advance(); Ok(TokenKind::Star) } '/' => { self.advance(); Ok(TokenKind::Slash) } '<' => { self.advance(); Ok(TokenKind::Less) } '>' => { self.advance(); Ok(TokenKind::Greater) } '!' => { self.advance(); Ok(TokenKind::Bang) } '=' => { self.advance(); Ok(TokenKind::Equals) } '(' => { self.advance(); Ok(TokenKind::LeftParen) } ')' => { self.advance(); Ok(TokenKind::RightParen) } ';' => { self.advance(); Ok(TokenKind::Semicolon) } _ => Err(CompilerError::ParseError { message: format!("unexpected character `{}`", ch), location: start..start + 1, hint: Some("expected a number, identifier, or operator".to_string()), }), } } fn current_char(&self) -> char { self.input.as_bytes()[self.position] as char } fn advance(&mut self) { self.position += 1; } fn read_number(&mut self) -> Result<TokenKind, CompilerError> { let start = self.position; while self.position < self.input.len() && self.input.as_bytes()[self.position].is_ascii_digit() { self.position += 1; } let num_str = &self.input[start..self.position]; let num = num_str.parse().map_err(|_| CompilerError::ParseError { message: "Invalid number format".to_string(), location: start..self.position, hint: Some("Number too large to parse".to_string()), })?; Ok(TokenKind::Number(num)) } fn read_identifier(&mut self) -> Result<TokenKind, CompilerError> { let start = self.position; while self.position < self.input.len() { let ch = self.input.as_bytes()[self.position]; if ch.is_ascii_alphanumeric() || ch == b'_' { self.position += 1; } else { break; } } let ident = &self.input[start..self.position]; let kind = match ident { "let" => TokenKind::Let, "if" => TokenKind::If, "else" => TokenKind::Else, "fn" => TokenKind::Function, _ => TokenKind::Identifier(ident.to_string()), }; Ok(kind) } } impl Project { pub fn new() -> Self { Self { engine: DiagnosticEngine::new(), file_ids: Vec::new(), } } pub fn add_file(&mut self, path: String, content: String) -> usize { let file_id = self.engine.add_file(path.clone(), content); self.file_ids.push((path, file_id)); file_id } pub fn compile(&self) -> Result<(), Vec<Diagnostic<usize>>> { let mut diagnostics = Vec::new(); // Simulate compilation with various error types for (path, file_id) in &self.file_ids { if path.ends_with("types.ml") { // Type error example diagnostics.push( CompilerError::TypeMismatch { expected: Type::Int, found: Type::String, location: 45..52, } .to_diagnostic(*file_id), ); } else if path.ends_with("undefined.ml") { // Undefined variable with suggestions diagnostics.push( CompilerError::UndefinedVariable { name: "lenght".to_string(), location: 23..29, similar: vec!["length".to_string(), "len".to_string()], } .to_diagnostic(*file_id), ); } } if diagnostics.is_empty() { Ok(()) } else { for diagnostic in &diagnostics { self.engine.emit_diagnostic(diagnostic.clone()); } Err(diagnostics) } } } impl Default for Project { fn default() -> Self { Self::new() } } /// Create a warning diagnostic pub fn create_warning( file_id: usize, message: &str, location: Range<usize>, note: Option<String>, ) -> Diagnostic<usize> { let mut diagnostic = Diagnostic::warning() .with_message(message) .with_labels(vec![Label::primary(file_id, location)]); if let Some(note) = note { diagnostic = diagnostic.with_notes(vec![note]); } diagnostic } /// Create an information diagnostic pub fn create_info(file_id: usize, message: &str, location: Range<usize>) -> Diagnostic<usize> { Diagnostic::note() .with_message(message) .with_labels(vec![Label::primary(file_id, location)]) } #[cfg(test)] mod tests { use codespan_reporting::diagnostic::Severity; use super::*; #[test] fn test_lexer() { let mut lexer = Lexer::new("let x = 42 + 3"); let tokens = lexer.tokenize().unwrap(); assert_eq!(tokens.len(), 7); // let x = 42 + 3 EOF } #[test] fn test_type_display() { let int_to_bool = Type::Function(Box::new(Type::Int), Box::new(Type::Bool)); assert_eq!(int_to_bool.to_string(), "int -> bool"); let list_of_ints = Type::List(Box::new(Type::Int)); assert_eq!(list_of_ints.to_string(), "[int]"); } #[test] fn test_diagnostic_creation() { let error = CompilerError::TypeMismatch { expected: Type::Int, found: Type::Bool, location: 10..15, }; let diagnostic = error.to_diagnostic(0); assert_eq!(diagnostic.severity, Severity::Error); } } /// Multi-file project support pub struct Project { engine: DiagnosticEngine, file_ids: Vec<(String, usize)>, } }
The Project struct demonstrates how to manage multiple source files and emit diagnostics across them. This is essential for real compilers that need to report errors spanning multiple modules.
Error Types
Well-designed error types make diagnostics more maintainable and consistent. Each error variant should capture all the information needed to generate a helpful diagnostic.
#![allow(unused)] fn main() { use std::ops::Range; use codespan_reporting::diagnostic::{Diagnostic, Label}; use codespan_reporting::files::SimpleFiles; use codespan_reporting::term::{self, Config}; use termcolor::{ColorChoice, StandardStream}; /// A compiler diagnostic system built on codespan-reporting pub struct DiagnosticEngine { files: SimpleFiles<String, String>, config: Config, } impl DiagnosticEngine { pub fn new() -> Self { Self { files: SimpleFiles::new(), config: Config::default(), } } pub fn add_file(&mut self, name: String, source: String) -> usize { self.files.add(name, source) } pub fn emit_diagnostic(&self, diagnostic: Diagnostic<usize>) { let writer = StandardStream::stderr(ColorChoice::Always); let _ = term::emit(&mut writer.lock(), &self.config, &self.files, &diagnostic); } } impl Default for DiagnosticEngine { fn default() -> Self { Self::new() } } /// Type system for a simple functional language #[derive(Debug, Clone, PartialEq)] pub enum Type { Int, Bool, String, Function(Box<Type>, Box<Type>), List(Box<Type>), Unknown, } impl std::fmt::Display for Type { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Type::Int => write!(f, "int"), Type::Bool => write!(f, "bool"), Type::String => write!(f, "string"), Type::Function(from, to) => write!(f, "{} -> {}", from, to), Type::List(elem) => write!(f, "[{}]", elem), Type::Unknown => write!(f, "_"), } } } impl CompilerError { pub fn to_diagnostic(&self, file_id: usize) -> Diagnostic<usize> { match self { CompilerError::TypeMismatch { expected, found, location, } => Diagnostic::error() .with_message("type mismatch") .with_labels(vec![Label::primary(file_id, location.clone()) .with_message(format!("expected `{}`, found `{}`", expected, found))]), CompilerError::UndefinedVariable { name, location, similar, } => { let mut diagnostic = Diagnostic::error() .with_message(format!("undefined variable `{}`", name)) .with_labels(vec![Label::primary(file_id, location.clone()) .with_message("not found in scope")]); if !similar.is_empty() { let suggestions = similar.join(", "); diagnostic = diagnostic.with_notes(vec![format!("did you mean: {}?", suggestions)]); } diagnostic } CompilerError::ParseError { message, location, hint, } => { let mut diagnostic = Diagnostic::error() .with_message("parse error") .with_labels(vec![ Label::primary(file_id, location.clone()).with_message(message) ]); if let Some(hint) = hint { diagnostic = diagnostic.with_notes(vec![hint.clone()]); } diagnostic } CompilerError::DuplicateDefinition { name, first_location, second_location, } => Diagnostic::error() .with_message(format!("duplicate definition of `{}`", name)) .with_labels(vec![ Label::secondary(file_id, first_location.clone()) .with_message("first definition here"), Label::primary(file_id, second_location.clone()).with_message("redefined here"), ]), } } } /// A simple lexer for demonstration purposes pub struct Lexer<'a> { input: &'a str, position: usize, } #[derive(Debug, Clone)] pub struct Token { pub kind: TokenKind, pub span: Range<usize>, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Number(i64), Identifier(String), Let, If, Else, Function, Arrow, LeftParen, RightParen, Equals, Plus, Minus, Star, Slash, Less, Greater, Bang, Semicolon, Eof, } impl<'a> Lexer<'a> { pub fn new(input: &'a str) -> Self { Self { input, position: 0 } } pub fn tokenize(&mut self) -> Result<Vec<Token>, CompilerError> { let mut tokens = Vec::new(); while self.position < self.input.len() { self.skip_whitespace(); if self.position >= self.input.len() { break; } let start = self.position; let token = self.next_token()?; let end = self.position; tokens.push(Token { kind: token, span: start..end, }); } tokens.push(Token { kind: TokenKind::Eof, span: self.position..self.position, }); Ok(tokens) } fn skip_whitespace(&mut self) { while self.position < self.input.len() && self.input.as_bytes()[self.position].is_ascii_whitespace() { self.position += 1; } } fn next_token(&mut self) -> Result<TokenKind, CompilerError> { let start = self.position; let ch = self.current_char(); match ch { '0'..='9' => self.read_number(), 'a'..='z' | 'A'..='Z' | '_' => self.read_identifier(), '+' => { self.advance(); Ok(TokenKind::Plus) } '-' => { self.advance(); if self.current_char() == '>' { self.advance(); Ok(TokenKind::Arrow) } else { Ok(TokenKind::Minus) } } '*' => { self.advance(); Ok(TokenKind::Star) } '/' => { self.advance(); Ok(TokenKind::Slash) } '<' => { self.advance(); Ok(TokenKind::Less) } '>' => { self.advance(); Ok(TokenKind::Greater) } '!' => { self.advance(); Ok(TokenKind::Bang) } '=' => { self.advance(); Ok(TokenKind::Equals) } '(' => { self.advance(); Ok(TokenKind::LeftParen) } ')' => { self.advance(); Ok(TokenKind::RightParen) } ';' => { self.advance(); Ok(TokenKind::Semicolon) } _ => Err(CompilerError::ParseError { message: format!("unexpected character `{}`", ch), location: start..start + 1, hint: Some("expected a number, identifier, or operator".to_string()), }), } } fn current_char(&self) -> char { self.input.as_bytes()[self.position] as char } fn advance(&mut self) { self.position += 1; } fn read_number(&mut self) -> Result<TokenKind, CompilerError> { let start = self.position; while self.position < self.input.len() && self.input.as_bytes()[self.position].is_ascii_digit() { self.position += 1; } let num_str = &self.input[start..self.position]; let num = num_str.parse().map_err(|_| CompilerError::ParseError { message: "Invalid number format".to_string(), location: start..self.position, hint: Some("Number too large to parse".to_string()), })?; Ok(TokenKind::Number(num)) } fn read_identifier(&mut self) -> Result<TokenKind, CompilerError> { let start = self.position; while self.position < self.input.len() { let ch = self.input.as_bytes()[self.position]; if ch.is_ascii_alphanumeric() || ch == b'_' { self.position += 1; } else { break; } } let ident = &self.input[start..self.position]; let kind = match ident { "let" => TokenKind::Let, "if" => TokenKind::If, "else" => TokenKind::Else, "fn" => TokenKind::Function, _ => TokenKind::Identifier(ident.to_string()), }; Ok(kind) } } /// Multi-file project support pub struct Project { engine: DiagnosticEngine, file_ids: Vec<(String, usize)>, } impl Project { pub fn new() -> Self { Self { engine: DiagnosticEngine::new(), file_ids: Vec::new(), } } pub fn add_file(&mut self, path: String, content: String) -> usize { let file_id = self.engine.add_file(path.clone(), content); self.file_ids.push((path, file_id)); file_id } pub fn compile(&self) -> Result<(), Vec<Diagnostic<usize>>> { let mut diagnostics = Vec::new(); // Simulate compilation with various error types for (path, file_id) in &self.file_ids { if path.ends_with("types.ml") { // Type error example diagnostics.push( CompilerError::TypeMismatch { expected: Type::Int, found: Type::String, location: 45..52, } .to_diagnostic(*file_id), ); } else if path.ends_with("undefined.ml") { // Undefined variable with suggestions diagnostics.push( CompilerError::UndefinedVariable { name: "lenght".to_string(), location: 23..29, similar: vec!["length".to_string(), "len".to_string()], } .to_diagnostic(*file_id), ); } } if diagnostics.is_empty() { Ok(()) } else { for diagnostic in &diagnostics { self.engine.emit_diagnostic(diagnostic.clone()); } Err(diagnostics) } } } impl Default for Project { fn default() -> Self { Self::new() } } /// Create a warning diagnostic pub fn create_warning( file_id: usize, message: &str, location: Range<usize>, note: Option<String>, ) -> Diagnostic<usize> { let mut diagnostic = Diagnostic::warning() .with_message(message) .with_labels(vec![Label::primary(file_id, location)]); if let Some(note) = note { diagnostic = diagnostic.with_notes(vec![note]); } diagnostic } /// Create an information diagnostic pub fn create_info(file_id: usize, message: &str, location: Range<usize>) -> Diagnostic<usize> { Diagnostic::note() .with_message(message) .with_labels(vec![Label::primary(file_id, location)]) } #[cfg(test)] mod tests { use codespan_reporting::diagnostic::Severity; use super::*; #[test] fn test_lexer() { let mut lexer = Lexer::new("let x = 42 + 3"); let tokens = lexer.tokenize().unwrap(); assert_eq!(tokens.len(), 7); // let x = 42 + 3 EOF } #[test] fn test_type_display() { let int_to_bool = Type::Function(Box::new(Type::Int), Box::new(Type::Bool)); assert_eq!(int_to_bool.to_string(), "int -> bool"); let list_of_ints = Type::List(Box::new(Type::Int)); assert_eq!(list_of_ints.to_string(), "[int]"); } #[test] fn test_diagnostic_creation() { let error = CompilerError::TypeMismatch { expected: Type::Int, found: Type::Bool, location: 10..15, }; let diagnostic = error.to_diagnostic(0); assert_eq!(diagnostic.severity, Severity::Error); } } /// Common compiler errors with location information #[derive(Debug, Clone)] pub enum CompilerError { TypeMismatch { expected: Type, found: Type, location: Range<usize>, }, UndefinedVariable { name: String, location: Range<usize>, similar: Vec<String>, }, ParseError { message: String, location: Range<usize>, hint: Option<String>, }, DuplicateDefinition { name: String, first_location: Range<usize>, second_location: Range<usize>, }, } }
These error types demonstrate common patterns: type mismatches with expected and actual types, undefined variables with spelling suggestions, and parse errors with recovery hints.
Advanced Features
The library supports warning and informational diagnostics, not just errors. Different severity levels help users prioritize what to fix first.
#![allow(unused)] fn main() { use std::ops::Range; use codespan_reporting::diagnostic::{Diagnostic, Label}; use codespan_reporting::files::SimpleFiles; use codespan_reporting::term::{self, Config}; use termcolor::{ColorChoice, StandardStream}; /// A compiler diagnostic system built on codespan-reporting pub struct DiagnosticEngine { files: SimpleFiles<String, String>, config: Config, } impl DiagnosticEngine { pub fn new() -> Self { Self { files: SimpleFiles::new(), config: Config::default(), } } pub fn add_file(&mut self, name: String, source: String) -> usize { self.files.add(name, source) } pub fn emit_diagnostic(&self, diagnostic: Diagnostic<usize>) { let writer = StandardStream::stderr(ColorChoice::Always); let _ = term::emit(&mut writer.lock(), &self.config, &self.files, &diagnostic); } } impl Default for DiagnosticEngine { fn default() -> Self { Self::new() } } /// Type system for a simple functional language #[derive(Debug, Clone, PartialEq)] pub enum Type { Int, Bool, String, Function(Box<Type>, Box<Type>), List(Box<Type>), Unknown, } impl std::fmt::Display for Type { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Type::Int => write!(f, "int"), Type::Bool => write!(f, "bool"), Type::String => write!(f, "string"), Type::Function(from, to) => write!(f, "{} -> {}", from, to), Type::List(elem) => write!(f, "[{}]", elem), Type::Unknown => write!(f, "_"), } } } /// Common compiler errors with location information #[derive(Debug, Clone)] pub enum CompilerError { TypeMismatch { expected: Type, found: Type, location: Range<usize>, }, UndefinedVariable { name: String, location: Range<usize>, similar: Vec<String>, }, ParseError { message: String, location: Range<usize>, hint: Option<String>, }, DuplicateDefinition { name: String, first_location: Range<usize>, second_location: Range<usize>, }, } impl CompilerError { pub fn to_diagnostic(&self, file_id: usize) -> Diagnostic<usize> { match self { CompilerError::TypeMismatch { expected, found, location, } => Diagnostic::error() .with_message("type mismatch") .with_labels(vec![Label::primary(file_id, location.clone()) .with_message(format!("expected `{}`, found `{}`", expected, found))]), CompilerError::UndefinedVariable { name, location, similar, } => { let mut diagnostic = Diagnostic::error() .with_message(format!("undefined variable `{}`", name)) .with_labels(vec![Label::primary(file_id, location.clone()) .with_message("not found in scope")]); if !similar.is_empty() { let suggestions = similar.join(", "); diagnostic = diagnostic.with_notes(vec![format!("did you mean: {}?", suggestions)]); } diagnostic } CompilerError::ParseError { message, location, hint, } => { let mut diagnostic = Diagnostic::error() .with_message("parse error") .with_labels(vec![ Label::primary(file_id, location.clone()).with_message(message) ]); if let Some(hint) = hint { diagnostic = diagnostic.with_notes(vec![hint.clone()]); } diagnostic } CompilerError::DuplicateDefinition { name, first_location, second_location, } => Diagnostic::error() .with_message(format!("duplicate definition of `{}`", name)) .with_labels(vec![ Label::secondary(file_id, first_location.clone()) .with_message("first definition here"), Label::primary(file_id, second_location.clone()).with_message("redefined here"), ]), } } } /// A simple lexer for demonstration purposes pub struct Lexer<'a> { input: &'a str, position: usize, } #[derive(Debug, Clone)] pub struct Token { pub kind: TokenKind, pub span: Range<usize>, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Number(i64), Identifier(String), Let, If, Else, Function, Arrow, LeftParen, RightParen, Equals, Plus, Minus, Star, Slash, Less, Greater, Bang, Semicolon, Eof, } impl<'a> Lexer<'a> { pub fn new(input: &'a str) -> Self { Self { input, position: 0 } } pub fn tokenize(&mut self) -> Result<Vec<Token>, CompilerError> { let mut tokens = Vec::new(); while self.position < self.input.len() { self.skip_whitespace(); if self.position >= self.input.len() { break; } let start = self.position; let token = self.next_token()?; let end = self.position; tokens.push(Token { kind: token, span: start..end, }); } tokens.push(Token { kind: TokenKind::Eof, span: self.position..self.position, }); Ok(tokens) } fn skip_whitespace(&mut self) { while self.position < self.input.len() && self.input.as_bytes()[self.position].is_ascii_whitespace() { self.position += 1; } } fn next_token(&mut self) -> Result<TokenKind, CompilerError> { let start = self.position; let ch = self.current_char(); match ch { '0'..='9' => self.read_number(), 'a'..='z' | 'A'..='Z' | '_' => self.read_identifier(), '+' => { self.advance(); Ok(TokenKind::Plus) } '-' => { self.advance(); if self.current_char() == '>' { self.advance(); Ok(TokenKind::Arrow) } else { Ok(TokenKind::Minus) } } '*' => { self.advance(); Ok(TokenKind::Star) } '/' => { self.advance(); Ok(TokenKind::Slash) } '<' => { self.advance(); Ok(TokenKind::Less) } '>' => { self.advance(); Ok(TokenKind::Greater) } '!' => { self.advance(); Ok(TokenKind::Bang) } '=' => { self.advance(); Ok(TokenKind::Equals) } '(' => { self.advance(); Ok(TokenKind::LeftParen) } ')' => { self.advance(); Ok(TokenKind::RightParen) } ';' => { self.advance(); Ok(TokenKind::Semicolon) } _ => Err(CompilerError::ParseError { message: format!("unexpected character `{}`", ch), location: start..start + 1, hint: Some("expected a number, identifier, or operator".to_string()), }), } } fn current_char(&self) -> char { self.input.as_bytes()[self.position] as char } fn advance(&mut self) { self.position += 1; } fn read_number(&mut self) -> Result<TokenKind, CompilerError> { let start = self.position; while self.position < self.input.len() && self.input.as_bytes()[self.position].is_ascii_digit() { self.position += 1; } let num_str = &self.input[start..self.position]; let num = num_str.parse().map_err(|_| CompilerError::ParseError { message: "Invalid number format".to_string(), location: start..self.position, hint: Some("Number too large to parse".to_string()), })?; Ok(TokenKind::Number(num)) } fn read_identifier(&mut self) -> Result<TokenKind, CompilerError> { let start = self.position; while self.position < self.input.len() { let ch = self.input.as_bytes()[self.position]; if ch.is_ascii_alphanumeric() || ch == b'_' { self.position += 1; } else { break; } } let ident = &self.input[start..self.position]; let kind = match ident { "let" => TokenKind::Let, "if" => TokenKind::If, "else" => TokenKind::Else, "fn" => TokenKind::Function, _ => TokenKind::Identifier(ident.to_string()), }; Ok(kind) } } /// Multi-file project support pub struct Project { engine: DiagnosticEngine, file_ids: Vec<(String, usize)>, } impl Project { pub fn new() -> Self { Self { engine: DiagnosticEngine::new(), file_ids: Vec::new(), } } pub fn add_file(&mut self, path: String, content: String) -> usize { let file_id = self.engine.add_file(path.clone(), content); self.file_ids.push((path, file_id)); file_id } pub fn compile(&self) -> Result<(), Vec<Diagnostic<usize>>> { let mut diagnostics = Vec::new(); // Simulate compilation with various error types for (path, file_id) in &self.file_ids { if path.ends_with("types.ml") { // Type error example diagnostics.push( CompilerError::TypeMismatch { expected: Type::Int, found: Type::String, location: 45..52, } .to_diagnostic(*file_id), ); } else if path.ends_with("undefined.ml") { // Undefined variable with suggestions diagnostics.push( CompilerError::UndefinedVariable { name: "lenght".to_string(), location: 23..29, similar: vec!["length".to_string(), "len".to_string()], } .to_diagnostic(*file_id), ); } } if diagnostics.is_empty() { Ok(()) } else { for diagnostic in &diagnostics { self.engine.emit_diagnostic(diagnostic.clone()); } Err(diagnostics) } } } impl Default for Project { fn default() -> Self { Self::new() } } /// Create an information diagnostic pub fn create_info(file_id: usize, message: &str, location: Range<usize>) -> Diagnostic<usize> { Diagnostic::note() .with_message(message) .with_labels(vec![Label::primary(file_id, location)]) } #[cfg(test)] mod tests { use codespan_reporting::diagnostic::Severity; use super::*; #[test] fn test_lexer() { let mut lexer = Lexer::new("let x = 42 + 3"); let tokens = lexer.tokenize().unwrap(); assert_eq!(tokens.len(), 7); // let x = 42 + 3 EOF } #[test] fn test_type_display() { let int_to_bool = Type::Function(Box::new(Type::Int), Box::new(Type::Bool)); assert_eq!(int_to_bool.to_string(), "int -> bool"); let list_of_ints = Type::List(Box::new(Type::Int)); assert_eq!(list_of_ints.to_string(), "[int]"); } #[test] fn test_diagnostic_creation() { let error = CompilerError::TypeMismatch { expected: Type::Int, found: Type::Bool, location: 10..15, }; let diagnostic = error.to_diagnostic(0); assert_eq!(diagnostic.severity, Severity::Error); } } /// Create a warning diagnostic pub fn create_warning( file_id: usize, message: &str, location: Range<usize>, note: Option<String>, ) -> Diagnostic<usize> { let mut diagnostic = Diagnostic::warning() .with_message(message) .with_labels(vec![Label::primary(file_id, location)]); if let Some(note) = note { diagnostic = diagnostic.with_notes(vec![note]); } diagnostic } }
Warnings can include notes with additional context or suggestions for fixing the issue. This helps guide users toward better code patterns.
Integration with Lexers
Real compiler diagnostics need accurate source locations. Here’s a simple lexer that tracks positions for error reporting:
#![allow(unused)] fn main() { use std::ops::Range; use codespan_reporting::diagnostic::{Diagnostic, Label}; use codespan_reporting::files::SimpleFiles; use codespan_reporting::term::{self, Config}; use termcolor::{ColorChoice, StandardStream}; /// A compiler diagnostic system built on codespan-reporting pub struct DiagnosticEngine { files: SimpleFiles<String, String>, config: Config, } impl DiagnosticEngine { pub fn new() -> Self { Self { files: SimpleFiles::new(), config: Config::default(), } } pub fn add_file(&mut self, name: String, source: String) -> usize { self.files.add(name, source) } pub fn emit_diagnostic(&self, diagnostic: Diagnostic<usize>) { let writer = StandardStream::stderr(ColorChoice::Always); let _ = term::emit(&mut writer.lock(), &self.config, &self.files, &diagnostic); } } impl Default for DiagnosticEngine { fn default() -> Self { Self::new() } } /// Type system for a simple functional language #[derive(Debug, Clone, PartialEq)] pub enum Type { Int, Bool, String, Function(Box<Type>, Box<Type>), List(Box<Type>), Unknown, } impl std::fmt::Display for Type { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Type::Int => write!(f, "int"), Type::Bool => write!(f, "bool"), Type::String => write!(f, "string"), Type::Function(from, to) => write!(f, "{} -> {}", from, to), Type::List(elem) => write!(f, "[{}]", elem), Type::Unknown => write!(f, "_"), } } } /// Common compiler errors with location information #[derive(Debug, Clone)] pub enum CompilerError { TypeMismatch { expected: Type, found: Type, location: Range<usize>, }, UndefinedVariable { name: String, location: Range<usize>, similar: Vec<String>, }, ParseError { message: String, location: Range<usize>, hint: Option<String>, }, DuplicateDefinition { name: String, first_location: Range<usize>, second_location: Range<usize>, }, } impl CompilerError { pub fn to_diagnostic(&self, file_id: usize) -> Diagnostic<usize> { match self { CompilerError::TypeMismatch { expected, found, location, } => Diagnostic::error() .with_message("type mismatch") .with_labels(vec![Label::primary(file_id, location.clone()) .with_message(format!("expected `{}`, found `{}`", expected, found))]), CompilerError::UndefinedVariable { name, location, similar, } => { let mut diagnostic = Diagnostic::error() .with_message(format!("undefined variable `{}`", name)) .with_labels(vec![Label::primary(file_id, location.clone()) .with_message("not found in scope")]); if !similar.is_empty() { let suggestions = similar.join(", "); diagnostic = diagnostic.with_notes(vec![format!("did you mean: {}?", suggestions)]); } diagnostic } CompilerError::ParseError { message, location, hint, } => { let mut diagnostic = Diagnostic::error() .with_message("parse error") .with_labels(vec![ Label::primary(file_id, location.clone()).with_message(message) ]); if let Some(hint) = hint { diagnostic = diagnostic.with_notes(vec![hint.clone()]); } diagnostic } CompilerError::DuplicateDefinition { name, first_location, second_location, } => Diagnostic::error() .with_message(format!("duplicate definition of `{}`", name)) .with_labels(vec![ Label::secondary(file_id, first_location.clone()) .with_message("first definition here"), Label::primary(file_id, second_location.clone()).with_message("redefined here"), ]), } } } #[derive(Debug, Clone)] pub struct Token { pub kind: TokenKind, pub span: Range<usize>, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Number(i64), Identifier(String), Let, If, Else, Function, Arrow, LeftParen, RightParen, Equals, Plus, Minus, Star, Slash, Less, Greater, Bang, Semicolon, Eof, } impl<'a> Lexer<'a> { pub fn new(input: &'a str) -> Self { Self { input, position: 0 } } pub fn tokenize(&mut self) -> Result<Vec<Token>, CompilerError> { let mut tokens = Vec::new(); while self.position < self.input.len() { self.skip_whitespace(); if self.position >= self.input.len() { break; } let start = self.position; let token = self.next_token()?; let end = self.position; tokens.push(Token { kind: token, span: start..end, }); } tokens.push(Token { kind: TokenKind::Eof, span: self.position..self.position, }); Ok(tokens) } fn skip_whitespace(&mut self) { while self.position < self.input.len() && self.input.as_bytes()[self.position].is_ascii_whitespace() { self.position += 1; } } fn next_token(&mut self) -> Result<TokenKind, CompilerError> { let start = self.position; let ch = self.current_char(); match ch { '0'..='9' => self.read_number(), 'a'..='z' | 'A'..='Z' | '_' => self.read_identifier(), '+' => { self.advance(); Ok(TokenKind::Plus) } '-' => { self.advance(); if self.current_char() == '>' { self.advance(); Ok(TokenKind::Arrow) } else { Ok(TokenKind::Minus) } } '*' => { self.advance(); Ok(TokenKind::Star) } '/' => { self.advance(); Ok(TokenKind::Slash) } '<' => { self.advance(); Ok(TokenKind::Less) } '>' => { self.advance(); Ok(TokenKind::Greater) } '!' => { self.advance(); Ok(TokenKind::Bang) } '=' => { self.advance(); Ok(TokenKind::Equals) } '(' => { self.advance(); Ok(TokenKind::LeftParen) } ')' => { self.advance(); Ok(TokenKind::RightParen) } ';' => { self.advance(); Ok(TokenKind::Semicolon) } _ => Err(CompilerError::ParseError { message: format!("unexpected character `{}`", ch), location: start..start + 1, hint: Some("expected a number, identifier, or operator".to_string()), }), } } fn current_char(&self) -> char { self.input.as_bytes()[self.position] as char } fn advance(&mut self) { self.position += 1; } fn read_number(&mut self) -> Result<TokenKind, CompilerError> { let start = self.position; while self.position < self.input.len() && self.input.as_bytes()[self.position].is_ascii_digit() { self.position += 1; } let num_str = &self.input[start..self.position]; let num = num_str.parse().map_err(|_| CompilerError::ParseError { message: "Invalid number format".to_string(), location: start..self.position, hint: Some("Number too large to parse".to_string()), })?; Ok(TokenKind::Number(num)) } fn read_identifier(&mut self) -> Result<TokenKind, CompilerError> { let start = self.position; while self.position < self.input.len() { let ch = self.input.as_bytes()[self.position]; if ch.is_ascii_alphanumeric() || ch == b'_' { self.position += 1; } else { break; } } let ident = &self.input[start..self.position]; let kind = match ident { "let" => TokenKind::Let, "if" => TokenKind::If, "else" => TokenKind::Else, "fn" => TokenKind::Function, _ => TokenKind::Identifier(ident.to_string()), }; Ok(kind) } } /// Multi-file project support pub struct Project { engine: DiagnosticEngine, file_ids: Vec<(String, usize)>, } impl Project { pub fn new() -> Self { Self { engine: DiagnosticEngine::new(), file_ids: Vec::new(), } } pub fn add_file(&mut self, path: String, content: String) -> usize { let file_id = self.engine.add_file(path.clone(), content); self.file_ids.push((path, file_id)); file_id } pub fn compile(&self) -> Result<(), Vec<Diagnostic<usize>>> { let mut diagnostics = Vec::new(); // Simulate compilation with various error types for (path, file_id) in &self.file_ids { if path.ends_with("types.ml") { // Type error example diagnostics.push( CompilerError::TypeMismatch { expected: Type::Int, found: Type::String, location: 45..52, } .to_diagnostic(*file_id), ); } else if path.ends_with("undefined.ml") { // Undefined variable with suggestions diagnostics.push( CompilerError::UndefinedVariable { name: "lenght".to_string(), location: 23..29, similar: vec!["length".to_string(), "len".to_string()], } .to_diagnostic(*file_id), ); } } if diagnostics.is_empty() { Ok(()) } else { for diagnostic in &diagnostics { self.engine.emit_diagnostic(diagnostic.clone()); } Err(diagnostics) } } } impl Default for Project { fn default() -> Self { Self::new() } } /// Create a warning diagnostic pub fn create_warning( file_id: usize, message: &str, location: Range<usize>, note: Option<String>, ) -> Diagnostic<usize> { let mut diagnostic = Diagnostic::warning() .with_message(message) .with_labels(vec![Label::primary(file_id, location)]); if let Some(note) = note { diagnostic = diagnostic.with_notes(vec![note]); } diagnostic } /// Create an information diagnostic pub fn create_info(file_id: usize, message: &str, location: Range<usize>) -> Diagnostic<usize> { Diagnostic::note() .with_message(message) .with_labels(vec![Label::primary(file_id, location)]) } #[cfg(test)] mod tests { use codespan_reporting::diagnostic::Severity; use super::*; #[test] fn test_lexer() { let mut lexer = Lexer::new("let x = 42 + 3"); let tokens = lexer.tokenize().unwrap(); assert_eq!(tokens.len(), 7); // let x = 42 + 3 EOF } #[test] fn test_type_display() { let int_to_bool = Type::Function(Box::new(Type::Int), Box::new(Type::Bool)); assert_eq!(int_to_bool.to_string(), "int -> bool"); let list_of_ints = Type::List(Box::new(Type::Int)); assert_eq!(list_of_ints.to_string(), "[int]"); } #[test] fn test_diagnostic_creation() { let error = CompilerError::TypeMismatch { expected: Type::Int, found: Type::Bool, location: 10..15, }; let diagnostic = error.to_diagnostic(0); assert_eq!(diagnostic.severity, Severity::Error); } } /// A simple lexer for demonstration purposes pub struct Lexer<'a> { input: &'a str, position: usize, } }
The lexer maintains byte positions for each token, which can be used directly in diagnostic labels. This ensures error underlines appear in exactly the right place.
Terminal Configuration
The library provides fine-grained control over output formatting through the Config type. You can customize colors, character sets, and layout options to match your project’s needs. The default configuration works well for most cases, automatically adapting to terminal capabilities.
Best Practices
Structure your error types to capture semantic information, not just strings. This makes it easier to provide consistent, helpful diagnostics throughout your compiler. Include spelling suggestions when reporting undefined names by computing edit distance to known identifiers.
Group related errors together when they have a common cause. For example, if a type error cascades through multiple expressions, show the root cause prominently and list the consequences as secondary information.
Use notes and help messages to educate users about language features. Good diagnostics teach users how to write better code, not just point out what’s wrong. Include examples in help text when appropriate.
For parse errors, show what tokens were expected at the error location. This helps users understand the grammar and fix syntax errors quickly. Recovery hints can suggest common fixes for typical mistakes.
The codespan-reporting crate has become essential infrastructure for Rust compiler projects. Its thoughtful design and attention to user experience set the standard for compiler diagnostics. By following its patterns, you can provide error messages that help rather than frustrate your users.
miette
Miette is a comprehensive diagnostic library that brings Rust’s excellent error reporting philosophy to your compiler projects. It provides a complete framework for creating beautiful, informative error messages with minimal boilerplate. The library excels at showing context, providing actionable help text, and maintaining consistency across all diagnostics.
Unlike simpler error reporting libraries, miette handles the entire diagnostic pipeline: from error definition through to rendering. It supports fancy Unicode rendering, screen-reader-friendly output, syntax highlighting, and even clickable error codes in supported terminals. The derive macro makes it trivial to create rich diagnostics while maintaining type safety.
Core Concepts
Miette revolves around the Diagnostic trait, which extends the standard Error trait with additional metadata. Every diagnostic can include source code snippets, labeled spans, help text, error codes, and related errors. The library handles all the complexity of rendering these elements attractively.
#![allow(unused)] fn main() { use std::fmt; use miette::{Diagnostic, LabeledSpan, NamedSource, SourceSpan}; use thiserror::Error; /// Type representation for our compiler #[derive(Debug, Clone, PartialEq)] pub enum Type { Int, Float, Bool, String, Array(Box<Type>), Function(Vec<Type>, Box<Type>), Struct(String), Never, } impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Type::Int => write!(f, "int"), Type::Float => write!(f, "float"), Type::Bool => write!(f, "bool"), Type::String => write!(f, "string"), Type::Array(elem) => write!(f, "[{}]", elem), Type::Function(params, ret) => { write!(f, "fn(")?; for (i, param) in params.iter().enumerate() { if i > 0 { write!(f, ", ")?; } write!(f, "{}", param)?; } write!(f, ") -> {}", ret) } Type::Struct(name) => write!(f, "{}", name), Type::Never => write!(f, "!"), } } } impl ParseError { pub fn new( filename: String, source: String, span: SourceSpan, expected: String, found: String, ) -> Self { Self { src: NamedSource::new(filename.clone(), source), err_span: span, expected, found, filename, context_span: None, } } pub fn with_context(mut self, span: SourceSpan) -> Self { self.context_span = Some(span); self } } /// Type mismatch error with rich diagnostics #[derive(Error, Debug, Diagnostic)] #[error("Type mismatch in expression")] #[diagnostic( code(compiler::typecheck::type_mismatch), url("https://example.com/errors/type-mismatch"), severity(Error) )] pub struct TypeMismatchError { #[source_code] src: NamedSource<String>, #[label(primary, "Expected type `{expected}` but found `{actual}`")] expr_span: SourceSpan, #[label("Expected due to this")] reason_span: Option<SourceSpan>, expected: Type, actual: Type, #[help] suggestion: Option<String>, } impl TypeMismatchError { pub fn new( filename: String, source: String, expr_span: SourceSpan, expected: Type, actual: Type, ) -> Self { let suggestion = match (&expected, &actual) { (Type::String, Type::Int) => Some("Try using `.to_string()` to convert".to_string()), (Type::Int, Type::String) => { Some("Try using `.parse::<i32>()?` to convert".to_string()) } (Type::Float, Type::Int) => Some("Try using `as f64` to convert".to_string()), _ => None, }; Self { src: NamedSource::new(filename, source), expr_span, reason_span: None, expected, actual, suggestion, } } pub fn with_reason(mut self, span: SourceSpan) -> Self { self.reason_span = Some(span); self } } /// Undefined variable error with suggestions #[derive(Error, Debug, Diagnostic)] #[error("Undefined variable `{name}`")] #[diagnostic( code(compiler::resolve::undefined_variable), help("Did you mean {suggestions}?") )] pub struct UndefinedVariableError { #[source_code] src: NamedSource<String>, #[label(primary, "Not found in this scope")] var_span: SourceSpan, name: String, suggestions: String, #[related] similar_vars: Vec<SimilarVariable>, } /// Similar variable found in scope #[derive(Error, Debug, Diagnostic)] #[error("Similar variable `{name}` defined here")] #[diagnostic(severity(Warning))] struct SimilarVariable { #[label] span: SourceSpan, name: String, } impl UndefinedVariableError { pub fn new( filename: String, source: String, span: SourceSpan, name: String, similar: Vec<(&str, SourceSpan)>, ) -> Self { let suggestions = similar .iter() .map(|(name, _)| format!("`{}`", name)) .collect::<Vec<_>>() .join(", "); let similar_vars = similar .into_iter() .map(|(name, span)| SimilarVariable { span, name: name.to_string(), }) .collect(); Self { src: NamedSource::new(filename, source), var_span: span, name, suggestions, similar_vars, } } } /// Multiple errors collected together #[derive(Error, Debug, Diagnostic)] #[error("Multiple errors occurred during compilation")] #[diagnostic( code(compiler::multiple_errors), help("Fix the errors in order, as later errors may be caused by earlier ones") )] pub struct CompilationErrors { #[source_code] src: NamedSource<String>, #[related] errors: Vec<Box<dyn Diagnostic + Send + Sync>>, error_count: usize, warning_count: usize, } impl CompilationErrors { pub fn new(filename: String, source: String) -> Self { Self { src: NamedSource::new(filename, source), errors: Vec::new(), error_count: 0, warning_count: 0, } } pub fn push<E: Diagnostic + Send + Sync + 'static>(&mut self, error: E) { match error.severity() { Some(miette::Severity::Warning) => self.warning_count += 1, _ => self.error_count += 1, } self.errors.push(Box::new(error)); } pub fn is_empty(&self) -> bool { self.errors.is_empty() } } /// Borrow checker error #[derive(Error, Debug, Diagnostic)] #[error("Cannot borrow `{variable}` as mutable more than once")] #[diagnostic( code(compiler::borrow_check::multiple_mutable), url(docsrs), help("Consider using RefCell for interior mutability") )] pub struct BorrowError { #[source_code] src: NamedSource<String>, #[label(primary, "Second mutable borrow occurs here")] second_borrow: SourceSpan, #[label("First mutable borrow occurs here")] first_borrow: SourceSpan, #[label("First borrow later used here")] first_use: Option<SourceSpan>, variable: String, } impl BorrowError { pub fn new( filename: String, source: String, first_borrow: SourceSpan, second_borrow: SourceSpan, variable: String, ) -> Self { Self { src: NamedSource::new(filename, source), second_borrow, first_borrow, first_use: None, variable, } } pub fn with_first_use(mut self, span: SourceSpan) -> Self { self.first_use = Some(span); self } } /// Pattern matching exhaustiveness error #[derive(Error, Debug, Diagnostic)] #[error("Non-exhaustive patterns")] #[diagnostic(code(compiler::pattern_match::non_exhaustive))] pub struct NonExhaustiveMatch { #[source_code] src: NamedSource<String>, #[label(primary, "Pattern match is non-exhaustive")] match_span: SourceSpan, #[label(collection, "Missing pattern")] missing_patterns: Vec<LabeledSpan>, #[help] missing_list: String, } impl NonExhaustiveMatch { pub fn new( filename: String, source: String, match_span: SourceSpan, missing: Vec<String>, ) -> Self { let missing_patterns = missing .iter() .map(|_pattern| LabeledSpan::underline(match_span)) .collect(); let missing_list = format!( "Missing patterns: {}\n\nEnsure all cases are covered or add a wildcard pattern `_`", missing.iter().map(|p| format!("`{}`", p)).collect::<Vec<_>>().join(", ") ); Self { src: NamedSource::new(filename, source), match_span, missing_patterns, missing_list, } } } /// Import cycle detection #[derive(Error, Debug, Diagnostic)] #[error("Circular dependency detected")] #[diagnostic(code(compiler::imports::cycle), severity(Error))] pub struct CyclicImportError { #[source_code] src: NamedSource<String>, #[label(collection, "Module in cycle")] cycle_spans: Vec<LabeledSpan>, #[help] help_text: String, } impl CyclicImportError { pub fn new(filename: String, source: String, modules: Vec<(String, SourceSpan)>) -> Self { let cycle_spans = modules .iter() .enumerate() .map(|(i, (name, span))| { let next = &modules[(i + 1) % modules.len()].0; LabeledSpan::new( Some(format!("`{}` imports `{}`", name, next)), span.offset(), span.len(), ) }) .collect(); let module_list = modules .iter() .map(|(name, _)| name.as_str()) .collect::<Vec<_>>() .join(" -> "); Self { src: NamedSource::new(filename, source), cycle_spans, help_text: format!("Break the cycle: {} -> ...", module_list), } } } /// Deprecated feature warning #[derive(Error, Debug, Diagnostic)] #[error("Use of deprecated feature `{feature}`")] #[diagnostic(code(compiler::deprecated), severity(Warning))] pub struct DeprecationWarning { #[source_code] src: NamedSource<String>, #[label(primary, "Deprecated since version {since}")] usage_span: SourceSpan, feature: String, since: String, #[help] alternative: String, } impl DeprecationWarning { pub fn new( filename: String, source: String, usage_span: SourceSpan, feature: String, since: String, alternative: String, ) -> Self { Self { src: NamedSource::new(filename, source), usage_span, feature, since, alternative, } } } /// Dynamic diagnostic creation pub fn create_diagnostic( filename: String, source: String, span: SourceSpan, message: String, help: Option<String>, ) -> miette::Report { let labels = vec![LabeledSpan::underline(span)]; miette::miette!( labels = labels, help = help.unwrap_or_default(), "{}", message ) .with_source_code(NamedSource::new(filename, source)) } /// Syntax highlighting support pub fn create_highlighted_error( filename: String, source: String, span: SourceSpan, ) -> impl Diagnostic { let src = NamedSource::new(filename, source).with_language("rust"); #[derive(Error, Debug, Diagnostic)] #[error("Syntax error in Rust code")] #[diagnostic(code(compiler::syntax))] struct HighlightedError { #[source_code] src: NamedSource<String>, #[label("Invalid syntax here")] span: SourceSpan, } HighlightedError { src, span } } #[cfg(test)] mod tests { use super::*; #[test] fn test_type_display() { let func_type = Type::Function(vec![Type::Int, Type::String], Box::new(Type::Bool)); assert_eq!(func_type.to_string(), "fn(int, string) -> bool"); } #[test] fn test_parse_error_creation() { let error = ParseError::new( "test.rs".to_string(), "let x = ;".to_string(), (8, 1).into(), "expression".to_string(), "semicolon".to_string(), ); assert!(error.to_string().contains("Parse error")); } #[test] fn test_multiple_errors() { let mut errors = CompilationErrors::new("test.rs".to_string(), "code".to_string()); assert!(errors.is_empty()); let parse_err = ParseError::new( "test.rs".to_string(), "code".to_string(), (0, 4).into(), "identifier".to_string(), "keyword".to_string(), ); errors.push(parse_err); assert!(!errors.is_empty()); assert_eq!(errors.error_count, 1); assert_eq!(errors.warning_count, 0); } } /// Parser error with detailed diagnostics #[derive(Error, Debug, Diagnostic)] #[error("Parse error in {filename}")] #[diagnostic( code(compiler::parse::syntax_error), url(docsrs), help("Check for missing semicolons, unmatched brackets, or typos in keywords") )] pub struct ParseError { #[source_code] src: NamedSource<String>, #[label("Expected {expected} but found {found}")] err_span: SourceSpan, expected: String, found: String, filename: String, #[label("Parsing started here")] context_span: Option<SourceSpan>, } }
This parse error demonstrates the key components: source code attachment, labeled spans with custom messages, error codes with documentation links, and contextual help. The #[source_code]
attribute tells miette where to find the source text for snippet rendering.
Type-Safe Diagnostics
Creating type-safe, reusable diagnostic types is straightforward with the derive macro. Each error type can capture all relevant context and provide specialized help based on the specific situation.
#![allow(unused)] fn main() { use std::fmt; use miette::{Diagnostic, LabeledSpan, NamedSource, SourceSpan}; use thiserror::Error; /// Type representation for our compiler #[derive(Debug, Clone, PartialEq)] pub enum Type { Int, Float, Bool, String, Array(Box<Type>), Function(Vec<Type>, Box<Type>), Struct(String), Never, } impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Type::Int => write!(f, "int"), Type::Float => write!(f, "float"), Type::Bool => write!(f, "bool"), Type::String => write!(f, "string"), Type::Array(elem) => write!(f, "[{}]", elem), Type::Function(params, ret) => { write!(f, "fn(")?; for (i, param) in params.iter().enumerate() { if i > 0 { write!(f, ", ")?; } write!(f, "{}", param)?; } write!(f, ") -> {}", ret) } Type::Struct(name) => write!(f, "{}", name), Type::Never => write!(f, "!"), } } } /// Parser error with detailed diagnostics #[derive(Error, Debug, Diagnostic)] #[error("Parse error in {filename}")] #[diagnostic( code(compiler::parse::syntax_error), url(docsrs), help("Check for missing semicolons, unmatched brackets, or typos in keywords") )] pub struct ParseError { #[source_code] src: NamedSource<String>, #[label("Expected {expected} but found {found}")] err_span: SourceSpan, expected: String, found: String, filename: String, #[label("Parsing started here")] context_span: Option<SourceSpan>, } impl ParseError { pub fn new( filename: String, source: String, span: SourceSpan, expected: String, found: String, ) -> Self { Self { src: NamedSource::new(filename.clone(), source), err_span: span, expected, found, filename, context_span: None, } } pub fn with_context(mut self, span: SourceSpan) -> Self { self.context_span = Some(span); self } } impl TypeMismatchError { pub fn new( filename: String, source: String, expr_span: SourceSpan, expected: Type, actual: Type, ) -> Self { let suggestion = match (&expected, &actual) { (Type::String, Type::Int) => Some("Try using `.to_string()` to convert".to_string()), (Type::Int, Type::String) => { Some("Try using `.parse::<i32>()?` to convert".to_string()) } (Type::Float, Type::Int) => Some("Try using `as f64` to convert".to_string()), _ => None, }; Self { src: NamedSource::new(filename, source), expr_span, reason_span: None, expected, actual, suggestion, } } pub fn with_reason(mut self, span: SourceSpan) -> Self { self.reason_span = Some(span); self } } /// Undefined variable error with suggestions #[derive(Error, Debug, Diagnostic)] #[error("Undefined variable `{name}`")] #[diagnostic( code(compiler::resolve::undefined_variable), help("Did you mean {suggestions}?") )] pub struct UndefinedVariableError { #[source_code] src: NamedSource<String>, #[label(primary, "Not found in this scope")] var_span: SourceSpan, name: String, suggestions: String, #[related] similar_vars: Vec<SimilarVariable>, } /// Similar variable found in scope #[derive(Error, Debug, Diagnostic)] #[error("Similar variable `{name}` defined here")] #[diagnostic(severity(Warning))] struct SimilarVariable { #[label] span: SourceSpan, name: String, } impl UndefinedVariableError { pub fn new( filename: String, source: String, span: SourceSpan, name: String, similar: Vec<(&str, SourceSpan)>, ) -> Self { let suggestions = similar .iter() .map(|(name, _)| format!("`{}`", name)) .collect::<Vec<_>>() .join(", "); let similar_vars = similar .into_iter() .map(|(name, span)| SimilarVariable { span, name: name.to_string(), }) .collect(); Self { src: NamedSource::new(filename, source), var_span: span, name, suggestions, similar_vars, } } } /// Multiple errors collected together #[derive(Error, Debug, Diagnostic)] #[error("Multiple errors occurred during compilation")] #[diagnostic( code(compiler::multiple_errors), help("Fix the errors in order, as later errors may be caused by earlier ones") )] pub struct CompilationErrors { #[source_code] src: NamedSource<String>, #[related] errors: Vec<Box<dyn Diagnostic + Send + Sync>>, error_count: usize, warning_count: usize, } impl CompilationErrors { pub fn new(filename: String, source: String) -> Self { Self { src: NamedSource::new(filename, source), errors: Vec::new(), error_count: 0, warning_count: 0, } } pub fn push<E: Diagnostic + Send + Sync + 'static>(&mut self, error: E) { match error.severity() { Some(miette::Severity::Warning) => self.warning_count += 1, _ => self.error_count += 1, } self.errors.push(Box::new(error)); } pub fn is_empty(&self) -> bool { self.errors.is_empty() } } /// Borrow checker error #[derive(Error, Debug, Diagnostic)] #[error("Cannot borrow `{variable}` as mutable more than once")] #[diagnostic( code(compiler::borrow_check::multiple_mutable), url(docsrs), help("Consider using RefCell for interior mutability") )] pub struct BorrowError { #[source_code] src: NamedSource<String>, #[label(primary, "Second mutable borrow occurs here")] second_borrow: SourceSpan, #[label("First mutable borrow occurs here")] first_borrow: SourceSpan, #[label("First borrow later used here")] first_use: Option<SourceSpan>, variable: String, } impl BorrowError { pub fn new( filename: String, source: String, first_borrow: SourceSpan, second_borrow: SourceSpan, variable: String, ) -> Self { Self { src: NamedSource::new(filename, source), second_borrow, first_borrow, first_use: None, variable, } } pub fn with_first_use(mut self, span: SourceSpan) -> Self { self.first_use = Some(span); self } } /// Pattern matching exhaustiveness error #[derive(Error, Debug, Diagnostic)] #[error("Non-exhaustive patterns")] #[diagnostic(code(compiler::pattern_match::non_exhaustive))] pub struct NonExhaustiveMatch { #[source_code] src: NamedSource<String>, #[label(primary, "Pattern match is non-exhaustive")] match_span: SourceSpan, #[label(collection, "Missing pattern")] missing_patterns: Vec<LabeledSpan>, #[help] missing_list: String, } impl NonExhaustiveMatch { pub fn new( filename: String, source: String, match_span: SourceSpan, missing: Vec<String>, ) -> Self { let missing_patterns = missing .iter() .map(|_pattern| LabeledSpan::underline(match_span)) .collect(); let missing_list = format!( "Missing patterns: {}\n\nEnsure all cases are covered or add a wildcard pattern `_`", missing.iter().map(|p| format!("`{}`", p)).collect::<Vec<_>>().join(", ") ); Self { src: NamedSource::new(filename, source), match_span, missing_patterns, missing_list, } } } /// Import cycle detection #[derive(Error, Debug, Diagnostic)] #[error("Circular dependency detected")] #[diagnostic(code(compiler::imports::cycle), severity(Error))] pub struct CyclicImportError { #[source_code] src: NamedSource<String>, #[label(collection, "Module in cycle")] cycle_spans: Vec<LabeledSpan>, #[help] help_text: String, } impl CyclicImportError { pub fn new(filename: String, source: String, modules: Vec<(String, SourceSpan)>) -> Self { let cycle_spans = modules .iter() .enumerate() .map(|(i, (name, span))| { let next = &modules[(i + 1) % modules.len()].0; LabeledSpan::new( Some(format!("`{}` imports `{}`", name, next)), span.offset(), span.len(), ) }) .collect(); let module_list = modules .iter() .map(|(name, _)| name.as_str()) .collect::<Vec<_>>() .join(" -> "); Self { src: NamedSource::new(filename, source), cycle_spans, help_text: format!("Break the cycle: {} -> ...", module_list), } } } /// Deprecated feature warning #[derive(Error, Debug, Diagnostic)] #[error("Use of deprecated feature `{feature}`")] #[diagnostic(code(compiler::deprecated), severity(Warning))] pub struct DeprecationWarning { #[source_code] src: NamedSource<String>, #[label(primary, "Deprecated since version {since}")] usage_span: SourceSpan, feature: String, since: String, #[help] alternative: String, } impl DeprecationWarning { pub fn new( filename: String, source: String, usage_span: SourceSpan, feature: String, since: String, alternative: String, ) -> Self { Self { src: NamedSource::new(filename, source), usage_span, feature, since, alternative, } } } /// Dynamic diagnostic creation pub fn create_diagnostic( filename: String, source: String, span: SourceSpan, message: String, help: Option<String>, ) -> miette::Report { let labels = vec![LabeledSpan::underline(span)]; miette::miette!( labels = labels, help = help.unwrap_or_default(), "{}", message ) .with_source_code(NamedSource::new(filename, source)) } /// Syntax highlighting support pub fn create_highlighted_error( filename: String, source: String, span: SourceSpan, ) -> impl Diagnostic { let src = NamedSource::new(filename, source).with_language("rust"); #[derive(Error, Debug, Diagnostic)] #[error("Syntax error in Rust code")] #[diagnostic(code(compiler::syntax))] struct HighlightedError { #[source_code] src: NamedSource<String>, #[label("Invalid syntax here")] span: SourceSpan, } HighlightedError { src, span } } #[cfg(test)] mod tests { use super::*; #[test] fn test_type_display() { let func_type = Type::Function(vec![Type::Int, Type::String], Box::new(Type::Bool)); assert_eq!(func_type.to_string(), "fn(int, string) -> bool"); } #[test] fn test_parse_error_creation() { let error = ParseError::new( "test.rs".to_string(), "let x = ;".to_string(), (8, 1).into(), "expression".to_string(), "semicolon".to_string(), ); assert!(error.to_string().contains("Parse error")); } #[test] fn test_multiple_errors() { let mut errors = CompilationErrors::new("test.rs".to_string(), "code".to_string()); assert!(errors.is_empty()); let parse_err = ParseError::new( "test.rs".to_string(), "code".to_string(), (0, 4).into(), "identifier".to_string(), "keyword".to_string(), ); errors.push(parse_err); assert!(!errors.is_empty()); assert_eq!(errors.error_count, 1); assert_eq!(errors.warning_count, 0); } } /// Type mismatch error with rich diagnostics #[derive(Error, Debug, Diagnostic)] #[error("Type mismatch in expression")] #[diagnostic( code(compiler::typecheck::type_mismatch), url("https://example.com/errors/type-mismatch"), severity(Error) )] pub struct TypeMismatchError { #[source_code] src: NamedSource<String>, #[label(primary, "Expected type `{expected}` but found `{actual}`")] expr_span: SourceSpan, #[label("Expected due to this")] reason_span: Option<SourceSpan>, expected: Type, actual: Type, #[help] suggestion: Option<String>, } }
The type mismatch error shows how to provide contextual suggestions. By examining the expected and actual types, it can offer specific conversion advice. This pattern scales well to complex type systems with many possible conversions.
Undefined Variables with Suggestions
One of miette’s strengths is showing related information. The undefined variable error demonstrates how to include suggestions and point to similar names in scope.
#![allow(unused)] fn main() { use std::fmt; use miette::{Diagnostic, LabeledSpan, NamedSource, SourceSpan}; use thiserror::Error; /// Type representation for our compiler #[derive(Debug, Clone, PartialEq)] pub enum Type { Int, Float, Bool, String, Array(Box<Type>), Function(Vec<Type>, Box<Type>), Struct(String), Never, } impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Type::Int => write!(f, "int"), Type::Float => write!(f, "float"), Type::Bool => write!(f, "bool"), Type::String => write!(f, "string"), Type::Array(elem) => write!(f, "[{}]", elem), Type::Function(params, ret) => { write!(f, "fn(")?; for (i, param) in params.iter().enumerate() { if i > 0 { write!(f, ", ")?; } write!(f, "{}", param)?; } write!(f, ") -> {}", ret) } Type::Struct(name) => write!(f, "{}", name), Type::Never => write!(f, "!"), } } } /// Parser error with detailed diagnostics #[derive(Error, Debug, Diagnostic)] #[error("Parse error in {filename}")] #[diagnostic( code(compiler::parse::syntax_error), url(docsrs), help("Check for missing semicolons, unmatched brackets, or typos in keywords") )] pub struct ParseError { #[source_code] src: NamedSource<String>, #[label("Expected {expected} but found {found}")] err_span: SourceSpan, expected: String, found: String, filename: String, #[label("Parsing started here")] context_span: Option<SourceSpan>, } impl ParseError { pub fn new( filename: String, source: String, span: SourceSpan, expected: String, found: String, ) -> Self { Self { src: NamedSource::new(filename.clone(), source), err_span: span, expected, found, filename, context_span: None, } } pub fn with_context(mut self, span: SourceSpan) -> Self { self.context_span = Some(span); self } } /// Type mismatch error with rich diagnostics #[derive(Error, Debug, Diagnostic)] #[error("Type mismatch in expression")] #[diagnostic( code(compiler::typecheck::type_mismatch), url("https://example.com/errors/type-mismatch"), severity(Error) )] pub struct TypeMismatchError { #[source_code] src: NamedSource<String>, #[label(primary, "Expected type `{expected}` but found `{actual}`")] expr_span: SourceSpan, #[label("Expected due to this")] reason_span: Option<SourceSpan>, expected: Type, actual: Type, #[help] suggestion: Option<String>, } impl TypeMismatchError { pub fn new( filename: String, source: String, expr_span: SourceSpan, expected: Type, actual: Type, ) -> Self { let suggestion = match (&expected, &actual) { (Type::String, Type::Int) => Some("Try using `.to_string()` to convert".to_string()), (Type::Int, Type::String) => { Some("Try using `.parse::<i32>()?` to convert".to_string()) } (Type::Float, Type::Int) => Some("Try using `as f64` to convert".to_string()), _ => None, }; Self { src: NamedSource::new(filename, source), expr_span, reason_span: None, expected, actual, suggestion, } } pub fn with_reason(mut self, span: SourceSpan) -> Self { self.reason_span = Some(span); self } } /// Similar variable found in scope #[derive(Error, Debug, Diagnostic)] #[error("Similar variable `{name}` defined here")] #[diagnostic(severity(Warning))] struct SimilarVariable { #[label] span: SourceSpan, name: String, } impl UndefinedVariableError { pub fn new( filename: String, source: String, span: SourceSpan, name: String, similar: Vec<(&str, SourceSpan)>, ) -> Self { let suggestions = similar .iter() .map(|(name, _)| format!("`{}`", name)) .collect::<Vec<_>>() .join(", "); let similar_vars = similar .into_iter() .map(|(name, span)| SimilarVariable { span, name: name.to_string(), }) .collect(); Self { src: NamedSource::new(filename, source), var_span: span, name, suggestions, similar_vars, } } } /// Multiple errors collected together #[derive(Error, Debug, Diagnostic)] #[error("Multiple errors occurred during compilation")] #[diagnostic( code(compiler::multiple_errors), help("Fix the errors in order, as later errors may be caused by earlier ones") )] pub struct CompilationErrors { #[source_code] src: NamedSource<String>, #[related] errors: Vec<Box<dyn Diagnostic + Send + Sync>>, error_count: usize, warning_count: usize, } impl CompilationErrors { pub fn new(filename: String, source: String) -> Self { Self { src: NamedSource::new(filename, source), errors: Vec::new(), error_count: 0, warning_count: 0, } } pub fn push<E: Diagnostic + Send + Sync + 'static>(&mut self, error: E) { match error.severity() { Some(miette::Severity::Warning) => self.warning_count += 1, _ => self.error_count += 1, } self.errors.push(Box::new(error)); } pub fn is_empty(&self) -> bool { self.errors.is_empty() } } /// Borrow checker error #[derive(Error, Debug, Diagnostic)] #[error("Cannot borrow `{variable}` as mutable more than once")] #[diagnostic( code(compiler::borrow_check::multiple_mutable), url(docsrs), help("Consider using RefCell for interior mutability") )] pub struct BorrowError { #[source_code] src: NamedSource<String>, #[label(primary, "Second mutable borrow occurs here")] second_borrow: SourceSpan, #[label("First mutable borrow occurs here")] first_borrow: SourceSpan, #[label("First borrow later used here")] first_use: Option<SourceSpan>, variable: String, } impl BorrowError { pub fn new( filename: String, source: String, first_borrow: SourceSpan, second_borrow: SourceSpan, variable: String, ) -> Self { Self { src: NamedSource::new(filename, source), second_borrow, first_borrow, first_use: None, variable, } } pub fn with_first_use(mut self, span: SourceSpan) -> Self { self.first_use = Some(span); self } } /// Pattern matching exhaustiveness error #[derive(Error, Debug, Diagnostic)] #[error("Non-exhaustive patterns")] #[diagnostic(code(compiler::pattern_match::non_exhaustive))] pub struct NonExhaustiveMatch { #[source_code] src: NamedSource<String>, #[label(primary, "Pattern match is non-exhaustive")] match_span: SourceSpan, #[label(collection, "Missing pattern")] missing_patterns: Vec<LabeledSpan>, #[help] missing_list: String, } impl NonExhaustiveMatch { pub fn new( filename: String, source: String, match_span: SourceSpan, missing: Vec<String>, ) -> Self { let missing_patterns = missing .iter() .map(|_pattern| LabeledSpan::underline(match_span)) .collect(); let missing_list = format!( "Missing patterns: {}\n\nEnsure all cases are covered or add a wildcard pattern `_`", missing.iter().map(|p| format!("`{}`", p)).collect::<Vec<_>>().join(", ") ); Self { src: NamedSource::new(filename, source), match_span, missing_patterns, missing_list, } } } /// Import cycle detection #[derive(Error, Debug, Diagnostic)] #[error("Circular dependency detected")] #[diagnostic(code(compiler::imports::cycle), severity(Error))] pub struct CyclicImportError { #[source_code] src: NamedSource<String>, #[label(collection, "Module in cycle")] cycle_spans: Vec<LabeledSpan>, #[help] help_text: String, } impl CyclicImportError { pub fn new(filename: String, source: String, modules: Vec<(String, SourceSpan)>) -> Self { let cycle_spans = modules .iter() .enumerate() .map(|(i, (name, span))| { let next = &modules[(i + 1) % modules.len()].0; LabeledSpan::new( Some(format!("`{}` imports `{}`", name, next)), span.offset(), span.len(), ) }) .collect(); let module_list = modules .iter() .map(|(name, _)| name.as_str()) .collect::<Vec<_>>() .join(" -> "); Self { src: NamedSource::new(filename, source), cycle_spans, help_text: format!("Break the cycle: {} -> ...", module_list), } } } /// Deprecated feature warning #[derive(Error, Debug, Diagnostic)] #[error("Use of deprecated feature `{feature}`")] #[diagnostic(code(compiler::deprecated), severity(Warning))] pub struct DeprecationWarning { #[source_code] src: NamedSource<String>, #[label(primary, "Deprecated since version {since}")] usage_span: SourceSpan, feature: String, since: String, #[help] alternative: String, } impl DeprecationWarning { pub fn new( filename: String, source: String, usage_span: SourceSpan, feature: String, since: String, alternative: String, ) -> Self { Self { src: NamedSource::new(filename, source), usage_span, feature, since, alternative, } } } /// Dynamic diagnostic creation pub fn create_diagnostic( filename: String, source: String, span: SourceSpan, message: String, help: Option<String>, ) -> miette::Report { let labels = vec![LabeledSpan::underline(span)]; miette::miette!( labels = labels, help = help.unwrap_or_default(), "{}", message ) .with_source_code(NamedSource::new(filename, source)) } /// Syntax highlighting support pub fn create_highlighted_error( filename: String, source: String, span: SourceSpan, ) -> impl Diagnostic { let src = NamedSource::new(filename, source).with_language("rust"); #[derive(Error, Debug, Diagnostic)] #[error("Syntax error in Rust code")] #[diagnostic(code(compiler::syntax))] struct HighlightedError { #[source_code] src: NamedSource<String>, #[label("Invalid syntax here")] span: SourceSpan, } HighlightedError { src, span } } #[cfg(test)] mod tests { use super::*; #[test] fn test_type_display() { let func_type = Type::Function(vec![Type::Int, Type::String], Box::new(Type::Bool)); assert_eq!(func_type.to_string(), "fn(int, string) -> bool"); } #[test] fn test_parse_error_creation() { let error = ParseError::new( "test.rs".to_string(), "let x = ;".to_string(), (8, 1).into(), "expression".to_string(), "semicolon".to_string(), ); assert!(error.to_string().contains("Parse error")); } #[test] fn test_multiple_errors() { let mut errors = CompilationErrors::new("test.rs".to_string(), "code".to_string()); assert!(errors.is_empty()); let parse_err = ParseError::new( "test.rs".to_string(), "code".to_string(), (0, 4).into(), "identifier".to_string(), "keyword".to_string(), ); errors.push(parse_err); assert!(!errors.is_empty()); assert_eq!(errors.error_count, 1); assert_eq!(errors.warning_count, 0); } } /// Undefined variable error with suggestions #[derive(Error, Debug, Diagnostic)] #[error("Undefined variable `{name}`")] #[diagnostic( code(compiler::resolve::undefined_variable), help("Did you mean {suggestions}?") )] pub struct UndefinedVariableError { #[source_code] src: NamedSource<String>, #[label(primary, "Not found in this scope")] var_span: SourceSpan, name: String, suggestions: String, #[related] similar_vars: Vec<SimilarVariable>, } }
The #[related]
attribute allows including sub-diagnostics that provide additional context. Each related diagnostic can have its own spans and messages, creating a rich, multi-layered error report.
Collecting Multiple Errors
Real compilers often need to report multiple errors at once. Miette handles this elegantly through the related errors feature.
#![allow(unused)] fn main() { use std::fmt; use miette::{Diagnostic, LabeledSpan, NamedSource, SourceSpan}; use thiserror::Error; /// Type representation for our compiler #[derive(Debug, Clone, PartialEq)] pub enum Type { Int, Float, Bool, String, Array(Box<Type>), Function(Vec<Type>, Box<Type>), Struct(String), Never, } impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Type::Int => write!(f, "int"), Type::Float => write!(f, "float"), Type::Bool => write!(f, "bool"), Type::String => write!(f, "string"), Type::Array(elem) => write!(f, "[{}]", elem), Type::Function(params, ret) => { write!(f, "fn(")?; for (i, param) in params.iter().enumerate() { if i > 0 { write!(f, ", ")?; } write!(f, "{}", param)?; } write!(f, ") -> {}", ret) } Type::Struct(name) => write!(f, "{}", name), Type::Never => write!(f, "!"), } } } /// Parser error with detailed diagnostics #[derive(Error, Debug, Diagnostic)] #[error("Parse error in {filename}")] #[diagnostic( code(compiler::parse::syntax_error), url(docsrs), help("Check for missing semicolons, unmatched brackets, or typos in keywords") )] pub struct ParseError { #[source_code] src: NamedSource<String>, #[label("Expected {expected} but found {found}")] err_span: SourceSpan, expected: String, found: String, filename: String, #[label("Parsing started here")] context_span: Option<SourceSpan>, } impl ParseError { pub fn new( filename: String, source: String, span: SourceSpan, expected: String, found: String, ) -> Self { Self { src: NamedSource::new(filename.clone(), source), err_span: span, expected, found, filename, context_span: None, } } pub fn with_context(mut self, span: SourceSpan) -> Self { self.context_span = Some(span); self } } /// Type mismatch error with rich diagnostics #[derive(Error, Debug, Diagnostic)] #[error("Type mismatch in expression")] #[diagnostic( code(compiler::typecheck::type_mismatch), url("https://example.com/errors/type-mismatch"), severity(Error) )] pub struct TypeMismatchError { #[source_code] src: NamedSource<String>, #[label(primary, "Expected type `{expected}` but found `{actual}`")] expr_span: SourceSpan, #[label("Expected due to this")] reason_span: Option<SourceSpan>, expected: Type, actual: Type, #[help] suggestion: Option<String>, } impl TypeMismatchError { pub fn new( filename: String, source: String, expr_span: SourceSpan, expected: Type, actual: Type, ) -> Self { let suggestion = match (&expected, &actual) { (Type::String, Type::Int) => Some("Try using `.to_string()` to convert".to_string()), (Type::Int, Type::String) => { Some("Try using `.parse::<i32>()?` to convert".to_string()) } (Type::Float, Type::Int) => Some("Try using `as f64` to convert".to_string()), _ => None, }; Self { src: NamedSource::new(filename, source), expr_span, reason_span: None, expected, actual, suggestion, } } pub fn with_reason(mut self, span: SourceSpan) -> Self { self.reason_span = Some(span); self } } /// Undefined variable error with suggestions #[derive(Error, Debug, Diagnostic)] #[error("Undefined variable `{name}`")] #[diagnostic( code(compiler::resolve::undefined_variable), help("Did you mean {suggestions}?") )] pub struct UndefinedVariableError { #[source_code] src: NamedSource<String>, #[label(primary, "Not found in this scope")] var_span: SourceSpan, name: String, suggestions: String, #[related] similar_vars: Vec<SimilarVariable>, } /// Similar variable found in scope #[derive(Error, Debug, Diagnostic)] #[error("Similar variable `{name}` defined here")] #[diagnostic(severity(Warning))] struct SimilarVariable { #[label] span: SourceSpan, name: String, } impl UndefinedVariableError { pub fn new( filename: String, source: String, span: SourceSpan, name: String, similar: Vec<(&str, SourceSpan)>, ) -> Self { let suggestions = similar .iter() .map(|(name, _)| format!("`{}`", name)) .collect::<Vec<_>>() .join(", "); let similar_vars = similar .into_iter() .map(|(name, span)| SimilarVariable { span, name: name.to_string(), }) .collect(); Self { src: NamedSource::new(filename, source), var_span: span, name, suggestions, similar_vars, } } } impl CompilationErrors { pub fn new(filename: String, source: String) -> Self { Self { src: NamedSource::new(filename, source), errors: Vec::new(), error_count: 0, warning_count: 0, } } pub fn push<E: Diagnostic + Send + Sync + 'static>(&mut self, error: E) { match error.severity() { Some(miette::Severity::Warning) => self.warning_count += 1, _ => self.error_count += 1, } self.errors.push(Box::new(error)); } pub fn is_empty(&self) -> bool { self.errors.is_empty() } } /// Borrow checker error #[derive(Error, Debug, Diagnostic)] #[error("Cannot borrow `{variable}` as mutable more than once")] #[diagnostic( code(compiler::borrow_check::multiple_mutable), url(docsrs), help("Consider using RefCell for interior mutability") )] pub struct BorrowError { #[source_code] src: NamedSource<String>, #[label(primary, "Second mutable borrow occurs here")] second_borrow: SourceSpan, #[label("First mutable borrow occurs here")] first_borrow: SourceSpan, #[label("First borrow later used here")] first_use: Option<SourceSpan>, variable: String, } impl BorrowError { pub fn new( filename: String, source: String, first_borrow: SourceSpan, second_borrow: SourceSpan, variable: String, ) -> Self { Self { src: NamedSource::new(filename, source), second_borrow, first_borrow, first_use: None, variable, } } pub fn with_first_use(mut self, span: SourceSpan) -> Self { self.first_use = Some(span); self } } /// Pattern matching exhaustiveness error #[derive(Error, Debug, Diagnostic)] #[error("Non-exhaustive patterns")] #[diagnostic(code(compiler::pattern_match::non_exhaustive))] pub struct NonExhaustiveMatch { #[source_code] src: NamedSource<String>, #[label(primary, "Pattern match is non-exhaustive")] match_span: SourceSpan, #[label(collection, "Missing pattern")] missing_patterns: Vec<LabeledSpan>, #[help] missing_list: String, } impl NonExhaustiveMatch { pub fn new( filename: String, source: String, match_span: SourceSpan, missing: Vec<String>, ) -> Self { let missing_patterns = missing .iter() .map(|_pattern| LabeledSpan::underline(match_span)) .collect(); let missing_list = format!( "Missing patterns: {}\n\nEnsure all cases are covered or add a wildcard pattern `_`", missing.iter().map(|p| format!("`{}`", p)).collect::<Vec<_>>().join(", ") ); Self { src: NamedSource::new(filename, source), match_span, missing_patterns, missing_list, } } } /// Import cycle detection #[derive(Error, Debug, Diagnostic)] #[error("Circular dependency detected")] #[diagnostic(code(compiler::imports::cycle), severity(Error))] pub struct CyclicImportError { #[source_code] src: NamedSource<String>, #[label(collection, "Module in cycle")] cycle_spans: Vec<LabeledSpan>, #[help] help_text: String, } impl CyclicImportError { pub fn new(filename: String, source: String, modules: Vec<(String, SourceSpan)>) -> Self { let cycle_spans = modules .iter() .enumerate() .map(|(i, (name, span))| { let next = &modules[(i + 1) % modules.len()].0; LabeledSpan::new( Some(format!("`{}` imports `{}`", name, next)), span.offset(), span.len(), ) }) .collect(); let module_list = modules .iter() .map(|(name, _)| name.as_str()) .collect::<Vec<_>>() .join(" -> "); Self { src: NamedSource::new(filename, source), cycle_spans, help_text: format!("Break the cycle: {} -> ...", module_list), } } } /// Deprecated feature warning #[derive(Error, Debug, Diagnostic)] #[error("Use of deprecated feature `{feature}`")] #[diagnostic(code(compiler::deprecated), severity(Warning))] pub struct DeprecationWarning { #[source_code] src: NamedSource<String>, #[label(primary, "Deprecated since version {since}")] usage_span: SourceSpan, feature: String, since: String, #[help] alternative: String, } impl DeprecationWarning { pub fn new( filename: String, source: String, usage_span: SourceSpan, feature: String, since: String, alternative: String, ) -> Self { Self { src: NamedSource::new(filename, source), usage_span, feature, since, alternative, } } } /// Dynamic diagnostic creation pub fn create_diagnostic( filename: String, source: String, span: SourceSpan, message: String, help: Option<String>, ) -> miette::Report { let labels = vec![LabeledSpan::underline(span)]; miette::miette!( labels = labels, help = help.unwrap_or_default(), "{}", message ) .with_source_code(NamedSource::new(filename, source)) } /// Syntax highlighting support pub fn create_highlighted_error( filename: String, source: String, span: SourceSpan, ) -> impl Diagnostic { let src = NamedSource::new(filename, source).with_language("rust"); #[derive(Error, Debug, Diagnostic)] #[error("Syntax error in Rust code")] #[diagnostic(code(compiler::syntax))] struct HighlightedError { #[source_code] src: NamedSource<String>, #[label("Invalid syntax here")] span: SourceSpan, } HighlightedError { src, span } } #[cfg(test)] mod tests { use super::*; #[test] fn test_type_display() { let func_type = Type::Function(vec![Type::Int, Type::String], Box::new(Type::Bool)); assert_eq!(func_type.to_string(), "fn(int, string) -> bool"); } #[test] fn test_parse_error_creation() { let error = ParseError::new( "test.rs".to_string(), "let x = ;".to_string(), (8, 1).into(), "expression".to_string(), "semicolon".to_string(), ); assert!(error.to_string().contains("Parse error")); } #[test] fn test_multiple_errors() { let mut errors = CompilationErrors::new("test.rs".to_string(), "code".to_string()); assert!(errors.is_empty()); let parse_err = ParseError::new( "test.rs".to_string(), "code".to_string(), (0, 4).into(), "identifier".to_string(), "keyword".to_string(), ); errors.push(parse_err); assert!(!errors.is_empty()); assert_eq!(errors.error_count, 1); assert_eq!(errors.warning_count, 0); } } /// Multiple errors collected together #[derive(Error, Debug, Diagnostic)] #[error("Multiple errors occurred during compilation")] #[diagnostic( code(compiler::multiple_errors), help("Fix the errors in order, as later errors may be caused by earlier ones") )] pub struct CompilationErrors { #[source_code] src: NamedSource<String>, #[related] errors: Vec<Box<dyn Diagnostic + Send + Sync>>, error_count: usize, warning_count: usize, } }
This pattern allows accumulating errors during compilation and reporting them all together. The dynamic dispatch through Box<dyn Diagnostic>
means you can mix different error types in the same collection.
Borrow Checker Diagnostics
Complex diagnostics like borrow checker errors benefit from multiple labeled spans showing the relationship between different code locations.
#![allow(unused)] fn main() { use std::fmt; use miette::{Diagnostic, LabeledSpan, NamedSource, SourceSpan}; use thiserror::Error; /// Type representation for our compiler #[derive(Debug, Clone, PartialEq)] pub enum Type { Int, Float, Bool, String, Array(Box<Type>), Function(Vec<Type>, Box<Type>), Struct(String), Never, } impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Type::Int => write!(f, "int"), Type::Float => write!(f, "float"), Type::Bool => write!(f, "bool"), Type::String => write!(f, "string"), Type::Array(elem) => write!(f, "[{}]", elem), Type::Function(params, ret) => { write!(f, "fn(")?; for (i, param) in params.iter().enumerate() { if i > 0 { write!(f, ", ")?; } write!(f, "{}", param)?; } write!(f, ") -> {}", ret) } Type::Struct(name) => write!(f, "{}", name), Type::Never => write!(f, "!"), } } } /// Parser error with detailed diagnostics #[derive(Error, Debug, Diagnostic)] #[error("Parse error in {filename}")] #[diagnostic( code(compiler::parse::syntax_error), url(docsrs), help("Check for missing semicolons, unmatched brackets, or typos in keywords") )] pub struct ParseError { #[source_code] src: NamedSource<String>, #[label("Expected {expected} but found {found}")] err_span: SourceSpan, expected: String, found: String, filename: String, #[label("Parsing started here")] context_span: Option<SourceSpan>, } impl ParseError { pub fn new( filename: String, source: String, span: SourceSpan, expected: String, found: String, ) -> Self { Self { src: NamedSource::new(filename.clone(), source), err_span: span, expected, found, filename, context_span: None, } } pub fn with_context(mut self, span: SourceSpan) -> Self { self.context_span = Some(span); self } } /// Type mismatch error with rich diagnostics #[derive(Error, Debug, Diagnostic)] #[error("Type mismatch in expression")] #[diagnostic( code(compiler::typecheck::type_mismatch), url("https://example.com/errors/type-mismatch"), severity(Error) )] pub struct TypeMismatchError { #[source_code] src: NamedSource<String>, #[label(primary, "Expected type `{expected}` but found `{actual}`")] expr_span: SourceSpan, #[label("Expected due to this")] reason_span: Option<SourceSpan>, expected: Type, actual: Type, #[help] suggestion: Option<String>, } impl TypeMismatchError { pub fn new( filename: String, source: String, expr_span: SourceSpan, expected: Type, actual: Type, ) -> Self { let suggestion = match (&expected, &actual) { (Type::String, Type::Int) => Some("Try using `.to_string()` to convert".to_string()), (Type::Int, Type::String) => { Some("Try using `.parse::<i32>()?` to convert".to_string()) } (Type::Float, Type::Int) => Some("Try using `as f64` to convert".to_string()), _ => None, }; Self { src: NamedSource::new(filename, source), expr_span, reason_span: None, expected, actual, suggestion, } } pub fn with_reason(mut self, span: SourceSpan) -> Self { self.reason_span = Some(span); self } } /// Undefined variable error with suggestions #[derive(Error, Debug, Diagnostic)] #[error("Undefined variable `{name}`")] #[diagnostic( code(compiler::resolve::undefined_variable), help("Did you mean {suggestions}?") )] pub struct UndefinedVariableError { #[source_code] src: NamedSource<String>, #[label(primary, "Not found in this scope")] var_span: SourceSpan, name: String, suggestions: String, #[related] similar_vars: Vec<SimilarVariable>, } /// Similar variable found in scope #[derive(Error, Debug, Diagnostic)] #[error("Similar variable `{name}` defined here")] #[diagnostic(severity(Warning))] struct SimilarVariable { #[label] span: SourceSpan, name: String, } impl UndefinedVariableError { pub fn new( filename: String, source: String, span: SourceSpan, name: String, similar: Vec<(&str, SourceSpan)>, ) -> Self { let suggestions = similar .iter() .map(|(name, _)| format!("`{}`", name)) .collect::<Vec<_>>() .join(", "); let similar_vars = similar .into_iter() .map(|(name, span)| SimilarVariable { span, name: name.to_string(), }) .collect(); Self { src: NamedSource::new(filename, source), var_span: span, name, suggestions, similar_vars, } } } /// Multiple errors collected together #[derive(Error, Debug, Diagnostic)] #[error("Multiple errors occurred during compilation")] #[diagnostic( code(compiler::multiple_errors), help("Fix the errors in order, as later errors may be caused by earlier ones") )] pub struct CompilationErrors { #[source_code] src: NamedSource<String>, #[related] errors: Vec<Box<dyn Diagnostic + Send + Sync>>, error_count: usize, warning_count: usize, } impl CompilationErrors { pub fn new(filename: String, source: String) -> Self { Self { src: NamedSource::new(filename, source), errors: Vec::new(), error_count: 0, warning_count: 0, } } pub fn push<E: Diagnostic + Send + Sync + 'static>(&mut self, error: E) { match error.severity() { Some(miette::Severity::Warning) => self.warning_count += 1, _ => self.error_count += 1, } self.errors.push(Box::new(error)); } pub fn is_empty(&self) -> bool { self.errors.is_empty() } } impl BorrowError { pub fn new( filename: String, source: String, first_borrow: SourceSpan, second_borrow: SourceSpan, variable: String, ) -> Self { Self { src: NamedSource::new(filename, source), second_borrow, first_borrow, first_use: None, variable, } } pub fn with_first_use(mut self, span: SourceSpan) -> Self { self.first_use = Some(span); self } } /// Pattern matching exhaustiveness error #[derive(Error, Debug, Diagnostic)] #[error("Non-exhaustive patterns")] #[diagnostic(code(compiler::pattern_match::non_exhaustive))] pub struct NonExhaustiveMatch { #[source_code] src: NamedSource<String>, #[label(primary, "Pattern match is non-exhaustive")] match_span: SourceSpan, #[label(collection, "Missing pattern")] missing_patterns: Vec<LabeledSpan>, #[help] missing_list: String, } impl NonExhaustiveMatch { pub fn new( filename: String, source: String, match_span: SourceSpan, missing: Vec<String>, ) -> Self { let missing_patterns = missing .iter() .map(|_pattern| LabeledSpan::underline(match_span)) .collect(); let missing_list = format!( "Missing patterns: {}\n\nEnsure all cases are covered or add a wildcard pattern `_`", missing.iter().map(|p| format!("`{}`", p)).collect::<Vec<_>>().join(", ") ); Self { src: NamedSource::new(filename, source), match_span, missing_patterns, missing_list, } } } /// Import cycle detection #[derive(Error, Debug, Diagnostic)] #[error("Circular dependency detected")] #[diagnostic(code(compiler::imports::cycle), severity(Error))] pub struct CyclicImportError { #[source_code] src: NamedSource<String>, #[label(collection, "Module in cycle")] cycle_spans: Vec<LabeledSpan>, #[help] help_text: String, } impl CyclicImportError { pub fn new(filename: String, source: String, modules: Vec<(String, SourceSpan)>) -> Self { let cycle_spans = modules .iter() .enumerate() .map(|(i, (name, span))| { let next = &modules[(i + 1) % modules.len()].0; LabeledSpan::new( Some(format!("`{}` imports `{}`", name, next)), span.offset(), span.len(), ) }) .collect(); let module_list = modules .iter() .map(|(name, _)| name.as_str()) .collect::<Vec<_>>() .join(" -> "); Self { src: NamedSource::new(filename, source), cycle_spans, help_text: format!("Break the cycle: {} -> ...", module_list), } } } /// Deprecated feature warning #[derive(Error, Debug, Diagnostic)] #[error("Use of deprecated feature `{feature}`")] #[diagnostic(code(compiler::deprecated), severity(Warning))] pub struct DeprecationWarning { #[source_code] src: NamedSource<String>, #[label(primary, "Deprecated since version {since}")] usage_span: SourceSpan, feature: String, since: String, #[help] alternative: String, } impl DeprecationWarning { pub fn new( filename: String, source: String, usage_span: SourceSpan, feature: String, since: String, alternative: String, ) -> Self { Self { src: NamedSource::new(filename, source), usage_span, feature, since, alternative, } } } /// Dynamic diagnostic creation pub fn create_diagnostic( filename: String, source: String, span: SourceSpan, message: String, help: Option<String>, ) -> miette::Report { let labels = vec![LabeledSpan::underline(span)]; miette::miette!( labels = labels, help = help.unwrap_or_default(), "{}", message ) .with_source_code(NamedSource::new(filename, source)) } /// Syntax highlighting support pub fn create_highlighted_error( filename: String, source: String, span: SourceSpan, ) -> impl Diagnostic { let src = NamedSource::new(filename, source).with_language("rust"); #[derive(Error, Debug, Diagnostic)] #[error("Syntax error in Rust code")] #[diagnostic(code(compiler::syntax))] struct HighlightedError { #[source_code] src: NamedSource<String>, #[label("Invalid syntax here")] span: SourceSpan, } HighlightedError { src, span } } #[cfg(test)] mod tests { use super::*; #[test] fn test_type_display() { let func_type = Type::Function(vec![Type::Int, Type::String], Box::new(Type::Bool)); assert_eq!(func_type.to_string(), "fn(int, string) -> bool"); } #[test] fn test_parse_error_creation() { let error = ParseError::new( "test.rs".to_string(), "let x = ;".to_string(), (8, 1).into(), "expression".to_string(), "semicolon".to_string(), ); assert!(error.to_string().contains("Parse error")); } #[test] fn test_multiple_errors() { let mut errors = CompilationErrors::new("test.rs".to_string(), "code".to_string()); assert!(errors.is_empty()); let parse_err = ParseError::new( "test.rs".to_string(), "code".to_string(), (0, 4).into(), "identifier".to_string(), "keyword".to_string(), ); errors.push(parse_err); assert!(!errors.is_empty()); assert_eq!(errors.error_count, 1); assert_eq!(errors.warning_count, 0); } } /// Borrow checker error #[derive(Error, Debug, Diagnostic)] #[error("Cannot borrow `{variable}` as mutable more than once")] #[diagnostic( code(compiler::borrow_check::multiple_mutable), url(docsrs), help("Consider using RefCell for interior mutability") )] pub struct BorrowError { #[source_code] src: NamedSource<String>, #[label(primary, "Second mutable borrow occurs here")] second_borrow: SourceSpan, #[label("First mutable borrow occurs here")] first_borrow: SourceSpan, #[label("First borrow later used here")] first_use: Option<SourceSpan>, variable: String, } }
Multiple labels with different roles (primary vs secondary) help users understand the flow of borrows through their code. The optional spans allow for cases where some information might not be available.
Pattern Matching Exhaustiveness
The collection label feature is perfect for showing multiple related locations, such as missing patterns in a match expression.
#![allow(unused)] fn main() { use std::fmt; use miette::{Diagnostic, LabeledSpan, NamedSource, SourceSpan}; use thiserror::Error; /// Type representation for our compiler #[derive(Debug, Clone, PartialEq)] pub enum Type { Int, Float, Bool, String, Array(Box<Type>), Function(Vec<Type>, Box<Type>), Struct(String), Never, } impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Type::Int => write!(f, "int"), Type::Float => write!(f, "float"), Type::Bool => write!(f, "bool"), Type::String => write!(f, "string"), Type::Array(elem) => write!(f, "[{}]", elem), Type::Function(params, ret) => { write!(f, "fn(")?; for (i, param) in params.iter().enumerate() { if i > 0 { write!(f, ", ")?; } write!(f, "{}", param)?; } write!(f, ") -> {}", ret) } Type::Struct(name) => write!(f, "{}", name), Type::Never => write!(f, "!"), } } } /// Parser error with detailed diagnostics #[derive(Error, Debug, Diagnostic)] #[error("Parse error in {filename}")] #[diagnostic( code(compiler::parse::syntax_error), url(docsrs), help("Check for missing semicolons, unmatched brackets, or typos in keywords") )] pub struct ParseError { #[source_code] src: NamedSource<String>, #[label("Expected {expected} but found {found}")] err_span: SourceSpan, expected: String, found: String, filename: String, #[label("Parsing started here")] context_span: Option<SourceSpan>, } impl ParseError { pub fn new( filename: String, source: String, span: SourceSpan, expected: String, found: String, ) -> Self { Self { src: NamedSource::new(filename.clone(), source), err_span: span, expected, found, filename, context_span: None, } } pub fn with_context(mut self, span: SourceSpan) -> Self { self.context_span = Some(span); self } } /// Type mismatch error with rich diagnostics #[derive(Error, Debug, Diagnostic)] #[error("Type mismatch in expression")] #[diagnostic( code(compiler::typecheck::type_mismatch), url("https://example.com/errors/type-mismatch"), severity(Error) )] pub struct TypeMismatchError { #[source_code] src: NamedSource<String>, #[label(primary, "Expected type `{expected}` but found `{actual}`")] expr_span: SourceSpan, #[label("Expected due to this")] reason_span: Option<SourceSpan>, expected: Type, actual: Type, #[help] suggestion: Option<String>, } impl TypeMismatchError { pub fn new( filename: String, source: String, expr_span: SourceSpan, expected: Type, actual: Type, ) -> Self { let suggestion = match (&expected, &actual) { (Type::String, Type::Int) => Some("Try using `.to_string()` to convert".to_string()), (Type::Int, Type::String) => { Some("Try using `.parse::<i32>()?` to convert".to_string()) } (Type::Float, Type::Int) => Some("Try using `as f64` to convert".to_string()), _ => None, }; Self { src: NamedSource::new(filename, source), expr_span, reason_span: None, expected, actual, suggestion, } } pub fn with_reason(mut self, span: SourceSpan) -> Self { self.reason_span = Some(span); self } } /// Undefined variable error with suggestions #[derive(Error, Debug, Diagnostic)] #[error("Undefined variable `{name}`")] #[diagnostic( code(compiler::resolve::undefined_variable), help("Did you mean {suggestions}?") )] pub struct UndefinedVariableError { #[source_code] src: NamedSource<String>, #[label(primary, "Not found in this scope")] var_span: SourceSpan, name: String, suggestions: String, #[related] similar_vars: Vec<SimilarVariable>, } /// Similar variable found in scope #[derive(Error, Debug, Diagnostic)] #[error("Similar variable `{name}` defined here")] #[diagnostic(severity(Warning))] struct SimilarVariable { #[label] span: SourceSpan, name: String, } impl UndefinedVariableError { pub fn new( filename: String, source: String, span: SourceSpan, name: String, similar: Vec<(&str, SourceSpan)>, ) -> Self { let suggestions = similar .iter() .map(|(name, _)| format!("`{}`", name)) .collect::<Vec<_>>() .join(", "); let similar_vars = similar .into_iter() .map(|(name, span)| SimilarVariable { span, name: name.to_string(), }) .collect(); Self { src: NamedSource::new(filename, source), var_span: span, name, suggestions, similar_vars, } } } /// Multiple errors collected together #[derive(Error, Debug, Diagnostic)] #[error("Multiple errors occurred during compilation")] #[diagnostic( code(compiler::multiple_errors), help("Fix the errors in order, as later errors may be caused by earlier ones") )] pub struct CompilationErrors { #[source_code] src: NamedSource<String>, #[related] errors: Vec<Box<dyn Diagnostic + Send + Sync>>, error_count: usize, warning_count: usize, } impl CompilationErrors { pub fn new(filename: String, source: String) -> Self { Self { src: NamedSource::new(filename, source), errors: Vec::new(), error_count: 0, warning_count: 0, } } pub fn push<E: Diagnostic + Send + Sync + 'static>(&mut self, error: E) { match error.severity() { Some(miette::Severity::Warning) => self.warning_count += 1, _ => self.error_count += 1, } self.errors.push(Box::new(error)); } pub fn is_empty(&self) -> bool { self.errors.is_empty() } } /// Borrow checker error #[derive(Error, Debug, Diagnostic)] #[error("Cannot borrow `{variable}` as mutable more than once")] #[diagnostic( code(compiler::borrow_check::multiple_mutable), url(docsrs), help("Consider using RefCell for interior mutability") )] pub struct BorrowError { #[source_code] src: NamedSource<String>, #[label(primary, "Second mutable borrow occurs here")] second_borrow: SourceSpan, #[label("First mutable borrow occurs here")] first_borrow: SourceSpan, #[label("First borrow later used here")] first_use: Option<SourceSpan>, variable: String, } impl BorrowError { pub fn new( filename: String, source: String, first_borrow: SourceSpan, second_borrow: SourceSpan, variable: String, ) -> Self { Self { src: NamedSource::new(filename, source), second_borrow, first_borrow, first_use: None, variable, } } pub fn with_first_use(mut self, span: SourceSpan) -> Self { self.first_use = Some(span); self } } impl NonExhaustiveMatch { pub fn new( filename: String, source: String, match_span: SourceSpan, missing: Vec<String>, ) -> Self { let missing_patterns = missing .iter() .map(|_pattern| LabeledSpan::underline(match_span)) .collect(); let missing_list = format!( "Missing patterns: {}\n\nEnsure all cases are covered or add a wildcard pattern `_`", missing.iter().map(|p| format!("`{}`", p)).collect::<Vec<_>>().join(", ") ); Self { src: NamedSource::new(filename, source), match_span, missing_patterns, missing_list, } } } /// Import cycle detection #[derive(Error, Debug, Diagnostic)] #[error("Circular dependency detected")] #[diagnostic(code(compiler::imports::cycle), severity(Error))] pub struct CyclicImportError { #[source_code] src: NamedSource<String>, #[label(collection, "Module in cycle")] cycle_spans: Vec<LabeledSpan>, #[help] help_text: String, } impl CyclicImportError { pub fn new(filename: String, source: String, modules: Vec<(String, SourceSpan)>) -> Self { let cycle_spans = modules .iter() .enumerate() .map(|(i, (name, span))| { let next = &modules[(i + 1) % modules.len()].0; LabeledSpan::new( Some(format!("`{}` imports `{}`", name, next)), span.offset(), span.len(), ) }) .collect(); let module_list = modules .iter() .map(|(name, _)| name.as_str()) .collect::<Vec<_>>() .join(" -> "); Self { src: NamedSource::new(filename, source), cycle_spans, help_text: format!("Break the cycle: {} -> ...", module_list), } } } /// Deprecated feature warning #[derive(Error, Debug, Diagnostic)] #[error("Use of deprecated feature `{feature}`")] #[diagnostic(code(compiler::deprecated), severity(Warning))] pub struct DeprecationWarning { #[source_code] src: NamedSource<String>, #[label(primary, "Deprecated since version {since}")] usage_span: SourceSpan, feature: String, since: String, #[help] alternative: String, } impl DeprecationWarning { pub fn new( filename: String, source: String, usage_span: SourceSpan, feature: String, since: String, alternative: String, ) -> Self { Self { src: NamedSource::new(filename, source), usage_span, feature, since, alternative, } } } /// Dynamic diagnostic creation pub fn create_diagnostic( filename: String, source: String, span: SourceSpan, message: String, help: Option<String>, ) -> miette::Report { let labels = vec![LabeledSpan::underline(span)]; miette::miette!( labels = labels, help = help.unwrap_or_default(), "{}", message ) .with_source_code(NamedSource::new(filename, source)) } /// Syntax highlighting support pub fn create_highlighted_error( filename: String, source: String, span: SourceSpan, ) -> impl Diagnostic { let src = NamedSource::new(filename, source).with_language("rust"); #[derive(Error, Debug, Diagnostic)] #[error("Syntax error in Rust code")] #[diagnostic(code(compiler::syntax))] struct HighlightedError { #[source_code] src: NamedSource<String>, #[label("Invalid syntax here")] span: SourceSpan, } HighlightedError { src, span } } #[cfg(test)] mod tests { use super::*; #[test] fn test_type_display() { let func_type = Type::Function(vec![Type::Int, Type::String], Box::new(Type::Bool)); assert_eq!(func_type.to_string(), "fn(int, string) -> bool"); } #[test] fn test_parse_error_creation() { let error = ParseError::new( "test.rs".to_string(), "let x = ;".to_string(), (8, 1).into(), "expression".to_string(), "semicolon".to_string(), ); assert!(error.to_string().contains("Parse error")); } #[test] fn test_multiple_errors() { let mut errors = CompilationErrors::new("test.rs".to_string(), "code".to_string()); assert!(errors.is_empty()); let parse_err = ParseError::new( "test.rs".to_string(), "code".to_string(), (0, 4).into(), "identifier".to_string(), "keyword".to_string(), ); errors.push(parse_err); assert!(!errors.is_empty()); assert_eq!(errors.error_count, 1); assert_eq!(errors.warning_count, 0); } } /// Pattern matching exhaustiveness error #[derive(Error, Debug, Diagnostic)] #[error("Non-exhaustive patterns")] #[diagnostic(code(compiler::pattern_match::non_exhaustive))] pub struct NonExhaustiveMatch { #[source_code] src: NamedSource<String>, #[label(primary, "Pattern match is non-exhaustive")] match_span: SourceSpan, #[label(collection, "Missing pattern")] missing_patterns: Vec<LabeledSpan>, #[help] missing_list: String, } }
The #[label(collection)]
attribute works with any iterator of spans, making it easy to highlight multiple locations with similar issues.
Integration with Standard Errors
Miette integrates seamlessly with Rust’s error handling ecosystem. The IntoDiagnostic
trait allows converting any standard error into a diagnostic, while the Result
type alias provides ergonomic error handling.
#![allow(unused)] fn main() { use std::fmt; use miette::{Diagnostic, LabeledSpan, NamedSource, SourceSpan}; use thiserror::Error; /// Type representation for our compiler #[derive(Debug, Clone, PartialEq)] pub enum Type { Int, Float, Bool, String, Array(Box<Type>), Function(Vec<Type>, Box<Type>), Struct(String), Never, } impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Type::Int => write!(f, "int"), Type::Float => write!(f, "float"), Type::Bool => write!(f, "bool"), Type::String => write!(f, "string"), Type::Array(elem) => write!(f, "[{}]", elem), Type::Function(params, ret) => { write!(f, "fn(")?; for (i, param) in params.iter().enumerate() { if i > 0 { write!(f, ", ")?; } write!(f, "{}", param)?; } write!(f, ") -> {}", ret) } Type::Struct(name) => write!(f, "{}", name), Type::Never => write!(f, "!"), } } } /// Parser error with detailed diagnostics #[derive(Error, Debug, Diagnostic)] #[error("Parse error in {filename}")] #[diagnostic( code(compiler::parse::syntax_error), url(docsrs), help("Check for missing semicolons, unmatched brackets, or typos in keywords") )] pub struct ParseError { #[source_code] src: NamedSource<String>, #[label("Expected {expected} but found {found}")] err_span: SourceSpan, expected: String, found: String, filename: String, #[label("Parsing started here")] context_span: Option<SourceSpan>, } impl ParseError { pub fn new( filename: String, source: String, span: SourceSpan, expected: String, found: String, ) -> Self { Self { src: NamedSource::new(filename.clone(), source), err_span: span, expected, found, filename, context_span: None, } } pub fn with_context(mut self, span: SourceSpan) -> Self { self.context_span = Some(span); self } } /// Type mismatch error with rich diagnostics #[derive(Error, Debug, Diagnostic)] #[error("Type mismatch in expression")] #[diagnostic( code(compiler::typecheck::type_mismatch), url("https://example.com/errors/type-mismatch"), severity(Error) )] pub struct TypeMismatchError { #[source_code] src: NamedSource<String>, #[label(primary, "Expected type `{expected}` but found `{actual}`")] expr_span: SourceSpan, #[label("Expected due to this")] reason_span: Option<SourceSpan>, expected: Type, actual: Type, #[help] suggestion: Option<String>, } impl TypeMismatchError { pub fn new( filename: String, source: String, expr_span: SourceSpan, expected: Type, actual: Type, ) -> Self { let suggestion = match (&expected, &actual) { (Type::String, Type::Int) => Some("Try using `.to_string()` to convert".to_string()), (Type::Int, Type::String) => { Some("Try using `.parse::<i32>()?` to convert".to_string()) } (Type::Float, Type::Int) => Some("Try using `as f64` to convert".to_string()), _ => None, }; Self { src: NamedSource::new(filename, source), expr_span, reason_span: None, expected, actual, suggestion, } } pub fn with_reason(mut self, span: SourceSpan) -> Self { self.reason_span = Some(span); self } } /// Undefined variable error with suggestions #[derive(Error, Debug, Diagnostic)] #[error("Undefined variable `{name}`")] #[diagnostic( code(compiler::resolve::undefined_variable), help("Did you mean {suggestions}?") )] pub struct UndefinedVariableError { #[source_code] src: NamedSource<String>, #[label(primary, "Not found in this scope")] var_span: SourceSpan, name: String, suggestions: String, #[related] similar_vars: Vec<SimilarVariable>, } /// Similar variable found in scope #[derive(Error, Debug, Diagnostic)] #[error("Similar variable `{name}` defined here")] #[diagnostic(severity(Warning))] struct SimilarVariable { #[label] span: SourceSpan, name: String, } impl UndefinedVariableError { pub fn new( filename: String, source: String, span: SourceSpan, name: String, similar: Vec<(&str, SourceSpan)>, ) -> Self { let suggestions = similar .iter() .map(|(name, _)| format!("`{}`", name)) .collect::<Vec<_>>() .join(", "); let similar_vars = similar .into_iter() .map(|(name, span)| SimilarVariable { span, name: name.to_string(), }) .collect(); Self { src: NamedSource::new(filename, source), var_span: span, name, suggestions, similar_vars, } } } /// Multiple errors collected together #[derive(Error, Debug, Diagnostic)] #[error("Multiple errors occurred during compilation")] #[diagnostic( code(compiler::multiple_errors), help("Fix the errors in order, as later errors may be caused by earlier ones") )] pub struct CompilationErrors { #[source_code] src: NamedSource<String>, #[related] errors: Vec<Box<dyn Diagnostic + Send + Sync>>, error_count: usize, warning_count: usize, } impl CompilationErrors { pub fn new(filename: String, source: String) -> Self { Self { src: NamedSource::new(filename, source), errors: Vec::new(), error_count: 0, warning_count: 0, } } pub fn push<E: Diagnostic + Send + Sync + 'static>(&mut self, error: E) { match error.severity() { Some(miette::Severity::Warning) => self.warning_count += 1, _ => self.error_count += 1, } self.errors.push(Box::new(error)); } pub fn is_empty(&self) -> bool { self.errors.is_empty() } } /// Borrow checker error #[derive(Error, Debug, Diagnostic)] #[error("Cannot borrow `{variable}` as mutable more than once")] #[diagnostic( code(compiler::borrow_check::multiple_mutable), url(docsrs), help("Consider using RefCell for interior mutability") )] pub struct BorrowError { #[source_code] src: NamedSource<String>, #[label(primary, "Second mutable borrow occurs here")] second_borrow: SourceSpan, #[label("First mutable borrow occurs here")] first_borrow: SourceSpan, #[label("First borrow later used here")] first_use: Option<SourceSpan>, variable: String, } impl BorrowError { pub fn new( filename: String, source: String, first_borrow: SourceSpan, second_borrow: SourceSpan, variable: String, ) -> Self { Self { src: NamedSource::new(filename, source), second_borrow, first_borrow, first_use: None, variable, } } pub fn with_first_use(mut self, span: SourceSpan) -> Self { self.first_use = Some(span); self } } /// Pattern matching exhaustiveness error #[derive(Error, Debug, Diagnostic)] #[error("Non-exhaustive patterns")] #[diagnostic(code(compiler::pattern_match::non_exhaustive))] pub struct NonExhaustiveMatch { #[source_code] src: NamedSource<String>, #[label(primary, "Pattern match is non-exhaustive")] match_span: SourceSpan, #[label(collection, "Missing pattern")] missing_patterns: Vec<LabeledSpan>, #[help] missing_list: String, } impl NonExhaustiveMatch { pub fn new( filename: String, source: String, match_span: SourceSpan, missing: Vec<String>, ) -> Self { let missing_patterns = missing .iter() .map(|_pattern| LabeledSpan::underline(match_span)) .collect(); let missing_list = format!( "Missing patterns: {}\n\nEnsure all cases are covered or add a wildcard pattern `_`", missing.iter().map(|p| format!("`{}`", p)).collect::<Vec<_>>().join(", ") ); Self { src: NamedSource::new(filename, source), match_span, missing_patterns, missing_list, } } } /// Import cycle detection #[derive(Error, Debug, Diagnostic)] #[error("Circular dependency detected")] #[diagnostic(code(compiler::imports::cycle), severity(Error))] pub struct CyclicImportError { #[source_code] src: NamedSource<String>, #[label(collection, "Module in cycle")] cycle_spans: Vec<LabeledSpan>, #[help] help_text: String, } impl CyclicImportError { pub fn new(filename: String, source: String, modules: Vec<(String, SourceSpan)>) -> Self { let cycle_spans = modules .iter() .enumerate() .map(|(i, (name, span))| { let next = &modules[(i + 1) % modules.len()].0; LabeledSpan::new( Some(format!("`{}` imports `{}`", name, next)), span.offset(), span.len(), ) }) .collect(); let module_list = modules .iter() .map(|(name, _)| name.as_str()) .collect::<Vec<_>>() .join(" -> "); Self { src: NamedSource::new(filename, source), cycle_spans, help_text: format!("Break the cycle: {} -> ...", module_list), } } } /// Deprecated feature warning #[derive(Error, Debug, Diagnostic)] #[error("Use of deprecated feature `{feature}`")] #[diagnostic(code(compiler::deprecated), severity(Warning))] pub struct DeprecationWarning { #[source_code] src: NamedSource<String>, #[label(primary, "Deprecated since version {since}")] usage_span: SourceSpan, feature: String, since: String, #[help] alternative: String, } impl DeprecationWarning { pub fn new( filename: String, source: String, usage_span: SourceSpan, feature: String, since: String, alternative: String, ) -> Self { Self { src: NamedSource::new(filename, source), usage_span, feature, since, alternative, } } } /// Syntax highlighting support pub fn create_highlighted_error( filename: String, source: String, span: SourceSpan, ) -> impl Diagnostic { let src = NamedSource::new(filename, source).with_language("rust"); #[derive(Error, Debug, Diagnostic)] #[error("Syntax error in Rust code")] #[diagnostic(code(compiler::syntax))] struct HighlightedError { #[source_code] src: NamedSource<String>, #[label("Invalid syntax here")] span: SourceSpan, } HighlightedError { src, span } } #[cfg(test)] mod tests { use super::*; #[test] fn test_type_display() { let func_type = Type::Function(vec![Type::Int, Type::String], Box::new(Type::Bool)); assert_eq!(func_type.to_string(), "fn(int, string) -> bool"); } #[test] fn test_parse_error_creation() { let error = ParseError::new( "test.rs".to_string(), "let x = ;".to_string(), (8, 1).into(), "expression".to_string(), "semicolon".to_string(), ); assert!(error.to_string().contains("Parse error")); } #[test] fn test_multiple_errors() { let mut errors = CompilationErrors::new("test.rs".to_string(), "code".to_string()); assert!(errors.is_empty()); let parse_err = ParseError::new( "test.rs".to_string(), "code".to_string(), (0, 4).into(), "identifier".to_string(), "keyword".to_string(), ); errors.push(parse_err); assert!(!errors.is_empty()); assert_eq!(errors.error_count, 1); assert_eq!(errors.warning_count, 0); } } /// Dynamic diagnostic creation pub fn create_diagnostic( filename: String, source: String, span: SourceSpan, message: String, help: Option<String>, ) -> miette::Report { let labels = vec![LabeledSpan::underline(span)]; miette::miette!( labels = labels, help = help.unwrap_or_default(), "{}", message ) .with_source_code(NamedSource::new(filename, source)) } }
This function shows how to create diagnostics dynamically when you don’t know the error structure at compile time. It’s useful for scripting languages or plugin systems where errors are defined at runtime.
Screen Reader Support
Miette automatically detects when to use its screen-reader-friendly output format based on environment variables and terminal capabilities. This ensures your compiler is accessible to all users without additional configuration.
The narratable output format presents all the same information as the graphical format but in a linear, screen-reader-friendly way. Error codes become clickable links in terminals that support them, improving the documentation discovery experience.
Best Practices
Structure your diagnostics hierarchically. Top-level errors should provide overview information, while related errors can provide specific details. This helps users understand both the big picture and the specifics.
Use error codes consistently and link them to documentation. The url(docsrs)
shorthand automatically generates links to your docs.rs documentation, making it easy for users to find detailed explanations.
Provide actionable help text. Instead of just describing what went wrong, suggest how to fix it. Include example code in help messages when appropriate.
Keep source spans accurate. Miette’s snippet rendering is only as good as the spans you provide. Take care to highlight exactly the relevant code, neither too much nor too little.
Use severity levels appropriately. Errors should block compilation, warnings should indicate potential issues, and notes should provide supplementary information. The fancy renderer uses different colors for each severity level.
For library code, always return concrete error types that implement Diagnostic. This gives consumers the flexibility to handle errors programmatically or render them with miette. Application code can use the more convenient Result
type alias and error conversion utilities.
Miette has become essential infrastructure for Rust projects that prioritize user experience. Its thoughtful design and comprehensive features make it possible to create compiler diagnostics that genuinely help users understand and fix problems in their code.
bitflags
The bitflags
crate provides a macro for generating type-safe bitmask structures. In compiler development, bitflags are essential for efficiently representing sets of boolean options or attributes that can be combined. Common use cases include representing file permissions, compiler optimization flags, access modifiers in programming languages, or node attributes in abstract syntax trees.
The primary advantage of bitflags
over manual bit manipulation is type safety. Instead of working with raw integer constants and bitwise operations that can lead to errors, bitflags
generates strongly-typed structures that prevent invalid flag combinations at compile time.
The type safety provided by bitflags
becomes particularly valuable in large compiler codebases where flags may be passed through multiple layers of abstraction. The compiler ensures you cannot accidentally mix incompatible flag types or use undefined flag values.
Basic Usage
The bitflags!
macro generates a struct that wraps an integer type and provides methods for safely manipulating sets of flags:
#![allow(unused)] fn main() { use bitflags::bitflags; bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct FilePermissions: u32 { const READ = 0b0000_0001; const WRITE = 0b0000_0010; const EXECUTE = 0b0000_0100; const READ_WRITE = Self::READ.bits() | Self::WRITE.bits(); const ALL = Self::READ.bits() | Self::WRITE.bits() | Self::EXECUTE.bits(); } } }
This generates a struct with associated constants for each flag. The macro automatically implements common traits like Debug
, Clone
, and comparison operators. Each flag is assigned a bit position, and you can define composite flags that combine multiple bits.
Compiler Flags Example
A practical example in compiler development is managing compiler flags that control various aspects of the compilation process:
#![allow(unused)] fn main() { bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct CompilerFlags: u32 { const OPTIMIZE = 1 << 0; const DEBUG_INFO = 1 << 1; const WARNINGS_AS_ERRORS = 1 << 2; const VERBOSE = 1 << 3; const LINK_TIME_OPTIMIZATION = 1 << 4; const STATIC_LINKING = 1 << 5; const PROFILE = 1 << 6; const RELEASE = Self::OPTIMIZE.bits() | Self::LINK_TIME_OPTIMIZATION.bits(); const DEBUG = Self::DEBUG_INFO.bits() | Self::VERBOSE.bits(); } } }
Note how we define composite flags like RELEASE
and DEBUG
that combine multiple individual flags. This pattern is common in compilers where certain modes imply specific sets of options.
Working with Flags
The generated types provide a rich API for manipulating flag sets. You can combine flags using the bitwise OR operator, check if specific flags are set, and perform set operations:
#![allow(unused)] fn main() { use bitflags::bitflags; bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct FilePermissions: u32 { const READ = 0b0000_0001; const WRITE = 0b0000_0010; const EXECUTE = 0b0000_0100; const READ_WRITE = Self::READ.bits() | Self::WRITE.bits(); const ALL = Self::READ.bits() | Self::WRITE.bits() | Self::EXECUTE.bits(); } } bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct CompilerFlags: u32 { const OPTIMIZE = 1 << 0; const DEBUG_INFO = 1 << 1; const WARNINGS_AS_ERRORS = 1 << 2; const VERBOSE = 1 << 3; const LINK_TIME_OPTIMIZATION = 1 << 4; const STATIC_LINKING = 1 << 5; const PROFILE = 1 << 6; const RELEASE = Self::OPTIMIZE.bits() | Self::LINK_TIME_OPTIMIZATION.bits(); const DEBUG = Self::DEBUG_INFO.bits() | Self::VERBOSE.bits(); } } pub fn demonstrate_file_permissions() { let mut perms = FilePermissions::READ | FilePermissions::WRITE; println!("Initial permissions: {:?}", perms); println!("Can read: {}", perms.contains(FilePermissions::READ)); println!("Can execute: {}", perms.contains(FilePermissions::EXECUTE)); perms.insert(FilePermissions::EXECUTE); println!("After adding execute: {:?}", perms); perms.remove(FilePermissions::WRITE); println!("After removing write: {:?}", perms); let readonly = FilePermissions::READ; let readwrite = FilePermissions::READ_WRITE; println!( "Read-only intersects with read-write: {}", !readonly.intersection(readwrite).is_empty() ); } #[derive(Debug)] pub struct CompilerOptions { flags: CompilerFlags, optimization_level: u8, } impl CompilerOptions { pub fn new(flags: CompilerFlags) -> Self { let optimization_level = if flags.contains(CompilerFlags::OPTIMIZE) { if flags.contains(CompilerFlags::LINK_TIME_OPTIMIZATION) { 3 } else { 2 } } else { 0 }; Self { flags, optimization_level, } } pub fn is_debug_build(&self) -> bool { self.flags.contains(CompilerFlags::DEBUG_INFO) } pub fn enable_profiling(&mut self) { self.flags.insert(CompilerFlags::PROFILE); } pub fn optimization_level(&self) -> u8 { self.optimization_level } } #[cfg(test)] mod tests { use super::*; #[test] fn test_file_permissions() { let perms = FilePermissions::READ | FilePermissions::WRITE; assert!(perms.contains(FilePermissions::READ)); assert!(perms.contains(FilePermissions::WRITE)); assert!(!perms.contains(FilePermissions::EXECUTE)); assert_eq!( FilePermissions::ALL, FilePermissions::READ | FilePermissions::WRITE | FilePermissions::EXECUTE ); } #[test] fn test_compiler_flags() { let debug = CompilerFlags::DEBUG; assert!(debug.contains(CompilerFlags::DEBUG_INFO)); assert!(debug.contains(CompilerFlags::VERBOSE)); assert!(!debug.contains(CompilerFlags::OPTIMIZE)); } #[test] fn test_compiler_options() { let release_options = CompilerOptions::new(CompilerFlags::RELEASE); assert!(!release_options.is_debug_build()); assert_eq!(release_options.optimization_level, 3); let debug_options = CompilerOptions::new(CompilerFlags::DEBUG); assert!(debug_options.is_debug_build()); assert_eq!(debug_options.optimization_level, 0); } } pub fn demonstrate_compiler_flags() { let debug_build = CompilerFlags::DEBUG; let release_build = CompilerFlags::RELEASE; println!("Debug flags: {:?}", debug_build); println!("Release flags: {:?}", release_build); let custom = CompilerFlags::OPTIMIZE | CompilerFlags::DEBUG_INFO | CompilerFlags::WARNINGS_AS_ERRORS; println!("Custom build flags: {:?}", custom); println!( "Custom has optimization: {}", custom.contains(CompilerFlags::OPTIMIZE) ); let common = debug_build.intersection(release_build); println!("Common flags between debug and release: {:?}", common); } }
The contains
method checks if specific flags are set, while intersection
returns flags common to both sets. Other useful operations include union
for combining flag sets, difference
for flags in one set but not another, and toggle
for flipping specific flags.
Integration with Compiler Structures
Bitflags integrate naturally with other compiler data structures. Here’s an example of using flags within a larger compiler options structure:
#![allow(unused)] fn main() { use bitflags::bitflags; bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct FilePermissions: u32 { const READ = 0b0000_0001; const WRITE = 0b0000_0010; const EXECUTE = 0b0000_0100; const READ_WRITE = Self::READ.bits() | Self::WRITE.bits(); const ALL = Self::READ.bits() | Self::WRITE.bits() | Self::EXECUTE.bits(); } } bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct CompilerFlags: u32 { const OPTIMIZE = 1 << 0; const DEBUG_INFO = 1 << 1; const WARNINGS_AS_ERRORS = 1 << 2; const VERBOSE = 1 << 3; const LINK_TIME_OPTIMIZATION = 1 << 4; const STATIC_LINKING = 1 << 5; const PROFILE = 1 << 6; const RELEASE = Self::OPTIMIZE.bits() | Self::LINK_TIME_OPTIMIZATION.bits(); const DEBUG = Self::DEBUG_INFO.bits() | Self::VERBOSE.bits(); } } pub fn demonstrate_file_permissions() { let mut perms = FilePermissions::READ | FilePermissions::WRITE; println!("Initial permissions: {:?}", perms); println!("Can read: {}", perms.contains(FilePermissions::READ)); println!("Can execute: {}", perms.contains(FilePermissions::EXECUTE)); perms.insert(FilePermissions::EXECUTE); println!("After adding execute: {:?}", perms); perms.remove(FilePermissions::WRITE); println!("After removing write: {:?}", perms); let readonly = FilePermissions::READ; let readwrite = FilePermissions::READ_WRITE; println!( "Read-only intersects with read-write: {}", !readonly.intersection(readwrite).is_empty() ); } pub fn demonstrate_compiler_flags() { let debug_build = CompilerFlags::DEBUG; let release_build = CompilerFlags::RELEASE; println!("Debug flags: {:?}", debug_build); println!("Release flags: {:?}", release_build); let custom = CompilerFlags::OPTIMIZE | CompilerFlags::DEBUG_INFO | CompilerFlags::WARNINGS_AS_ERRORS; println!("Custom build flags: {:?}", custom); println!( "Custom has optimization: {}", custom.contains(CompilerFlags::OPTIMIZE) ); let common = debug_build.intersection(release_build); println!("Common flags between debug and release: {:?}", common); } impl CompilerOptions { pub fn new(flags: CompilerFlags) -> Self { let optimization_level = if flags.contains(CompilerFlags::OPTIMIZE) { if flags.contains(CompilerFlags::LINK_TIME_OPTIMIZATION) { 3 } else { 2 } } else { 0 }; Self { flags, optimization_level, } } pub fn is_debug_build(&self) -> bool { self.flags.contains(CompilerFlags::DEBUG_INFO) } pub fn enable_profiling(&mut self) { self.flags.insert(CompilerFlags::PROFILE); } pub fn optimization_level(&self) -> u8 { self.optimization_level } } #[cfg(test)] mod tests { use super::*; #[test] fn test_file_permissions() { let perms = FilePermissions::READ | FilePermissions::WRITE; assert!(perms.contains(FilePermissions::READ)); assert!(perms.contains(FilePermissions::WRITE)); assert!(!perms.contains(FilePermissions::EXECUTE)); assert_eq!( FilePermissions::ALL, FilePermissions::READ | FilePermissions::WRITE | FilePermissions::EXECUTE ); } #[test] fn test_compiler_flags() { let debug = CompilerFlags::DEBUG; assert!(debug.contains(CompilerFlags::DEBUG_INFO)); assert!(debug.contains(CompilerFlags::VERBOSE)); assert!(!debug.contains(CompilerFlags::OPTIMIZE)); } #[test] fn test_compiler_options() { let release_options = CompilerOptions::new(CompilerFlags::RELEASE); assert!(!release_options.is_debug_build()); assert_eq!(release_options.optimization_level, 3); let debug_options = CompilerOptions::new(CompilerFlags::DEBUG); assert!(debug_options.is_debug_build()); assert_eq!(debug_options.optimization_level, 0); } } #[derive(Debug)] pub struct CompilerOptions { flags: CompilerFlags, optimization_level: u8, } }
#![allow(unused)] fn main() { use bitflags::bitflags; bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct FilePermissions: u32 { const READ = 0b0000_0001; const WRITE = 0b0000_0010; const EXECUTE = 0b0000_0100; const READ_WRITE = Self::READ.bits() | Self::WRITE.bits(); const ALL = Self::READ.bits() | Self::WRITE.bits() | Self::EXECUTE.bits(); } } bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct CompilerFlags: u32 { const OPTIMIZE = 1 << 0; const DEBUG_INFO = 1 << 1; const WARNINGS_AS_ERRORS = 1 << 2; const VERBOSE = 1 << 3; const LINK_TIME_OPTIMIZATION = 1 << 4; const STATIC_LINKING = 1 << 5; const PROFILE = 1 << 6; const RELEASE = Self::OPTIMIZE.bits() | Self::LINK_TIME_OPTIMIZATION.bits(); const DEBUG = Self::DEBUG_INFO.bits() | Self::VERBOSE.bits(); } } pub fn demonstrate_file_permissions() { let mut perms = FilePermissions::READ | FilePermissions::WRITE; println!("Initial permissions: {:?}", perms); println!("Can read: {}", perms.contains(FilePermissions::READ)); println!("Can execute: {}", perms.contains(FilePermissions::EXECUTE)); perms.insert(FilePermissions::EXECUTE); println!("After adding execute: {:?}", perms); perms.remove(FilePermissions::WRITE); println!("After removing write: {:?}", perms); let readonly = FilePermissions::READ; let readwrite = FilePermissions::READ_WRITE; println!( "Read-only intersects with read-write: {}", !readonly.intersection(readwrite).is_empty() ); } pub fn demonstrate_compiler_flags() { let debug_build = CompilerFlags::DEBUG; let release_build = CompilerFlags::RELEASE; println!("Debug flags: {:?}", debug_build); println!("Release flags: {:?}", release_build); let custom = CompilerFlags::OPTIMIZE | CompilerFlags::DEBUG_INFO | CompilerFlags::WARNINGS_AS_ERRORS; println!("Custom build flags: {:?}", custom); println!( "Custom has optimization: {}", custom.contains(CompilerFlags::OPTIMIZE) ); let common = debug_build.intersection(release_build); println!("Common flags between debug and release: {:?}", common); } #[derive(Debug)] pub struct CompilerOptions { flags: CompilerFlags, optimization_level: u8, } #[cfg(test)] mod tests { use super::*; #[test] fn test_file_permissions() { let perms = FilePermissions::READ | FilePermissions::WRITE; assert!(perms.contains(FilePermissions::READ)); assert!(perms.contains(FilePermissions::WRITE)); assert!(!perms.contains(FilePermissions::EXECUTE)); assert_eq!( FilePermissions::ALL, FilePermissions::READ | FilePermissions::WRITE | FilePermissions::EXECUTE ); } #[test] fn test_compiler_flags() { let debug = CompilerFlags::DEBUG; assert!(debug.contains(CompilerFlags::DEBUG_INFO)); assert!(debug.contains(CompilerFlags::VERBOSE)); assert!(!debug.contains(CompilerFlags::OPTIMIZE)); } #[test] fn test_compiler_options() { let release_options = CompilerOptions::new(CompilerFlags::RELEASE); assert!(!release_options.is_debug_build()); assert_eq!(release_options.optimization_level, 3); let debug_options = CompilerOptions::new(CompilerFlags::DEBUG); assert!(debug_options.is_debug_build()); assert_eq!(debug_options.optimization_level, 0); } } impl CompilerOptions { pub fn new(flags: CompilerFlags) -> Self { let optimization_level = if flags.contains(CompilerFlags::OPTIMIZE) { if flags.contains(CompilerFlags::LINK_TIME_OPTIMIZATION) { 3 } else { 2 } } else { 0 }; Self { flags, optimization_level, } } pub fn is_debug_build(&self) -> bool { self.flags.contains(CompilerFlags::DEBUG_INFO) } pub fn enable_profiling(&mut self) { self.flags.insert(CompilerFlags::PROFILE); } pub fn optimization_level(&self) -> u8 { self.optimization_level } } }
This pattern allows you to encapsulate flag-based configuration with derived state and behavior. The compiler options structure can make decisions based on flag combinations and expose higher-level methods that abstract over the underlying bit manipulation.
File Permissions Example
Another common use case is representing file permissions or access modifiers:
#![allow(unused)] fn main() { use bitflags::bitflags; bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct FilePermissions: u32 { const READ = 0b0000_0001; const WRITE = 0b0000_0010; const EXECUTE = 0b0000_0100; const READ_WRITE = Self::READ.bits() | Self::WRITE.bits(); const ALL = Self::READ.bits() | Self::WRITE.bits() | Self::EXECUTE.bits(); } } bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct CompilerFlags: u32 { const OPTIMIZE = 1 << 0; const DEBUG_INFO = 1 << 1; const WARNINGS_AS_ERRORS = 1 << 2; const VERBOSE = 1 << 3; const LINK_TIME_OPTIMIZATION = 1 << 4; const STATIC_LINKING = 1 << 5; const PROFILE = 1 << 6; const RELEASE = Self::OPTIMIZE.bits() | Self::LINK_TIME_OPTIMIZATION.bits(); const DEBUG = Self::DEBUG_INFO.bits() | Self::VERBOSE.bits(); } } pub fn demonstrate_compiler_flags() { let debug_build = CompilerFlags::DEBUG; let release_build = CompilerFlags::RELEASE; println!("Debug flags: {:?}", debug_build); println!("Release flags: {:?}", release_build); let custom = CompilerFlags::OPTIMIZE | CompilerFlags::DEBUG_INFO | CompilerFlags::WARNINGS_AS_ERRORS; println!("Custom build flags: {:?}", custom); println!( "Custom has optimization: {}", custom.contains(CompilerFlags::OPTIMIZE) ); let common = debug_build.intersection(release_build); println!("Common flags between debug and release: {:?}", common); } #[derive(Debug)] pub struct CompilerOptions { flags: CompilerFlags, optimization_level: u8, } impl CompilerOptions { pub fn new(flags: CompilerFlags) -> Self { let optimization_level = if flags.contains(CompilerFlags::OPTIMIZE) { if flags.contains(CompilerFlags::LINK_TIME_OPTIMIZATION) { 3 } else { 2 } } else { 0 }; Self { flags, optimization_level, } } pub fn is_debug_build(&self) -> bool { self.flags.contains(CompilerFlags::DEBUG_INFO) } pub fn enable_profiling(&mut self) { self.flags.insert(CompilerFlags::PROFILE); } pub fn optimization_level(&self) -> u8 { self.optimization_level } } #[cfg(test)] mod tests { use super::*; #[test] fn test_file_permissions() { let perms = FilePermissions::READ | FilePermissions::WRITE; assert!(perms.contains(FilePermissions::READ)); assert!(perms.contains(FilePermissions::WRITE)); assert!(!perms.contains(FilePermissions::EXECUTE)); assert_eq!( FilePermissions::ALL, FilePermissions::READ | FilePermissions::WRITE | FilePermissions::EXECUTE ); } #[test] fn test_compiler_flags() { let debug = CompilerFlags::DEBUG; assert!(debug.contains(CompilerFlags::DEBUG_INFO)); assert!(debug.contains(CompilerFlags::VERBOSE)); assert!(!debug.contains(CompilerFlags::OPTIMIZE)); } #[test] fn test_compiler_options() { let release_options = CompilerOptions::new(CompilerFlags::RELEASE); assert!(!release_options.is_debug_build()); assert_eq!(release_options.optimization_level, 3); let debug_options = CompilerOptions::new(CompilerFlags::DEBUG); assert!(debug_options.is_debug_build()); assert_eq!(debug_options.optimization_level, 0); } } pub fn demonstrate_file_permissions() { let mut perms = FilePermissions::READ | FilePermissions::WRITE; println!("Initial permissions: {:?}", perms); println!("Can read: {}", perms.contains(FilePermissions::READ)); println!("Can execute: {}", perms.contains(FilePermissions::EXECUTE)); perms.insert(FilePermissions::EXECUTE); println!("After adding execute: {:?}", perms); perms.remove(FilePermissions::WRITE); println!("After removing write: {:?}", perms); let readonly = FilePermissions::READ; let readwrite = FilePermissions::READ_WRITE; println!( "Read-only intersects with read-write: {}", !readonly.intersection(readwrite).is_empty() ); } }
The methods insert
and remove
modify flag sets in place, while intersection
checks for overlapping permissions. This API is much clearer than manual bit manipulation and prevents common errors like using the wrong bit mask.
bumpalo
The bumpalo
crate provides a fast bump allocation arena for Rust that dramatically improves allocation performance in compiler workloads. Bump allocation, also known as linear or arena allocation, allocates memory by simply incrementing a pointer through a contiguous block of memory. This makes allocation extremely fast - just a pointer bump and bounds check - at the cost of not being able to deallocate individual objects. Instead, the entire arena is deallocated at once when dropped.
For compiler development, bump allocation is ideal because compilation naturally proceeds in phases where large numbers of temporary allocations are created, used, and then all discarded together. AST nodes, type information, and intermediate representations can all be allocated in arenas that live only as long as needed. This allocation strategy eliminates the overhead of reference counting or garbage collection while providing excellent cache locality.
Basic Allocation
The simplest use of bumpalo is allocating individual values:
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } }
Values allocated in the bump allocator are returned as references with the allocator’s lifetime. This ensures they remain valid as long as the allocator exists.
String Allocation
Bumpalo provides specialized methods for string allocation:
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } }
The alloc_str
method is particularly efficient for building up strings during parsing or code generation, as it avoids the overhead of String
’s capacity management.
Slice Allocation
Copying slices into the arena is straightforward:
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } }
This is useful for storing parsed tokens, symbol tables, or any sequence of data that needs to outlive its original source.
Bump-Allocated Collections
Bumpalo provides arena-allocated versions of common collections:
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } }
These collections avoid heap allocations entirely, storing their data directly in the arena. This is perfect for temporary collections during compilation passes.
AST Construction
Arena allocation shines for building recursive data structures like ASTs:
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } }
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } }
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } }
The entire AST is allocated in a single contiguous memory region, providing excellent cache locality during traversal. Nodes can freely reference each other without worrying about ownership or lifetimes beyond the arena’s lifetime.
Compiler IR Structures
More complex compiler structures benefit from arena allocation:
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } }
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } }
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } }
This pattern works well for intermediate representations where you build up complex structures during one compilation phase and discard them after lowering or code generation.
Reset and Reuse
Bump allocators can be reset to reclaim all memory at once:
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } }
This is perfect for compilers that process multiple files or compilation units sequentially. Reset the allocator between units to reuse the same memory.
Scoped Allocation
Use bump allocation for temporary computations:
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } }
The arena automatically frees all memory when it goes out of scope, making it ideal for temporary working memory during optimization passes.
Higher-Order Patterns
Encapsulate arena lifetime management with closures:
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } }
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } }
This pattern ensures the arena is properly scoped and makes it easy to add arena allocation to existing code.
Symbol Tables
Arena allocation works well for symbol interning:
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } }
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } }
Interned strings live as long as the compilation unit needs them, with minimal allocation overhead and excellent cache performance.
Graph Structures
Build complex graph structures like control flow graphs:
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } }
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } }
Nodes can freely reference each other without complex lifetime management or reference counting overhead.
Bump Boxes
For single-value allocations with ownership semantics:
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } }
Bump boxes provide a Box
-like interface while using arena allocation under the hood.
String Building
Efficient string construction without repeated allocations:
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } }
#![allow(unused)] fn main() { use std::fmt::Debug; use bumpalo::Bump; use bumpalo::boxed::Box as BumpBox; use bumpalo::collections::Vec as BumpVec; /// Demonstrates basic bump allocation with simple types pub fn basic_allocation() -> Vec<i32> { let bump = Bump::new(); // Allocate individual values let x = bump.alloc(10); let y = bump.alloc(20); let z = bump.alloc(30); // Values are valid for the lifetime of the bump allocator vec![*x, *y, *z] } /// Shows how to allocate strings in the bump allocator pub fn allocate_strings(bump: &Bump) -> &str { // Allocate a string slice let hello = bump.alloc_str("Hello, "); let world = bump.alloc_str("World!"); // Concatenate using bump allocation (bump.alloc_str(&format!("{}{}", hello, world))) as _ } /// Allocates slices efficiently in the bump allocator pub fn allocate_slices(bump: &Bump) -> &[i32] { // Allocate a slice from a vector let data = vec![1, 2, 3, 4, 5]; bump.alloc_slice_copy(&data) } /// Demonstrates using bump-allocated collections pub fn bump_collections() -> Vec<i32> { let bump = Bump::new(); // Create a bump-allocated vector let mut vec = BumpVec::new_in(&bump); vec.push(1); vec.push(2); vec.push(3); // Convert to standard Vec for return vec.iter().copied().collect() } /// Shows arena-style allocation for AST nodes #[derive(Debug, Clone)] pub enum Expr<'a> { Number(i64), Add(&'a Expr<'a>, &'a Expr<'a>), Multiply(&'a Expr<'a>, &'a Expr<'a>), } pub fn build_ast<'a>(bump: &'a Bump) -> &'a Expr<'a> { // Build expression: (2 + 3) * 4 let two = bump.alloc(Expr::Number(2)); let three = bump.alloc(Expr::Number(3)); let four = bump.alloc(Expr::Number(4)); let add = bump.alloc(Expr::Add(two, three)); bump.alloc(Expr::Multiply(add, four)) } /// Evaluates an AST expression pub fn eval_expr(expr: &Expr) -> i64 { match expr { Expr::Number(n) => *n, Expr::Add(a, b) => eval_expr(a) + eval_expr(b), Expr::Multiply(a, b) => eval_expr(a) * eval_expr(b), } } /// Demonstrates using bump allocation for a simple compiler IR pub struct Function<'a> { pub name: &'a str, pub params: BumpVec<'a, &'a str>, pub body: BumpVec<'a, Statement<'a>>, } pub enum Statement<'a> { Let(&'a str, &'a Expr<'a>), Return(&'a Expr<'a>), } pub fn build_function<'a>(bump: &'a Bump) -> Function<'a> { let mut params = BumpVec::new_in(bump); let x = bump.alloc_str("x"); let y = bump.alloc_str("y"); params.push(&*x); params.push(&*y); let mut body = BumpVec::new_in(bump); // let sum = x + y let x = bump.alloc(Expr::Number(10)); let y = bump.alloc(Expr::Number(20)); let sum_expr = bump.alloc(Expr::Add(x, y)); body.push(Statement::Let("sum", sum_expr)); // return sum * 2 let two = bump.alloc(Expr::Number(2)); let result = bump.alloc(Expr::Multiply(sum_expr, two)); body.push(Statement::Return(result)); Function { name: bump.alloc_str("calculate"), params, body, } } /// Shows how to reset and reuse a bump allocator pub fn reset_and_reuse() -> (Vec<i32>, Vec<i32>) { let mut bump = Bump::new(); // First allocation cycle let first = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([1, 2, 3].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; // Reset the allocator to reclaim all memory bump.reset(); // Second allocation cycle reuses the same memory let second = { let vec = BumpVec::new_in(&bump); let mut vec = vec; vec.extend([4, 5, 6].iter().copied()); vec.iter().copied().collect::<Vec<_>>() }; (first, second) } /// Demonstrates scoped allocation for temporary computations pub fn scoped_allocation() -> i32 { let bump = Bump::new(); // Create a temporary allocation scope // Memory is automatically freed when bump goes out of scope { // Allocate temporary working data let mut temps = BumpVec::new_in(&bump); for i in 0..100 { temps.push(i); } // Process data temps.iter().sum::<i32>() } } /// Shows using bump allocation with closures pub fn with_allocator<F, R>(f: F) -> R where F: FnOnce(&Bump) -> R, { let bump = Bump::new(); f(&bump) } pub fn closure_example() -> i32 { with_allocator(|bump| { let numbers = bump.alloc_slice_copy(&[1, 2, 3, 4, 5]); numbers.iter().sum() }) } /// Custom type that uses bump allocation internally pub struct SymbolTable<'a> { bump: &'a Bump, symbols: BumpVec<'a, &'a str>, } impl<'a> SymbolTable<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, symbols: BumpVec::new_in(bump), } } pub fn intern(&mut self, s: &str) -> usize { // Check if symbol already exists for (i, &sym) in self.symbols.iter().enumerate() { if sym == s { return i; } } // Allocate new symbol let symbol = self.bump.alloc_str(s); let id = self.symbols.len(); self.symbols.push(symbol); id } pub fn get(&self, id: usize) -> Option<&'a str> { self.symbols.get(id).copied() } } /// Demonstrates using bump allocation for graph structures pub struct Node<'a> { pub value: i32, pub children: BumpVec<'a, &'a Node<'a>>, } pub fn build_tree<'a>(bump: &'a Bump) -> &'a Node<'a> { // Build a simple tree structure let leaf1 = bump.alloc(Node { value: 1, children: BumpVec::new_in(bump), }); let leaf2 = bump.alloc(Node { value: 2, children: BumpVec::new_in(bump), }); let mut branch_children = BumpVec::new_in(bump); branch_children.push(&*leaf1); branch_children.push(&*leaf2); bump.alloc(Node { value: 3, children: branch_children, }) } /// Shows statistics about memory usage pub fn allocation_stats() { let mut bump = Bump::new(); // Allocate some data for i in 0..1000 { bump.alloc(i); } // Get allocation statistics let allocated = bump.allocated_bytes(); println!("Allocated: {} bytes", allocated); // Reset and check again bump.reset(); let after_reset = bump.allocated_bytes(); println!("After reset: {} bytes", after_reset); } /// Demonstrates bump boxes for single-value allocation pub fn bump_box_example() -> i32 { let bump = Bump::new(); // Create a bump-allocated box let boxed: BumpBox<i32> = BumpBox::new_in(100, &bump); // Bump boxes can be dereferenced like regular boxes *boxed } /// Shows efficient string building with bump allocation pub struct StringBuilder<'a> { bump: &'a Bump, parts: BumpVec<'a, &'a str>, } #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_allocation() { let result = basic_allocation(); assert_eq!(result, vec![10, 20, 30]); } #[test] fn test_string_allocation() { let bump = Bump::new(); let result = allocate_strings(&bump); assert_eq!(result, "Hello, World!"); } #[test] fn test_ast_evaluation() { let bump = Bump::new(); let ast = build_ast(&bump); assert_eq!(eval_expr(ast), 20); // (2 + 3) * 4 = 20 } #[test] fn test_reset_reuse() { let (first, second) = reset_and_reuse(); assert_eq!(first, vec![1, 2, 3]); assert_eq!(second, vec![4, 5, 6]); } #[test] fn test_symbol_table() { let bump = Bump::new(); let mut table = SymbolTable::new(&bump); let id1 = table.intern("hello"); let id2 = table.intern("world"); let id3 = table.intern("hello"); // Should return same ID assert_eq!(id1, id3); assert_ne!(id1, id2); assert_eq!(table.get(id1), Some("hello")); assert_eq!(table.get(id2), Some("world")); } #[test] fn test_tree_building() { let bump = Bump::new(); let tree = build_tree(&bump); assert_eq!(tree.value, 3); assert_eq!(tree.children.len(), 2); assert_eq!(tree.children[0].value, 1); assert_eq!(tree.children[1].value, 2); } #[test] fn test_string_builder() { let bump = Bump::new(); let mut builder = StringBuilder::new(&bump); builder.append("Hello"); builder.append(", "); builder.append("World!"); assert_eq!(builder.build(), "Hello, World!"); } } impl<'a> StringBuilder<'a> { pub fn new(bump: &'a Bump) -> Self { Self { bump, parts: BumpVec::new_in(bump), } } pub fn append(&mut self, s: &str) { let part = self.bump.alloc_str(s); self.parts.push(part); } pub fn build(&self) -> String { self.parts.iter().flat_map(|s| s.chars()).collect() } } }
This avoids the repeated allocations that would occur with String::push_str
or format strings.
Performance Characteristics
Bump allocation provides several performance advantages for compilers:
Allocation Speed: O(1) allocation with just a pointer increment and bounds check. No searching for free blocks or managing free lists.
Deallocation Speed: O(1) for the entire arena. No need to track individual object lifetimes or run destructors.
Memory Locality: Sequential allocations are contiguous in memory, providing excellent cache performance during traversal.
Low Overhead: No per-allocation metadata like headers or reference counts. The only overhead is unused space at the end of the current chunk.
Predictable Performance: No garbage collection pauses or reference counting overhead. Performance is deterministic and easy to reason about.
Best Practices
Structure your compiler passes to match arena lifetimes. Each major phase (parsing, type checking, optimization, code generation) can use its own arena that’s dropped when the phase completes.
Avoid storing bump-allocated values in long-lived data structures. The arena lifetime must outlive all references to its allocated values.
Use typed arenas for hot paths. Creating type-specific arenas can eliminate pointer indirection and improve cache performance for frequently accessed types.
Reset and reuse arenas when processing multiple compilation units. This amortizes the cost of the initial memory allocation across all units.
Consider using multiple arenas for different lifetimes. For example, use one arena for the AST that lives through type checking, and another for temporary values during each optimization pass.
Profile your allocator usage to find the optimal chunk size. Larger chunks mean fewer allocations from the system allocator but potentially more wasted space.
id-arena
The id-arena
crate provides a simple arena allocator that assigns unique IDs to values. In compiler development, arenas solve numerous challenges related to managing AST nodes, type representations, and intermediate representations. Traditional approaches using references or Rc
/Arc
lead to complex lifetime management and potential cycles. Arena allocation with IDs provides stable references that are copyable, comparable, and safe to store anywhere.
Compilers deal with many interconnected data structures: AST nodes reference other nodes, types reference other types, and IR instructions reference values and blocks. Using an arena with IDs instead of pointers simplifies these relationships dramatically. IDs are just integers, so they can be copied freely, stored in hash maps, and serialized without concern for lifetimes or ownership.
AST Construction
Building an AST with id-arena involves allocating nodes in the arena and using IDs for references:
#![allow(unused)] fn main() { use std::collections::HashMap; use id_arena::{Arena, Id}; #[derive(Debug, Clone)] pub enum NodeKind { Program, Function { name: String, params: Vec<Id<AstNode>>, body: Id<AstNode>, }, Parameter { name: String, }, Block, VariableDecl { name: String, init: Option<Id<AstNode>>, }, BinaryOp { op: BinaryOperator, left: Id<AstNode>, right: Id<AstNode>, }, Literal(Literal), Identifier(String), } #[derive(Debug, Clone)] pub enum BinaryOperator { Add, Sub, Mul, Div, Eq, Lt, } #[derive(Debug, Clone)] pub enum Literal { Integer(i64), Float(f64), String(String), Bool(bool), } #[derive(Debug, Clone)] pub struct Type { pub kind: TypeKind, } #[derive(Debug, Clone)] pub enum TypeKind { Int, Float, Bool, String, Function { params: Vec<Id<Type>>, ret: Id<Type>, }, Unknown, } pub struct Compiler { pub ast_arena: Arena<AstNode>, pub type_arena: Arena<Type>, pub symbol_table: HashMap<String, Id<AstNode>>, } impl Default for Compiler { fn default() -> Self { Self::new() } } impl Compiler { pub fn new() -> Self { Self { ast_arena: Arena::new(), type_arena: Arena::new(), symbol_table: HashMap::new(), } } pub fn build_example_ast(&mut self) -> Id<AstNode> { let int_type = self.type_arena.alloc(Type { kind: TypeKind::Int, }); let x_param = self.ast_arena.alloc(AstNode { kind: NodeKind::Parameter { name: "x".to_string(), }, ty: Some(int_type), children: vec![], }); let y_param = self.ast_arena.alloc(AstNode { kind: NodeKind::Parameter { name: "y".to_string(), }, ty: Some(int_type), children: vec![], }); let x_ident = self.ast_arena.alloc(AstNode { kind: NodeKind::Identifier("x".to_string()), ty: Some(int_type), children: vec![], }); let y_ident = self.ast_arena.alloc(AstNode { kind: NodeKind::Identifier("y".to_string()), ty: Some(int_type), children: vec![], }); let add_expr = self.ast_arena.alloc(AstNode { kind: NodeKind::BinaryOp { op: BinaryOperator::Add, left: x_ident, right: y_ident, }, ty: Some(int_type), children: vec![x_ident, y_ident], }); let body = self.ast_arena.alloc(AstNode { kind: NodeKind::Block, ty: None, children: vec![add_expr], }); let add_func = self.ast_arena.alloc(AstNode { kind: NodeKind::Function { name: "add".to_string(), params: vec![x_param, y_param], body, }, ty: None, children: vec![x_param, y_param, body], }); self.symbol_table.insert("add".to_string(), add_func); self.ast_arena.alloc(AstNode { kind: NodeKind::Program, ty: None, children: vec![add_func], }) } pub fn print_ast(&self, id: Id<AstNode>, depth: usize) { let indent = " ".repeat(depth); let node = &self.ast_arena[id]; match &node.kind { NodeKind::Program => println!("{}Program", indent), NodeKind::Function { name, params, body } => { println!("{}Function: {}", indent, name); println!("{} Parameters:", indent); for ¶m_id in params { self.print_ast(param_id, depth + 2); } println!("{} Body:", indent); self.print_ast(*body, depth + 2); } NodeKind::Parameter { name } => { println!( "{}Parameter: {} (type: {:?})", indent, name, node.ty.map(|t| &self.type_arena[t].kind) ); } NodeKind::Block => { println!("{}Block", indent); for &child in &node.children { self.print_ast(child, depth + 1); } } NodeKind::BinaryOp { op, left, right } => { println!("{}BinaryOp: {:?}", indent, op); self.print_ast(*left, depth + 1); self.print_ast(*right, depth + 1); } NodeKind::Identifier(name) => println!("{}Identifier: {}", indent, name), NodeKind::Literal(lit) => println!("{}Literal: {:?}", indent, lit), NodeKind::VariableDecl { name, init } => { println!("{}VariableDecl: {}", indent, name); if let Some(init_id) = init { self.print_ast(*init_id, depth + 1); } } } } } pub struct InstructionArena { instructions: Arena<Instruction>, blocks: Arena<BasicBlock>, } #[derive(Debug)] pub struct Instruction { pub opcode: Opcode, pub operands: Vec<Operand>, pub result: Option<Id<Value>>, } #[derive(Debug)] pub enum Opcode { Add, Sub, Mul, Load, Store, Jump, Branch, Return, } #[derive(Debug)] pub enum Operand { Value(Id<Value>), Block(Id<BasicBlock>), Immediate(i64), } #[derive(Debug)] pub struct BasicBlock { pub label: String, pub instructions: Vec<Id<Instruction>>, pub terminator: Option<Id<Instruction>>, } #[derive(Debug)] pub struct Value { pub name: String, pub ty: ValueType, } #[derive(Debug)] pub enum ValueType { I32, I64, F32, F64, Ptr, } impl Default for InstructionArena { fn default() -> Self { Self::new() } } impl InstructionArena { pub fn new() -> Self { Self { instructions: Arena::new(), blocks: Arena::new(), } } pub fn create_example_ir(&mut self, values: &mut Arena<Value>) -> Id<BasicBlock> { let x = values.alloc(Value { name: "%x".to_string(), ty: ValueType::I32, }); let y = values.alloc(Value { name: "%y".to_string(), ty: ValueType::I32, }); let result = values.alloc(Value { name: "%result".to_string(), ty: ValueType::I32, }); let load_x = self.instructions.alloc(Instruction { opcode: Opcode::Load, operands: vec![Operand::Value(x)], result: Some(x), }); let load_y = self.instructions.alloc(Instruction { opcode: Opcode::Load, operands: vec![Operand::Value(y)], result: Some(y), }); let add = self.instructions.alloc(Instruction { opcode: Opcode::Add, operands: vec![Operand::Value(x), Operand::Value(y)], result: Some(result), }); let ret = self.instructions.alloc(Instruction { opcode: Opcode::Return, operands: vec![Operand::Value(result)], result: None, }); self.blocks.alloc(BasicBlock { label: "entry".to_string(), instructions: vec![load_x, load_y, add], terminator: Some(ret), }) } pub fn print_block(&self, block_id: Id<BasicBlock>, values: &Arena<Value>) { let block = &self.blocks[block_id]; println!("{}:", block.label); for &inst_id in &block.instructions { self.print_instruction(inst_id, values); } if let Some(term_id) = block.terminator { self.print_instruction(term_id, values); } } fn print_instruction(&self, inst_id: Id<Instruction>, values: &Arena<Value>) { let inst = &self.instructions[inst_id]; print!(" "); if let Some(result_id) = inst.result { print!("{} = ", values[result_id].name); } print!("{:?}", inst.opcode); for op in &inst.operands { match op { Operand::Value(v) => print!(" {}", values[*v].name), Operand::Block(b) => print!(" {}", self.blocks[*b].label), Operand::Immediate(i) => print!(" {}", i), } } println!(); } } pub fn demonstrate_arena_efficiency() { let mut arena = Arena::<String>::new(); let mut ids = Vec::new(); for i in 0..1000 { let id = arena.alloc(format!("Node {}", i)); ids.push(id); } println!("Arena statistics:"); println!(" Allocated {} strings", ids.len()); println!(" Arena len: {}", arena.len()); let sample_ids: Vec<_> = ids.iter().step_by(100).take(5).collect(); println!(" Sample accesses:"); for &id in &sample_ids { println!(" {} -> {}", id.index(), arena[*id]); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_ast_construction() { let mut compiler = Compiler::new(); let program_id = compiler.build_example_ast(); assert_eq!(compiler.ast_arena.len(), 8); assert!(matches!( compiler.ast_arena[program_id].kind, NodeKind::Program )); } #[test] fn test_symbol_lookup() { let mut compiler = Compiler::new(); compiler.build_example_ast(); assert!(compiler.symbol_table.contains_key("add")); let func_id = compiler.symbol_table["add"]; assert!(matches!( compiler.ast_arena[func_id].kind, NodeKind::Function { .. } )); } #[test] fn test_ir_construction() { let mut ir = InstructionArena::new(); let mut values = Arena::new(); let block_id = ir.create_example_ir(&mut values); assert_eq!(ir.instructions.len(), 4); assert_eq!(values.len(), 3); assert_eq!(ir.blocks[block_id].instructions.len(), 3); } } #[derive(Debug, Clone)] pub struct AstNode { pub kind: NodeKind, pub ty: Option<Id<Type>>, pub children: Vec<Id<AstNode>>, } }
#![allow(unused)] fn main() { use std::collections::HashMap; use id_arena::{Arena, Id}; #[derive(Debug, Clone)] pub struct AstNode { pub kind: NodeKind, pub ty: Option<Id<Type>>, pub children: Vec<Id<AstNode>>, } #[derive(Debug, Clone)] pub enum NodeKind { Program, Function { name: String, params: Vec<Id<AstNode>>, body: Id<AstNode>, }, Parameter { name: String, }, Block, VariableDecl { name: String, init: Option<Id<AstNode>>, }, BinaryOp { op: BinaryOperator, left: Id<AstNode>, right: Id<AstNode>, }, Literal(Literal), Identifier(String), } #[derive(Debug, Clone)] pub enum BinaryOperator { Add, Sub, Mul, Div, Eq, Lt, } #[derive(Debug, Clone)] pub enum Literal { Integer(i64), Float(f64), String(String), Bool(bool), } #[derive(Debug, Clone)] pub struct Type { pub kind: TypeKind, } #[derive(Debug, Clone)] pub enum TypeKind { Int, Float, Bool, String, Function { params: Vec<Id<Type>>, ret: Id<Type>, }, Unknown, } impl Default for Compiler { fn default() -> Self { Self::new() } } impl Compiler { pub fn new() -> Self { Self { ast_arena: Arena::new(), type_arena: Arena::new(), symbol_table: HashMap::new(), } } pub fn build_example_ast(&mut self) -> Id<AstNode> { let int_type = self.type_arena.alloc(Type { kind: TypeKind::Int, }); let x_param = self.ast_arena.alloc(AstNode { kind: NodeKind::Parameter { name: "x".to_string(), }, ty: Some(int_type), children: vec![], }); let y_param = self.ast_arena.alloc(AstNode { kind: NodeKind::Parameter { name: "y".to_string(), }, ty: Some(int_type), children: vec![], }); let x_ident = self.ast_arena.alloc(AstNode { kind: NodeKind::Identifier("x".to_string()), ty: Some(int_type), children: vec![], }); let y_ident = self.ast_arena.alloc(AstNode { kind: NodeKind::Identifier("y".to_string()), ty: Some(int_type), children: vec![], }); let add_expr = self.ast_arena.alloc(AstNode { kind: NodeKind::BinaryOp { op: BinaryOperator::Add, left: x_ident, right: y_ident, }, ty: Some(int_type), children: vec![x_ident, y_ident], }); let body = self.ast_arena.alloc(AstNode { kind: NodeKind::Block, ty: None, children: vec![add_expr], }); let add_func = self.ast_arena.alloc(AstNode { kind: NodeKind::Function { name: "add".to_string(), params: vec![x_param, y_param], body, }, ty: None, children: vec![x_param, y_param, body], }); self.symbol_table.insert("add".to_string(), add_func); self.ast_arena.alloc(AstNode { kind: NodeKind::Program, ty: None, children: vec![add_func], }) } pub fn print_ast(&self, id: Id<AstNode>, depth: usize) { let indent = " ".repeat(depth); let node = &self.ast_arena[id]; match &node.kind { NodeKind::Program => println!("{}Program", indent), NodeKind::Function { name, params, body } => { println!("{}Function: {}", indent, name); println!("{} Parameters:", indent); for ¶m_id in params { self.print_ast(param_id, depth + 2); } println!("{} Body:", indent); self.print_ast(*body, depth + 2); } NodeKind::Parameter { name } => { println!( "{}Parameter: {} (type: {:?})", indent, name, node.ty.map(|t| &self.type_arena[t].kind) ); } NodeKind::Block => { println!("{}Block", indent); for &child in &node.children { self.print_ast(child, depth + 1); } } NodeKind::BinaryOp { op, left, right } => { println!("{}BinaryOp: {:?}", indent, op); self.print_ast(*left, depth + 1); self.print_ast(*right, depth + 1); } NodeKind::Identifier(name) => println!("{}Identifier: {}", indent, name), NodeKind::Literal(lit) => println!("{}Literal: {:?}", indent, lit), NodeKind::VariableDecl { name, init } => { println!("{}VariableDecl: {}", indent, name); if let Some(init_id) = init { self.print_ast(*init_id, depth + 1); } } } } } pub struct InstructionArena { instructions: Arena<Instruction>, blocks: Arena<BasicBlock>, } #[derive(Debug)] pub struct Instruction { pub opcode: Opcode, pub operands: Vec<Operand>, pub result: Option<Id<Value>>, } #[derive(Debug)] pub enum Opcode { Add, Sub, Mul, Load, Store, Jump, Branch, Return, } #[derive(Debug)] pub enum Operand { Value(Id<Value>), Block(Id<BasicBlock>), Immediate(i64), } #[derive(Debug)] pub struct BasicBlock { pub label: String, pub instructions: Vec<Id<Instruction>>, pub terminator: Option<Id<Instruction>>, } #[derive(Debug)] pub struct Value { pub name: String, pub ty: ValueType, } #[derive(Debug)] pub enum ValueType { I32, I64, F32, F64, Ptr, } impl Default for InstructionArena { fn default() -> Self { Self::new() } } impl InstructionArena { pub fn new() -> Self { Self { instructions: Arena::new(), blocks: Arena::new(), } } pub fn create_example_ir(&mut self, values: &mut Arena<Value>) -> Id<BasicBlock> { let x = values.alloc(Value { name: "%x".to_string(), ty: ValueType::I32, }); let y = values.alloc(Value { name: "%y".to_string(), ty: ValueType::I32, }); let result = values.alloc(Value { name: "%result".to_string(), ty: ValueType::I32, }); let load_x = self.instructions.alloc(Instruction { opcode: Opcode::Load, operands: vec![Operand::Value(x)], result: Some(x), }); let load_y = self.instructions.alloc(Instruction { opcode: Opcode::Load, operands: vec![Operand::Value(y)], result: Some(y), }); let add = self.instructions.alloc(Instruction { opcode: Opcode::Add, operands: vec![Operand::Value(x), Operand::Value(y)], result: Some(result), }); let ret = self.instructions.alloc(Instruction { opcode: Opcode::Return, operands: vec![Operand::Value(result)], result: None, }); self.blocks.alloc(BasicBlock { label: "entry".to_string(), instructions: vec![load_x, load_y, add], terminator: Some(ret), }) } pub fn print_block(&self, block_id: Id<BasicBlock>, values: &Arena<Value>) { let block = &self.blocks[block_id]; println!("{}:", block.label); for &inst_id in &block.instructions { self.print_instruction(inst_id, values); } if let Some(term_id) = block.terminator { self.print_instruction(term_id, values); } } fn print_instruction(&self, inst_id: Id<Instruction>, values: &Arena<Value>) { let inst = &self.instructions[inst_id]; print!(" "); if let Some(result_id) = inst.result { print!("{} = ", values[result_id].name); } print!("{:?}", inst.opcode); for op in &inst.operands { match op { Operand::Value(v) => print!(" {}", values[*v].name), Operand::Block(b) => print!(" {}", self.blocks[*b].label), Operand::Immediate(i) => print!(" {}", i), } } println!(); } } pub fn demonstrate_arena_efficiency() { let mut arena = Arena::<String>::new(); let mut ids = Vec::new(); for i in 0..1000 { let id = arena.alloc(format!("Node {}", i)); ids.push(id); } println!("Arena statistics:"); println!(" Allocated {} strings", ids.len()); println!(" Arena len: {}", arena.len()); let sample_ids: Vec<_> = ids.iter().step_by(100).take(5).collect(); println!(" Sample accesses:"); for &id in &sample_ids { println!(" {} -> {}", id.index(), arena[*id]); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_ast_construction() { let mut compiler = Compiler::new(); let program_id = compiler.build_example_ast(); assert_eq!(compiler.ast_arena.len(), 8); assert!(matches!( compiler.ast_arena[program_id].kind, NodeKind::Program )); } #[test] fn test_symbol_lookup() { let mut compiler = Compiler::new(); compiler.build_example_ast(); assert!(compiler.symbol_table.contains_key("add")); let func_id = compiler.symbol_table["add"]; assert!(matches!( compiler.ast_arena[func_id].kind, NodeKind::Function { .. } )); } #[test] fn test_ir_construction() { let mut ir = InstructionArena::new(); let mut values = Arena::new(); let block_id = ir.create_example_ir(&mut values); assert_eq!(ir.instructions.len(), 4); assert_eq!(values.len(), 3); assert_eq!(ir.blocks[block_id].instructions.len(), 3); } } pub struct Compiler { pub ast_arena: Arena<AstNode>, pub type_arena: Arena<Type>, pub symbol_table: HashMap<String, Id<AstNode>>, } }
The compiler struct owns the arenas for both AST nodes and types. Nodes reference each other through IDs rather than pointers, eliminating lifetime concerns and enabling flexible tree manipulation.
Building Complex Trees
The arena pattern makes building complex AST structures straightforward:
#![allow(unused)] fn main() { use std::collections::HashMap; use id_arena::{Arena, Id}; #[derive(Debug, Clone)] pub struct AstNode { pub kind: NodeKind, pub ty: Option<Id<Type>>, pub children: Vec<Id<AstNode>>, } #[derive(Debug, Clone)] pub enum NodeKind { Program, Function { name: String, params: Vec<Id<AstNode>>, body: Id<AstNode>, }, Parameter { name: String, }, Block, VariableDecl { name: String, init: Option<Id<AstNode>>, }, BinaryOp { op: BinaryOperator, left: Id<AstNode>, right: Id<AstNode>, }, Literal(Literal), Identifier(String), } #[derive(Debug, Clone)] pub enum BinaryOperator { Add, Sub, Mul, Div, Eq, Lt, } #[derive(Debug, Clone)] pub enum Literal { Integer(i64), Float(f64), String(String), Bool(bool), } #[derive(Debug, Clone)] pub struct Type { pub kind: TypeKind, } #[derive(Debug, Clone)] pub enum TypeKind { Int, Float, Bool, String, Function { params: Vec<Id<Type>>, ret: Id<Type>, }, Unknown, } pub struct Compiler { pub ast_arena: Arena<AstNode>, pub type_arena: Arena<Type>, pub symbol_table: HashMap<String, Id<AstNode>>, } impl Default for Compiler { fn default() -> Self { Self::new() } } pub struct InstructionArena { instructions: Arena<Instruction>, blocks: Arena<BasicBlock>, } #[derive(Debug)] pub struct Instruction { pub opcode: Opcode, pub operands: Vec<Operand>, pub result: Option<Id<Value>>, } #[derive(Debug)] pub enum Opcode { Add, Sub, Mul, Load, Store, Jump, Branch, Return, } #[derive(Debug)] pub enum Operand { Value(Id<Value>), Block(Id<BasicBlock>), Immediate(i64), } #[derive(Debug)] pub struct BasicBlock { pub label: String, pub instructions: Vec<Id<Instruction>>, pub terminator: Option<Id<Instruction>>, } #[derive(Debug)] pub struct Value { pub name: String, pub ty: ValueType, } #[derive(Debug)] pub enum ValueType { I32, I64, F32, F64, Ptr, } impl Default for InstructionArena { fn default() -> Self { Self::new() } } impl InstructionArena { pub fn new() -> Self { Self { instructions: Arena::new(), blocks: Arena::new(), } } pub fn create_example_ir(&mut self, values: &mut Arena<Value>) -> Id<BasicBlock> { let x = values.alloc(Value { name: "%x".to_string(), ty: ValueType::I32, }); let y = values.alloc(Value { name: "%y".to_string(), ty: ValueType::I32, }); let result = values.alloc(Value { name: "%result".to_string(), ty: ValueType::I32, }); let load_x = self.instructions.alloc(Instruction { opcode: Opcode::Load, operands: vec![Operand::Value(x)], result: Some(x), }); let load_y = self.instructions.alloc(Instruction { opcode: Opcode::Load, operands: vec![Operand::Value(y)], result: Some(y), }); let add = self.instructions.alloc(Instruction { opcode: Opcode::Add, operands: vec![Operand::Value(x), Operand::Value(y)], result: Some(result), }); let ret = self.instructions.alloc(Instruction { opcode: Opcode::Return, operands: vec![Operand::Value(result)], result: None, }); self.blocks.alloc(BasicBlock { label: "entry".to_string(), instructions: vec![load_x, load_y, add], terminator: Some(ret), }) } pub fn print_block(&self, block_id: Id<BasicBlock>, values: &Arena<Value>) { let block = &self.blocks[block_id]; println!("{}:", block.label); for &inst_id in &block.instructions { self.print_instruction(inst_id, values); } if let Some(term_id) = block.terminator { self.print_instruction(term_id, values); } } fn print_instruction(&self, inst_id: Id<Instruction>, values: &Arena<Value>) { let inst = &self.instructions[inst_id]; print!(" "); if let Some(result_id) = inst.result { print!("{} = ", values[result_id].name); } print!("{:?}", inst.opcode); for op in &inst.operands { match op { Operand::Value(v) => print!(" {}", values[*v].name), Operand::Block(b) => print!(" {}", self.blocks[*b].label), Operand::Immediate(i) => print!(" {}", i), } } println!(); } } pub fn demonstrate_arena_efficiency() { let mut arena = Arena::<String>::new(); let mut ids = Vec::new(); for i in 0..1000 { let id = arena.alloc(format!("Node {}", i)); ids.push(id); } println!("Arena statistics:"); println!(" Allocated {} strings", ids.len()); println!(" Arena len: {}", arena.len()); let sample_ids: Vec<_> = ids.iter().step_by(100).take(5).collect(); println!(" Sample accesses:"); for &id in &sample_ids { println!(" {} -> {}", id.index(), arena[*id]); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_ast_construction() { let mut compiler = Compiler::new(); let program_id = compiler.build_example_ast(); assert_eq!(compiler.ast_arena.len(), 8); assert!(matches!( compiler.ast_arena[program_id].kind, NodeKind::Program )); } #[test] fn test_symbol_lookup() { let mut compiler = Compiler::new(); compiler.build_example_ast(); assert!(compiler.symbol_table.contains_key("add")); let func_id = compiler.symbol_table["add"]; assert!(matches!( compiler.ast_arena[func_id].kind, NodeKind::Function { .. } )); } #[test] fn test_ir_construction() { let mut ir = InstructionArena::new(); let mut values = Arena::new(); let block_id = ir.create_example_ir(&mut values); assert_eq!(ir.instructions.len(), 4); assert_eq!(values.len(), 3); assert_eq!(ir.blocks[block_id].instructions.len(), 3); } } impl Compiler { pub fn new() -> Self { Self { ast_arena: Arena::new(), type_arena: Arena::new(), symbol_table: HashMap::new(), } } pub fn build_example_ast(&mut self) -> Id<AstNode> { let int_type = self.type_arena.alloc(Type { kind: TypeKind::Int, }); let x_param = self.ast_arena.alloc(AstNode { kind: NodeKind::Parameter { name: "x".to_string(), }, ty: Some(int_type), children: vec![], }); let y_param = self.ast_arena.alloc(AstNode { kind: NodeKind::Parameter { name: "y".to_string(), }, ty: Some(int_type), children: vec![], }); let x_ident = self.ast_arena.alloc(AstNode { kind: NodeKind::Identifier("x".to_string()), ty: Some(int_type), children: vec![], }); let y_ident = self.ast_arena.alloc(AstNode { kind: NodeKind::Identifier("y".to_string()), ty: Some(int_type), children: vec![], }); let add_expr = self.ast_arena.alloc(AstNode { kind: NodeKind::BinaryOp { op: BinaryOperator::Add, left: x_ident, right: y_ident, }, ty: Some(int_type), children: vec![x_ident, y_ident], }); let body = self.ast_arena.alloc(AstNode { kind: NodeKind::Block, ty: None, children: vec![add_expr], }); let add_func = self.ast_arena.alloc(AstNode { kind: NodeKind::Function { name: "add".to_string(), params: vec![x_param, y_param], body, }, ty: None, children: vec![x_param, y_param, body], }); self.symbol_table.insert("add".to_string(), add_func); self.ast_arena.alloc(AstNode { kind: NodeKind::Program, ty: None, children: vec![add_func], }) } pub fn print_ast(&self, id: Id<AstNode>, depth: usize) { let indent = " ".repeat(depth); let node = &self.ast_arena[id]; match &node.kind { NodeKind::Program => println!("{}Program", indent), NodeKind::Function { name, params, body } => { println!("{}Function: {}", indent, name); println!("{} Parameters:", indent); for ¶m_id in params { self.print_ast(param_id, depth + 2); } println!("{} Body:", indent); self.print_ast(*body, depth + 2); } NodeKind::Parameter { name } => { println!( "{}Parameter: {} (type: {:?})", indent, name, node.ty.map(|t| &self.type_arena[t].kind) ); } NodeKind::Block => { println!("{}Block", indent); for &child in &node.children { self.print_ast(child, depth + 1); } } NodeKind::BinaryOp { op, left, right } => { println!("{}BinaryOp: {:?}", indent, op); self.print_ast(*left, depth + 1); self.print_ast(*right, depth + 1); } NodeKind::Identifier(name) => println!("{}Identifier: {}", indent, name), NodeKind::Literal(lit) => println!("{}Literal: {:?}", indent, lit), NodeKind::VariableDecl { name, init } => { println!("{}VariableDecl: {}", indent, name); if let Some(init_id) = init { self.print_ast(*init_id, depth + 1); } } } } } }
This example shows how function definitions, parameters, and expressions are allocated in the arena with proper parent-child relationships maintained through ID vectors.
Intermediate Representation
Arenas work equally well for compiler IR where instructions reference values and basic blocks:
#![allow(unused)] fn main() { use std::collections::HashMap; use id_arena::{Arena, Id}; #[derive(Debug, Clone)] pub struct AstNode { pub kind: NodeKind, pub ty: Option<Id<Type>>, pub children: Vec<Id<AstNode>>, } #[derive(Debug, Clone)] pub enum NodeKind { Program, Function { name: String, params: Vec<Id<AstNode>>, body: Id<AstNode>, }, Parameter { name: String, }, Block, VariableDecl { name: String, init: Option<Id<AstNode>>, }, BinaryOp { op: BinaryOperator, left: Id<AstNode>, right: Id<AstNode>, }, Literal(Literal), Identifier(String), } #[derive(Debug, Clone)] pub enum BinaryOperator { Add, Sub, Mul, Div, Eq, Lt, } #[derive(Debug, Clone)] pub enum Literal { Integer(i64), Float(f64), String(String), Bool(bool), } #[derive(Debug, Clone)] pub struct Type { pub kind: TypeKind, } #[derive(Debug, Clone)] pub enum TypeKind { Int, Float, Bool, String, Function { params: Vec<Id<Type>>, ret: Id<Type>, }, Unknown, } pub struct Compiler { pub ast_arena: Arena<AstNode>, pub type_arena: Arena<Type>, pub symbol_table: HashMap<String, Id<AstNode>>, } impl Default for Compiler { fn default() -> Self { Self::new() } } impl Compiler { pub fn new() -> Self { Self { ast_arena: Arena::new(), type_arena: Arena::new(), symbol_table: HashMap::new(), } } pub fn build_example_ast(&mut self) -> Id<AstNode> { let int_type = self.type_arena.alloc(Type { kind: TypeKind::Int, }); let x_param = self.ast_arena.alloc(AstNode { kind: NodeKind::Parameter { name: "x".to_string(), }, ty: Some(int_type), children: vec![], }); let y_param = self.ast_arena.alloc(AstNode { kind: NodeKind::Parameter { name: "y".to_string(), }, ty: Some(int_type), children: vec![], }); let x_ident = self.ast_arena.alloc(AstNode { kind: NodeKind::Identifier("x".to_string()), ty: Some(int_type), children: vec![], }); let y_ident = self.ast_arena.alloc(AstNode { kind: NodeKind::Identifier("y".to_string()), ty: Some(int_type), children: vec![], }); let add_expr = self.ast_arena.alloc(AstNode { kind: NodeKind::BinaryOp { op: BinaryOperator::Add, left: x_ident, right: y_ident, }, ty: Some(int_type), children: vec![x_ident, y_ident], }); let body = self.ast_arena.alloc(AstNode { kind: NodeKind::Block, ty: None, children: vec![add_expr], }); let add_func = self.ast_arena.alloc(AstNode { kind: NodeKind::Function { name: "add".to_string(), params: vec![x_param, y_param], body, }, ty: None, children: vec![x_param, y_param, body], }); self.symbol_table.insert("add".to_string(), add_func); self.ast_arena.alloc(AstNode { kind: NodeKind::Program, ty: None, children: vec![add_func], }) } pub fn print_ast(&self, id: Id<AstNode>, depth: usize) { let indent = " ".repeat(depth); let node = &self.ast_arena[id]; match &node.kind { NodeKind::Program => println!("{}Program", indent), NodeKind::Function { name, params, body } => { println!("{}Function: {}", indent, name); println!("{} Parameters:", indent); for ¶m_id in params { self.print_ast(param_id, depth + 2); } println!("{} Body:", indent); self.print_ast(*body, depth + 2); } NodeKind::Parameter { name } => { println!( "{}Parameter: {} (type: {:?})", indent, name, node.ty.map(|t| &self.type_arena[t].kind) ); } NodeKind::Block => { println!("{}Block", indent); for &child in &node.children { self.print_ast(child, depth + 1); } } NodeKind::BinaryOp { op, left, right } => { println!("{}BinaryOp: {:?}", indent, op); self.print_ast(*left, depth + 1); self.print_ast(*right, depth + 1); } NodeKind::Identifier(name) => println!("{}Identifier: {}", indent, name), NodeKind::Literal(lit) => println!("{}Literal: {:?}", indent, lit), NodeKind::VariableDecl { name, init } => { println!("{}VariableDecl: {}", indent, name); if let Some(init_id) = init { self.print_ast(*init_id, depth + 1); } } } } } #[derive(Debug)] pub struct Instruction { pub opcode: Opcode, pub operands: Vec<Operand>, pub result: Option<Id<Value>>, } #[derive(Debug)] pub enum Opcode { Add, Sub, Mul, Load, Store, Jump, Branch, Return, } #[derive(Debug)] pub enum Operand { Value(Id<Value>), Block(Id<BasicBlock>), Immediate(i64), } #[derive(Debug)] pub struct BasicBlock { pub label: String, pub instructions: Vec<Id<Instruction>>, pub terminator: Option<Id<Instruction>>, } #[derive(Debug)] pub struct Value { pub name: String, pub ty: ValueType, } #[derive(Debug)] pub enum ValueType { I32, I64, F32, F64, Ptr, } impl Default for InstructionArena { fn default() -> Self { Self::new() } } impl InstructionArena { pub fn new() -> Self { Self { instructions: Arena::new(), blocks: Arena::new(), } } pub fn create_example_ir(&mut self, values: &mut Arena<Value>) -> Id<BasicBlock> { let x = values.alloc(Value { name: "%x".to_string(), ty: ValueType::I32, }); let y = values.alloc(Value { name: "%y".to_string(), ty: ValueType::I32, }); let result = values.alloc(Value { name: "%result".to_string(), ty: ValueType::I32, }); let load_x = self.instructions.alloc(Instruction { opcode: Opcode::Load, operands: vec![Operand::Value(x)], result: Some(x), }); let load_y = self.instructions.alloc(Instruction { opcode: Opcode::Load, operands: vec![Operand::Value(y)], result: Some(y), }); let add = self.instructions.alloc(Instruction { opcode: Opcode::Add, operands: vec![Operand::Value(x), Operand::Value(y)], result: Some(result), }); let ret = self.instructions.alloc(Instruction { opcode: Opcode::Return, operands: vec![Operand::Value(result)], result: None, }); self.blocks.alloc(BasicBlock { label: "entry".to_string(), instructions: vec![load_x, load_y, add], terminator: Some(ret), }) } pub fn print_block(&self, block_id: Id<BasicBlock>, values: &Arena<Value>) { let block = &self.blocks[block_id]; println!("{}:", block.label); for &inst_id in &block.instructions { self.print_instruction(inst_id, values); } if let Some(term_id) = block.terminator { self.print_instruction(term_id, values); } } fn print_instruction(&self, inst_id: Id<Instruction>, values: &Arena<Value>) { let inst = &self.instructions[inst_id]; print!(" "); if let Some(result_id) = inst.result { print!("{} = ", values[result_id].name); } print!("{:?}", inst.opcode); for op in &inst.operands { match op { Operand::Value(v) => print!(" {}", values[*v].name), Operand::Block(b) => print!(" {}", self.blocks[*b].label), Operand::Immediate(i) => print!(" {}", i), } } println!(); } } pub fn demonstrate_arena_efficiency() { let mut arena = Arena::<String>::new(); let mut ids = Vec::new(); for i in 0..1000 { let id = arena.alloc(format!("Node {}", i)); ids.push(id); } println!("Arena statistics:"); println!(" Allocated {} strings", ids.len()); println!(" Arena len: {}", arena.len()); let sample_ids: Vec<_> = ids.iter().step_by(100).take(5).collect(); println!(" Sample accesses:"); for &id in &sample_ids { println!(" {} -> {}", id.index(), arena[*id]); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_ast_construction() { let mut compiler = Compiler::new(); let program_id = compiler.build_example_ast(); assert_eq!(compiler.ast_arena.len(), 8); assert!(matches!( compiler.ast_arena[program_id].kind, NodeKind::Program )); } #[test] fn test_symbol_lookup() { let mut compiler = Compiler::new(); compiler.build_example_ast(); assert!(compiler.symbol_table.contains_key("add")); let func_id = compiler.symbol_table["add"]; assert!(matches!( compiler.ast_arena[func_id].kind, NodeKind::Function { .. } )); } #[test] fn test_ir_construction() { let mut ir = InstructionArena::new(); let mut values = Arena::new(); let block_id = ir.create_example_ir(&mut values); assert_eq!(ir.instructions.len(), 4); assert_eq!(values.len(), 3); assert_eq!(ir.blocks[block_id].instructions.len(), 3); } } pub struct InstructionArena { instructions: Arena<Instruction>, blocks: Arena<BasicBlock>, } }
#![allow(unused)] fn main() { use std::collections::HashMap; use id_arena::{Arena, Id}; #[derive(Debug, Clone)] pub struct AstNode { pub kind: NodeKind, pub ty: Option<Id<Type>>, pub children: Vec<Id<AstNode>>, } #[derive(Debug, Clone)] pub enum NodeKind { Program, Function { name: String, params: Vec<Id<AstNode>>, body: Id<AstNode>, }, Parameter { name: String, }, Block, VariableDecl { name: String, init: Option<Id<AstNode>>, }, BinaryOp { op: BinaryOperator, left: Id<AstNode>, right: Id<AstNode>, }, Literal(Literal), Identifier(String), } #[derive(Debug, Clone)] pub enum BinaryOperator { Add, Sub, Mul, Div, Eq, Lt, } #[derive(Debug, Clone)] pub enum Literal { Integer(i64), Float(f64), String(String), Bool(bool), } #[derive(Debug, Clone)] pub struct Type { pub kind: TypeKind, } #[derive(Debug, Clone)] pub enum TypeKind { Int, Float, Bool, String, Function { params: Vec<Id<Type>>, ret: Id<Type>, }, Unknown, } pub struct Compiler { pub ast_arena: Arena<AstNode>, pub type_arena: Arena<Type>, pub symbol_table: HashMap<String, Id<AstNode>>, } impl Default for Compiler { fn default() -> Self { Self::new() } } impl Compiler { pub fn new() -> Self { Self { ast_arena: Arena::new(), type_arena: Arena::new(), symbol_table: HashMap::new(), } } pub fn build_example_ast(&mut self) -> Id<AstNode> { let int_type = self.type_arena.alloc(Type { kind: TypeKind::Int, }); let x_param = self.ast_arena.alloc(AstNode { kind: NodeKind::Parameter { name: "x".to_string(), }, ty: Some(int_type), children: vec![], }); let y_param = self.ast_arena.alloc(AstNode { kind: NodeKind::Parameter { name: "y".to_string(), }, ty: Some(int_type), children: vec![], }); let x_ident = self.ast_arena.alloc(AstNode { kind: NodeKind::Identifier("x".to_string()), ty: Some(int_type), children: vec![], }); let y_ident = self.ast_arena.alloc(AstNode { kind: NodeKind::Identifier("y".to_string()), ty: Some(int_type), children: vec![], }); let add_expr = self.ast_arena.alloc(AstNode { kind: NodeKind::BinaryOp { op: BinaryOperator::Add, left: x_ident, right: y_ident, }, ty: Some(int_type), children: vec![x_ident, y_ident], }); let body = self.ast_arena.alloc(AstNode { kind: NodeKind::Block, ty: None, children: vec![add_expr], }); let add_func = self.ast_arena.alloc(AstNode { kind: NodeKind::Function { name: "add".to_string(), params: vec![x_param, y_param], body, }, ty: None, children: vec![x_param, y_param, body], }); self.symbol_table.insert("add".to_string(), add_func); self.ast_arena.alloc(AstNode { kind: NodeKind::Program, ty: None, children: vec![add_func], }) } pub fn print_ast(&self, id: Id<AstNode>, depth: usize) { let indent = " ".repeat(depth); let node = &self.ast_arena[id]; match &node.kind { NodeKind::Program => println!("{}Program", indent), NodeKind::Function { name, params, body } => { println!("{}Function: {}", indent, name); println!("{} Parameters:", indent); for ¶m_id in params { self.print_ast(param_id, depth + 2); } println!("{} Body:", indent); self.print_ast(*body, depth + 2); } NodeKind::Parameter { name } => { println!( "{}Parameter: {} (type: {:?})", indent, name, node.ty.map(|t| &self.type_arena[t].kind) ); } NodeKind::Block => { println!("{}Block", indent); for &child in &node.children { self.print_ast(child, depth + 1); } } NodeKind::BinaryOp { op, left, right } => { println!("{}BinaryOp: {:?}", indent, op); self.print_ast(*left, depth + 1); self.print_ast(*right, depth + 1); } NodeKind::Identifier(name) => println!("{}Identifier: {}", indent, name), NodeKind::Literal(lit) => println!("{}Literal: {:?}", indent, lit), NodeKind::VariableDecl { name, init } => { println!("{}VariableDecl: {}", indent, name); if let Some(init_id) = init { self.print_ast(*init_id, depth + 1); } } } } } pub struct InstructionArena { instructions: Arena<Instruction>, blocks: Arena<BasicBlock>, } #[derive(Debug)] pub enum Opcode { Add, Sub, Mul, Load, Store, Jump, Branch, Return, } #[derive(Debug)] pub enum Operand { Value(Id<Value>), Block(Id<BasicBlock>), Immediate(i64), } #[derive(Debug)] pub struct BasicBlock { pub label: String, pub instructions: Vec<Id<Instruction>>, pub terminator: Option<Id<Instruction>>, } #[derive(Debug)] pub struct Value { pub name: String, pub ty: ValueType, } #[derive(Debug)] pub enum ValueType { I32, I64, F32, F64, Ptr, } impl Default for InstructionArena { fn default() -> Self { Self::new() } } impl InstructionArena { pub fn new() -> Self { Self { instructions: Arena::new(), blocks: Arena::new(), } } pub fn create_example_ir(&mut self, values: &mut Arena<Value>) -> Id<BasicBlock> { let x = values.alloc(Value { name: "%x".to_string(), ty: ValueType::I32, }); let y = values.alloc(Value { name: "%y".to_string(), ty: ValueType::I32, }); let result = values.alloc(Value { name: "%result".to_string(), ty: ValueType::I32, }); let load_x = self.instructions.alloc(Instruction { opcode: Opcode::Load, operands: vec![Operand::Value(x)], result: Some(x), }); let load_y = self.instructions.alloc(Instruction { opcode: Opcode::Load, operands: vec![Operand::Value(y)], result: Some(y), }); let add = self.instructions.alloc(Instruction { opcode: Opcode::Add, operands: vec![Operand::Value(x), Operand::Value(y)], result: Some(result), }); let ret = self.instructions.alloc(Instruction { opcode: Opcode::Return, operands: vec![Operand::Value(result)], result: None, }); self.blocks.alloc(BasicBlock { label: "entry".to_string(), instructions: vec![load_x, load_y, add], terminator: Some(ret), }) } pub fn print_block(&self, block_id: Id<BasicBlock>, values: &Arena<Value>) { let block = &self.blocks[block_id]; println!("{}:", block.label); for &inst_id in &block.instructions { self.print_instruction(inst_id, values); } if let Some(term_id) = block.terminator { self.print_instruction(term_id, values); } } fn print_instruction(&self, inst_id: Id<Instruction>, values: &Arena<Value>) { let inst = &self.instructions[inst_id]; print!(" "); if let Some(result_id) = inst.result { print!("{} = ", values[result_id].name); } print!("{:?}", inst.opcode); for op in &inst.operands { match op { Operand::Value(v) => print!(" {}", values[*v].name), Operand::Block(b) => print!(" {}", self.blocks[*b].label), Operand::Immediate(i) => print!(" {}", i), } } println!(); } } pub fn demonstrate_arena_efficiency() { let mut arena = Arena::<String>::new(); let mut ids = Vec::new(); for i in 0..1000 { let id = arena.alloc(format!("Node {}", i)); ids.push(id); } println!("Arena statistics:"); println!(" Allocated {} strings", ids.len()); println!(" Arena len: {}", arena.len()); let sample_ids: Vec<_> = ids.iter().step_by(100).take(5).collect(); println!(" Sample accesses:"); for &id in &sample_ids { println!(" {} -> {}", id.index(), arena[*id]); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_ast_construction() { let mut compiler = Compiler::new(); let program_id = compiler.build_example_ast(); assert_eq!(compiler.ast_arena.len(), 8); assert!(matches!( compiler.ast_arena[program_id].kind, NodeKind::Program )); } #[test] fn test_symbol_lookup() { let mut compiler = Compiler::new(); compiler.build_example_ast(); assert!(compiler.symbol_table.contains_key("add")); let func_id = compiler.symbol_table["add"]; assert!(matches!( compiler.ast_arena[func_id].kind, NodeKind::Function { .. } )); } #[test] fn test_ir_construction() { let mut ir = InstructionArena::new(); let mut values = Arena::new(); let block_id = ir.create_example_ir(&mut values); assert_eq!(ir.instructions.len(), 4); assert_eq!(values.len(), 3); assert_eq!(ir.blocks[block_id].instructions.len(), 3); } } #[derive(Debug)] pub struct Instruction { pub opcode: Opcode, pub operands: Vec<Operand>, pub result: Option<Id<Value>>, } }
Instructions can reference values and blocks through IDs. The arena owns all the data, making memory management automatic and efficient.
Type Representation
Type systems benefit from arena allocation when types can be recursive or mutually referential:
#![allow(unused)] fn main() { use std::collections::HashMap; use id_arena::{Arena, Id}; #[derive(Debug, Clone)] pub struct AstNode { pub kind: NodeKind, pub ty: Option<Id<Type>>, pub children: Vec<Id<AstNode>>, } #[derive(Debug, Clone)] pub enum NodeKind { Program, Function { name: String, params: Vec<Id<AstNode>>, body: Id<AstNode>, }, Parameter { name: String, }, Block, VariableDecl { name: String, init: Option<Id<AstNode>>, }, BinaryOp { op: BinaryOperator, left: Id<AstNode>, right: Id<AstNode>, }, Literal(Literal), Identifier(String), } #[derive(Debug, Clone)] pub enum BinaryOperator { Add, Sub, Mul, Div, Eq, Lt, } #[derive(Debug, Clone)] pub enum Literal { Integer(i64), Float(f64), String(String), Bool(bool), } #[derive(Debug, Clone)] pub enum TypeKind { Int, Float, Bool, String, Function { params: Vec<Id<Type>>, ret: Id<Type>, }, Unknown, } pub struct Compiler { pub ast_arena: Arena<AstNode>, pub type_arena: Arena<Type>, pub symbol_table: HashMap<String, Id<AstNode>>, } impl Default for Compiler { fn default() -> Self { Self::new() } } impl Compiler { pub fn new() -> Self { Self { ast_arena: Arena::new(), type_arena: Arena::new(), symbol_table: HashMap::new(), } } pub fn build_example_ast(&mut self) -> Id<AstNode> { let int_type = self.type_arena.alloc(Type { kind: TypeKind::Int, }); let x_param = self.ast_arena.alloc(AstNode { kind: NodeKind::Parameter { name: "x".to_string(), }, ty: Some(int_type), children: vec![], }); let y_param = self.ast_arena.alloc(AstNode { kind: NodeKind::Parameter { name: "y".to_string(), }, ty: Some(int_type), children: vec![], }); let x_ident = self.ast_arena.alloc(AstNode { kind: NodeKind::Identifier("x".to_string()), ty: Some(int_type), children: vec![], }); let y_ident = self.ast_arena.alloc(AstNode { kind: NodeKind::Identifier("y".to_string()), ty: Some(int_type), children: vec![], }); let add_expr = self.ast_arena.alloc(AstNode { kind: NodeKind::BinaryOp { op: BinaryOperator::Add, left: x_ident, right: y_ident, }, ty: Some(int_type), children: vec![x_ident, y_ident], }); let body = self.ast_arena.alloc(AstNode { kind: NodeKind::Block, ty: None, children: vec![add_expr], }); let add_func = self.ast_arena.alloc(AstNode { kind: NodeKind::Function { name: "add".to_string(), params: vec![x_param, y_param], body, }, ty: None, children: vec![x_param, y_param, body], }); self.symbol_table.insert("add".to_string(), add_func); self.ast_arena.alloc(AstNode { kind: NodeKind::Program, ty: None, children: vec![add_func], }) } pub fn print_ast(&self, id: Id<AstNode>, depth: usize) { let indent = " ".repeat(depth); let node = &self.ast_arena[id]; match &node.kind { NodeKind::Program => println!("{}Program", indent), NodeKind::Function { name, params, body } => { println!("{}Function: {}", indent, name); println!("{} Parameters:", indent); for ¶m_id in params { self.print_ast(param_id, depth + 2); } println!("{} Body:", indent); self.print_ast(*body, depth + 2); } NodeKind::Parameter { name } => { println!( "{}Parameter: {} (type: {:?})", indent, name, node.ty.map(|t| &self.type_arena[t].kind) ); } NodeKind::Block => { println!("{}Block", indent); for &child in &node.children { self.print_ast(child, depth + 1); } } NodeKind::BinaryOp { op, left, right } => { println!("{}BinaryOp: {:?}", indent, op); self.print_ast(*left, depth + 1); self.print_ast(*right, depth + 1); } NodeKind::Identifier(name) => println!("{}Identifier: {}", indent, name), NodeKind::Literal(lit) => println!("{}Literal: {:?}", indent, lit), NodeKind::VariableDecl { name, init } => { println!("{}VariableDecl: {}", indent, name); if let Some(init_id) = init { self.print_ast(*init_id, depth + 1); } } } } } pub struct InstructionArena { instructions: Arena<Instruction>, blocks: Arena<BasicBlock>, } #[derive(Debug)] pub struct Instruction { pub opcode: Opcode, pub operands: Vec<Operand>, pub result: Option<Id<Value>>, } #[derive(Debug)] pub enum Opcode { Add, Sub, Mul, Load, Store, Jump, Branch, Return, } #[derive(Debug)] pub enum Operand { Value(Id<Value>), Block(Id<BasicBlock>), Immediate(i64), } #[derive(Debug)] pub struct BasicBlock { pub label: String, pub instructions: Vec<Id<Instruction>>, pub terminator: Option<Id<Instruction>>, } #[derive(Debug)] pub struct Value { pub name: String, pub ty: ValueType, } #[derive(Debug)] pub enum ValueType { I32, I64, F32, F64, Ptr, } impl Default for InstructionArena { fn default() -> Self { Self::new() } } impl InstructionArena { pub fn new() -> Self { Self { instructions: Arena::new(), blocks: Arena::new(), } } pub fn create_example_ir(&mut self, values: &mut Arena<Value>) -> Id<BasicBlock> { let x = values.alloc(Value { name: "%x".to_string(), ty: ValueType::I32, }); let y = values.alloc(Value { name: "%y".to_string(), ty: ValueType::I32, }); let result = values.alloc(Value { name: "%result".to_string(), ty: ValueType::I32, }); let load_x = self.instructions.alloc(Instruction { opcode: Opcode::Load, operands: vec![Operand::Value(x)], result: Some(x), }); let load_y = self.instructions.alloc(Instruction { opcode: Opcode::Load, operands: vec![Operand::Value(y)], result: Some(y), }); let add = self.instructions.alloc(Instruction { opcode: Opcode::Add, operands: vec![Operand::Value(x), Operand::Value(y)], result: Some(result), }); let ret = self.instructions.alloc(Instruction { opcode: Opcode::Return, operands: vec![Operand::Value(result)], result: None, }); self.blocks.alloc(BasicBlock { label: "entry".to_string(), instructions: vec![load_x, load_y, add], terminator: Some(ret), }) } pub fn print_block(&self, block_id: Id<BasicBlock>, values: &Arena<Value>) { let block = &self.blocks[block_id]; println!("{}:", block.label); for &inst_id in &block.instructions { self.print_instruction(inst_id, values); } if let Some(term_id) = block.terminator { self.print_instruction(term_id, values); } } fn print_instruction(&self, inst_id: Id<Instruction>, values: &Arena<Value>) { let inst = &self.instructions[inst_id]; print!(" "); if let Some(result_id) = inst.result { print!("{} = ", values[result_id].name); } print!("{:?}", inst.opcode); for op in &inst.operands { match op { Operand::Value(v) => print!(" {}", values[*v].name), Operand::Block(b) => print!(" {}", self.blocks[*b].label), Operand::Immediate(i) => print!(" {}", i), } } println!(); } } pub fn demonstrate_arena_efficiency() { let mut arena = Arena::<String>::new(); let mut ids = Vec::new(); for i in 0..1000 { let id = arena.alloc(format!("Node {}", i)); ids.push(id); } println!("Arena statistics:"); println!(" Allocated {} strings", ids.len()); println!(" Arena len: {}", arena.len()); let sample_ids: Vec<_> = ids.iter().step_by(100).take(5).collect(); println!(" Sample accesses:"); for &id in &sample_ids { println!(" {} -> {}", id.index(), arena[*id]); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_ast_construction() { let mut compiler = Compiler::new(); let program_id = compiler.build_example_ast(); assert_eq!(compiler.ast_arena.len(), 8); assert!(matches!( compiler.ast_arena[program_id].kind, NodeKind::Program )); } #[test] fn test_symbol_lookup() { let mut compiler = Compiler::new(); compiler.build_example_ast(); assert!(compiler.symbol_table.contains_key("add")); let func_id = compiler.symbol_table["add"]; assert!(matches!( compiler.ast_arena[func_id].kind, NodeKind::Function { .. } )); } #[test] fn test_ir_construction() { let mut ir = InstructionArena::new(); let mut values = Arena::new(); let block_id = ir.create_example_ir(&mut values); assert_eq!(ir.instructions.len(), 4); assert_eq!(values.len(), 3); assert_eq!(ir.blocks[block_id].instructions.len(), 3); } } #[derive(Debug, Clone)] pub struct Type { pub kind: TypeKind, } }
Function types reference parameter and return types through IDs, avoiding the complexity of boxed recursive types while maintaining type safety.
Traversal and Printing
Arena-based structures are easy to traverse since IDs can be followed without lifetime concerns:
The print_ast method is part of the Compiler impl:
#![allow(unused)] fn main() { pub fn print_ast(&self, id: Id<AstNode>, depth: usize) { let indent = " ".repeat(depth); let node = &self.ast_arena[id]; match &node.kind { NodeKind::Program => println!("{}Program", indent), NodeKind::Function { name, params, body } => { println!("{}Function: {}", indent, name); println!("{} Parameters:", indent); for ¶m_id in params { self.print_ast(param_id, depth + 2); } println!("{} Body:", indent); self.print_ast(*body, depth + 2); } // ... other node types } } }
The print function recursively follows IDs to traverse the tree. The arena provides indexed access to retrieve nodes by ID.
Performance Benefits
Arena allocation provides several performance advantages:
#![allow(unused)] fn main() { use std::collections::HashMap; use id_arena::{Arena, Id}; #[derive(Debug, Clone)] pub struct AstNode { pub kind: NodeKind, pub ty: Option<Id<Type>>, pub children: Vec<Id<AstNode>>, } #[derive(Debug, Clone)] pub enum NodeKind { Program, Function { name: String, params: Vec<Id<AstNode>>, body: Id<AstNode>, }, Parameter { name: String, }, Block, VariableDecl { name: String, init: Option<Id<AstNode>>, }, BinaryOp { op: BinaryOperator, left: Id<AstNode>, right: Id<AstNode>, }, Literal(Literal), Identifier(String), } #[derive(Debug, Clone)] pub enum BinaryOperator { Add, Sub, Mul, Div, Eq, Lt, } #[derive(Debug, Clone)] pub enum Literal { Integer(i64), Float(f64), String(String), Bool(bool), } #[derive(Debug, Clone)] pub struct Type { pub kind: TypeKind, } #[derive(Debug, Clone)] pub enum TypeKind { Int, Float, Bool, String, Function { params: Vec<Id<Type>>, ret: Id<Type>, }, Unknown, } pub struct Compiler { pub ast_arena: Arena<AstNode>, pub type_arena: Arena<Type>, pub symbol_table: HashMap<String, Id<AstNode>>, } impl Default for Compiler { fn default() -> Self { Self::new() } } impl Compiler { pub fn new() -> Self { Self { ast_arena: Arena::new(), type_arena: Arena::new(), symbol_table: HashMap::new(), } } pub fn build_example_ast(&mut self) -> Id<AstNode> { let int_type = self.type_arena.alloc(Type { kind: TypeKind::Int, }); let x_param = self.ast_arena.alloc(AstNode { kind: NodeKind::Parameter { name: "x".to_string(), }, ty: Some(int_type), children: vec![], }); let y_param = self.ast_arena.alloc(AstNode { kind: NodeKind::Parameter { name: "y".to_string(), }, ty: Some(int_type), children: vec![], }); let x_ident = self.ast_arena.alloc(AstNode { kind: NodeKind::Identifier("x".to_string()), ty: Some(int_type), children: vec![], }); let y_ident = self.ast_arena.alloc(AstNode { kind: NodeKind::Identifier("y".to_string()), ty: Some(int_type), children: vec![], }); let add_expr = self.ast_arena.alloc(AstNode { kind: NodeKind::BinaryOp { op: BinaryOperator::Add, left: x_ident, right: y_ident, }, ty: Some(int_type), children: vec![x_ident, y_ident], }); let body = self.ast_arena.alloc(AstNode { kind: NodeKind::Block, ty: None, children: vec![add_expr], }); let add_func = self.ast_arena.alloc(AstNode { kind: NodeKind::Function { name: "add".to_string(), params: vec![x_param, y_param], body, }, ty: None, children: vec![x_param, y_param, body], }); self.symbol_table.insert("add".to_string(), add_func); self.ast_arena.alloc(AstNode { kind: NodeKind::Program, ty: None, children: vec![add_func], }) } pub fn print_ast(&self, id: Id<AstNode>, depth: usize) { let indent = " ".repeat(depth); let node = &self.ast_arena[id]; match &node.kind { NodeKind::Program => println!("{}Program", indent), NodeKind::Function { name, params, body } => { println!("{}Function: {}", indent, name); println!("{} Parameters:", indent); for ¶m_id in params { self.print_ast(param_id, depth + 2); } println!("{} Body:", indent); self.print_ast(*body, depth + 2); } NodeKind::Parameter { name } => { println!( "{}Parameter: {} (type: {:?})", indent, name, node.ty.map(|t| &self.type_arena[t].kind) ); } NodeKind::Block => { println!("{}Block", indent); for &child in &node.children { self.print_ast(child, depth + 1); } } NodeKind::BinaryOp { op, left, right } => { println!("{}BinaryOp: {:?}", indent, op); self.print_ast(*left, depth + 1); self.print_ast(*right, depth + 1); } NodeKind::Identifier(name) => println!("{}Identifier: {}", indent, name), NodeKind::Literal(lit) => println!("{}Literal: {:?}", indent, lit), NodeKind::VariableDecl { name, init } => { println!("{}VariableDecl: {}", indent, name); if let Some(init_id) = init { self.print_ast(*init_id, depth + 1); } } } } } pub struct InstructionArena { instructions: Arena<Instruction>, blocks: Arena<BasicBlock>, } #[derive(Debug)] pub struct Instruction { pub opcode: Opcode, pub operands: Vec<Operand>, pub result: Option<Id<Value>>, } #[derive(Debug)] pub enum Opcode { Add, Sub, Mul, Load, Store, Jump, Branch, Return, } #[derive(Debug)] pub enum Operand { Value(Id<Value>), Block(Id<BasicBlock>), Immediate(i64), } #[derive(Debug)] pub struct BasicBlock { pub label: String, pub instructions: Vec<Id<Instruction>>, pub terminator: Option<Id<Instruction>>, } #[derive(Debug)] pub struct Value { pub name: String, pub ty: ValueType, } #[derive(Debug)] pub enum ValueType { I32, I64, F32, F64, Ptr, } impl Default for InstructionArena { fn default() -> Self { Self::new() } } impl InstructionArena { pub fn new() -> Self { Self { instructions: Arena::new(), blocks: Arena::new(), } } pub fn create_example_ir(&mut self, values: &mut Arena<Value>) -> Id<BasicBlock> { let x = values.alloc(Value { name: "%x".to_string(), ty: ValueType::I32, }); let y = values.alloc(Value { name: "%y".to_string(), ty: ValueType::I32, }); let result = values.alloc(Value { name: "%result".to_string(), ty: ValueType::I32, }); let load_x = self.instructions.alloc(Instruction { opcode: Opcode::Load, operands: vec![Operand::Value(x)], result: Some(x), }); let load_y = self.instructions.alloc(Instruction { opcode: Opcode::Load, operands: vec![Operand::Value(y)], result: Some(y), }); let add = self.instructions.alloc(Instruction { opcode: Opcode::Add, operands: vec![Operand::Value(x), Operand::Value(y)], result: Some(result), }); let ret = self.instructions.alloc(Instruction { opcode: Opcode::Return, operands: vec![Operand::Value(result)], result: None, }); self.blocks.alloc(BasicBlock { label: "entry".to_string(), instructions: vec![load_x, load_y, add], terminator: Some(ret), }) } pub fn print_block(&self, block_id: Id<BasicBlock>, values: &Arena<Value>) { let block = &self.blocks[block_id]; println!("{}:", block.label); for &inst_id in &block.instructions { self.print_instruction(inst_id, values); } if let Some(term_id) = block.terminator { self.print_instruction(term_id, values); } } fn print_instruction(&self, inst_id: Id<Instruction>, values: &Arena<Value>) { let inst = &self.instructions[inst_id]; print!(" "); if let Some(result_id) = inst.result { print!("{} = ", values[result_id].name); } print!("{:?}", inst.opcode); for op in &inst.operands { match op { Operand::Value(v) => print!(" {}", values[*v].name), Operand::Block(b) => print!(" {}", self.blocks[*b].label), Operand::Immediate(i) => print!(" {}", i), } } println!(); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_ast_construction() { let mut compiler = Compiler::new(); let program_id = compiler.build_example_ast(); assert_eq!(compiler.ast_arena.len(), 8); assert!(matches!( compiler.ast_arena[program_id].kind, NodeKind::Program )); } #[test] fn test_symbol_lookup() { let mut compiler = Compiler::new(); compiler.build_example_ast(); assert!(compiler.symbol_table.contains_key("add")); let func_id = compiler.symbol_table["add"]; assert!(matches!( compiler.ast_arena[func_id].kind, NodeKind::Function { .. } )); } #[test] fn test_ir_construction() { let mut ir = InstructionArena::new(); let mut values = Arena::new(); let block_id = ir.create_example_ir(&mut values); assert_eq!(ir.instructions.len(), 4); assert_eq!(values.len(), 3); assert_eq!(ir.blocks[block_id].instructions.len(), 3); } } pub fn demonstrate_arena_efficiency() { let mut arena = Arena::<String>::new(); let mut ids = Vec::new(); for i in 0..1000 { let id = arena.alloc(format!("Node {}", i)); ids.push(id); } println!("Arena statistics:"); println!(" Allocated {} strings", ids.len()); println!(" Arena len: {}", arena.len()); let sample_ids: Vec<_> = ids.iter().step_by(100).take(5).collect(); println!(" Sample accesses:"); for &id in &sample_ids { println!(" {} -> {}", id.index(), arena[*id]); } } }
Arenas allocate memory in large chunks, reducing allocator overhead. All nodes are stored contiguously, improving cache locality during traversal.
Best Practices
Structure your compiler with dedicated arenas for different types of data. Separate arenas for AST nodes, types, and IR allows independent manipulation and clearer ownership. Each compiler pass can create its own arenas for intermediate data.
Use newtype wrappers around IDs when you have multiple arena types. While id-arena provides type safety through phantom types, additional newtype wrappers can prevent mixing IDs from different logical domains.
Consider arena granularity carefully. One arena for all AST nodes is simpler but prevents partial deallocation. Multiple smaller arenas allow freeing memory between compiler phases but require more careful ID management.
Leverage ID stability for caching and incremental compilation. Unlike pointers, IDs remain valid even if you add more items to the arena. This makes them ideal keys for analysis results and cached computations.
Use arena iteration for whole-program analyses. Arenas provide efficient iteration over all allocated items, useful for passes that need to examine every node, type, or instruction.
Be mindful of arena growth patterns. Arenas never shrink, so long-lived arenas in language servers or watch modes can accumulate memory. Consider periodic arena recreation for long-running processes.
Take advantage of ID copyability for parallel analysis. IDs can be freely sent between threads, enabling parallel compiler passes without complex synchronization. Each thread can safely read from shared arenas while building its own result arenas.
indexmap
The indexmap
crate provides hash maps and sets that maintain insertion order. In compiler development, preserving the order of definitions, declarations, and operations is often crucial for deterministic output, meaningful error messages, and correct code generation. While standard hash maps offer O(1) access, they randomize iteration order, which can lead to non-deterministic compiler behavior. IndexMap combines the performance of hash maps with predictable iteration order.
Compilers frequently need ordered collections for symbol tables, import lists, struct field definitions, function parameters, and type registries. IndexMap ensures that iterating over these collections always produces the same order as insertion, which is essential for reproducible builds and stable compiler output across different runs.
Symbol Tables with Scopes
Symbol tables are fundamental to compilers, tracking identifiers and their associated information:
#![allow(unused)] fn main() { use std::hash::Hash; use indexmap::map::Entry; use indexmap::{IndexMap, IndexSet}; #[derive(Debug, Clone, PartialEq)] pub struct Symbol { pub name: String, pub kind: SymbolKind, pub scope_level: usize, } #[derive(Debug, Clone, PartialEq)] pub enum SymbolKind { Variable { mutable: bool, ty: String }, Function { params: Vec<String>, ret_ty: String }, Type { definition: String }, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } impl SymbolTable { pub fn new() -> Self { Self { scopes: vec![IndexMap::new()], } } pub fn push_scope(&mut self) { self.scopes.push(IndexMap::new()); } pub fn pop_scope(&mut self) -> Option<IndexMap<String, Symbol>> { if self.scopes.len() > 1 { self.scopes.pop() } else { None } } pub fn insert(&mut self, name: String, symbol: Symbol) -> Option<Symbol> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); if let Some(scope) = self.scopes.last_mut() { scope.insert(name, symbol) } else { None } } pub fn lookup(&self, name: &str) -> Option<&Symbol> { for scope in self.scopes.iter().rev() { if let Some(symbol) = scope.get(name) { return Some(symbol); } } None } pub fn current_scope_symbols(&self) -> Vec<(&String, &Symbol)> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); self.scopes .last() .map(|scope| scope.iter().collect()) .unwrap_or_default() } } #[derive(Debug, Clone)] pub struct StructField { pub name: String, pub ty: String, pub offset: usize, } pub fn create_struct_layout() -> IndexMap<String, StructField> { let mut fields = IndexMap::new(); fields.insert( "id".to_string(), StructField { name: "id".to_string(), ty: "u64".to_string(), offset: 0, }, ); fields.insert( "name".to_string(), StructField { name: "name".to_string(), ty: "String".to_string(), offset: 8, }, ); fields.insert( "data".to_string(), StructField { name: "data".to_string(), ty: "Vec<u8>".to_string(), offset: 32, }, ); fields } pub struct ImportResolver { imports: IndexMap<String, IndexSet<String>>, } impl Default for ImportResolver { fn default() -> Self { Self::new() } } impl ImportResolver { pub fn new() -> Self { Self { imports: IndexMap::new(), } } pub fn add_import(&mut self, module: String, items: Vec<String>) { match self.imports.entry(module) { Entry::Occupied(mut e) => { for item in items { e.get_mut().insert(item); } } Entry::Vacant(e) => { let mut set = IndexSet::new(); for item in items { set.insert(item); } e.insert(set); } } } pub fn get_imports(&self) -> Vec<(String, Vec<String>)> { self.imports .iter() .map(|(module, items)| (module.clone(), items.iter().cloned().collect())) .collect() } } #[derive(Debug, Clone)] pub struct TypeDefinition { pub name: String, pub kind: TypeKind, } #[derive(Debug, Clone)] pub enum TypeKind { Primitive, Struct { fields: IndexMap<String, String> }, Enum { variants: IndexSet<String> }, Alias { target: String }, } pub struct TypeRegistry { types: IndexMap<String, TypeDefinition>, } impl Default for TypeRegistry { fn default() -> Self { Self::new() } } impl TypeRegistry { pub fn new() -> Self { let mut registry = Self { types: IndexMap::new(), }; registry.types.insert( "i32".to_string(), TypeDefinition { name: "i32".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "bool".to_string(), TypeDefinition { name: "bool".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "String".to_string(), TypeDefinition { name: "String".to_string(), kind: TypeKind::Primitive, }, ); registry } pub fn register_type(&mut self, def: TypeDefinition) -> bool { match self.types.entry(def.name.clone()) { Entry::Vacant(e) => { e.insert(def); true } Entry::Occupied(_) => false, } } pub fn get_type(&self, name: &str) -> Option<&TypeDefinition> { self.types.get(name) } pub fn iter_types(&self) -> impl Iterator<Item = (&String, &TypeDefinition)> { self.types.iter() } } pub fn demonstrate_field_ordering() { let fields = create_struct_layout(); println!("Struct fields in definition order:"); for (i, (name, field)) in fields.iter().enumerate() { println!( " {}: {} ({}) at offset {}", i, name, field.ty, field.offset ); } println!(); println!("Field access by name:"); if let Some(field) = fields.get("name") { println!(" fields[\"name\"] = {:?}", field); } println!(); println!("Field access by index:"); if let Some((_name, field)) = fields.get_index(1) { println!(" fields[1] = {:?}", field); } } pub fn demonstrate_import_resolution() { let mut resolver = ImportResolver::new(); resolver.add_import( "std::collections".to_string(), vec!["HashMap".to_string(), "Vec".to_string()], ); resolver.add_import( "std::io".to_string(), vec!["Read".to_string(), "Write".to_string()], ); resolver.add_import("std::collections".to_string(), vec!["HashSet".to_string()]); println!("Import resolution order:"); for (module, items) in resolver.get_imports() { println!(" use {} {{ {} }};", module, items.join(", ")); } } pub struct LocalScope<K: Hash + Eq, V> { bindings: IndexMap<K, V>, } impl<K: Hash + Eq + Clone, V> Default for LocalScope<K, V> { fn default() -> Self { Self::new() } } impl<K: Hash + Eq + Clone, V> LocalScope<K, V> { pub fn new() -> Self { Self { bindings: IndexMap::new(), } } pub fn bind(&mut self, name: K, value: V) -> Option<V> { self.bindings.insert(name, value) } pub fn lookup(&self, name: &K) -> Option<&V> { self.bindings.get(name) } pub fn bindings_ordered(&self) -> Vec<(K, &V)> { self.bindings.iter().map(|(k, v)| (k.clone(), v)).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_symbol_table_scoping() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: true, ty: "i32".to_string(), }, scope_level: 0, }, ); table.push_scope(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: false, ty: "bool".to_string(), }, scope_level: 1, }, ); assert_eq!(table.lookup("x").unwrap().scope_level, 1); table.pop_scope(); assert_eq!(table.lookup("x").unwrap().scope_level, 0); } #[test] fn test_struct_field_ordering() { let fields = create_struct_layout(); let keys: Vec<_> = fields.keys().cloned().collect(); assert_eq!(keys, vec!["id", "name", "data"]); assert_eq!(fields.get_index(0).unwrap().0, "id"); assert_eq!(fields.get_index(1).unwrap().0, "name"); assert_eq!(fields.get_index(2).unwrap().0, "data"); } #[test] fn test_type_registry() { let mut registry = TypeRegistry::new(); let struct_def = TypeDefinition { name: "Point".to_string(), kind: TypeKind::Struct { fields: IndexMap::from([ ("x".to_string(), "f64".to_string()), ("y".to_string(), "f64".to_string()), ]), }, }; assert!(registry.register_type(struct_def.clone())); assert!(!registry.register_type(struct_def)); assert!(registry.get_type("Point").is_some()); assert!(registry.get_type("i32").is_some()); } } pub struct SymbolTable { scopes: Vec<IndexMap<String, Symbol>>, } }
#![allow(unused)] fn main() { use std::hash::Hash; use indexmap::map::Entry; use indexmap::{IndexMap, IndexSet}; #[derive(Debug, Clone, PartialEq)] pub struct Symbol { pub name: String, pub kind: SymbolKind, pub scope_level: usize, } #[derive(Debug, Clone, PartialEq)] pub enum SymbolKind { Variable { mutable: bool, ty: String }, Function { params: Vec<String>, ret_ty: String }, Type { definition: String }, } pub struct SymbolTable { scopes: Vec<IndexMap<String, Symbol>>, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } #[derive(Debug, Clone)] pub struct StructField { pub name: String, pub ty: String, pub offset: usize, } pub fn create_struct_layout() -> IndexMap<String, StructField> { let mut fields = IndexMap::new(); fields.insert( "id".to_string(), StructField { name: "id".to_string(), ty: "u64".to_string(), offset: 0, }, ); fields.insert( "name".to_string(), StructField { name: "name".to_string(), ty: "String".to_string(), offset: 8, }, ); fields.insert( "data".to_string(), StructField { name: "data".to_string(), ty: "Vec<u8>".to_string(), offset: 32, }, ); fields } pub struct ImportResolver { imports: IndexMap<String, IndexSet<String>>, } impl Default for ImportResolver { fn default() -> Self { Self::new() } } impl ImportResolver { pub fn new() -> Self { Self { imports: IndexMap::new(), } } pub fn add_import(&mut self, module: String, items: Vec<String>) { match self.imports.entry(module) { Entry::Occupied(mut e) => { for item in items { e.get_mut().insert(item); } } Entry::Vacant(e) => { let mut set = IndexSet::new(); for item in items { set.insert(item); } e.insert(set); } } } pub fn get_imports(&self) -> Vec<(String, Vec<String>)> { self.imports .iter() .map(|(module, items)| (module.clone(), items.iter().cloned().collect())) .collect() } } #[derive(Debug, Clone)] pub struct TypeDefinition { pub name: String, pub kind: TypeKind, } #[derive(Debug, Clone)] pub enum TypeKind { Primitive, Struct { fields: IndexMap<String, String> }, Enum { variants: IndexSet<String> }, Alias { target: String }, } pub struct TypeRegistry { types: IndexMap<String, TypeDefinition>, } impl Default for TypeRegistry { fn default() -> Self { Self::new() } } impl TypeRegistry { pub fn new() -> Self { let mut registry = Self { types: IndexMap::new(), }; registry.types.insert( "i32".to_string(), TypeDefinition { name: "i32".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "bool".to_string(), TypeDefinition { name: "bool".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "String".to_string(), TypeDefinition { name: "String".to_string(), kind: TypeKind::Primitive, }, ); registry } pub fn register_type(&mut self, def: TypeDefinition) -> bool { match self.types.entry(def.name.clone()) { Entry::Vacant(e) => { e.insert(def); true } Entry::Occupied(_) => false, } } pub fn get_type(&self, name: &str) -> Option<&TypeDefinition> { self.types.get(name) } pub fn iter_types(&self) -> impl Iterator<Item = (&String, &TypeDefinition)> { self.types.iter() } } pub fn demonstrate_field_ordering() { let fields = create_struct_layout(); println!("Struct fields in definition order:"); for (i, (name, field)) in fields.iter().enumerate() { println!( " {}: {} ({}) at offset {}", i, name, field.ty, field.offset ); } println!(); println!("Field access by name:"); if let Some(field) = fields.get("name") { println!(" fields[\"name\"] = {:?}", field); } println!(); println!("Field access by index:"); if let Some((_name, field)) = fields.get_index(1) { println!(" fields[1] = {:?}", field); } } pub fn demonstrate_import_resolution() { let mut resolver = ImportResolver::new(); resolver.add_import( "std::collections".to_string(), vec!["HashMap".to_string(), "Vec".to_string()], ); resolver.add_import( "std::io".to_string(), vec!["Read".to_string(), "Write".to_string()], ); resolver.add_import("std::collections".to_string(), vec!["HashSet".to_string()]); println!("Import resolution order:"); for (module, items) in resolver.get_imports() { println!(" use {} {{ {} }};", module, items.join(", ")); } } pub struct LocalScope<K: Hash + Eq, V> { bindings: IndexMap<K, V>, } impl<K: Hash + Eq + Clone, V> Default for LocalScope<K, V> { fn default() -> Self { Self::new() } } impl<K: Hash + Eq + Clone, V> LocalScope<K, V> { pub fn new() -> Self { Self { bindings: IndexMap::new(), } } pub fn bind(&mut self, name: K, value: V) -> Option<V> { self.bindings.insert(name, value) } pub fn lookup(&self, name: &K) -> Option<&V> { self.bindings.get(name) } pub fn bindings_ordered(&self) -> Vec<(K, &V)> { self.bindings.iter().map(|(k, v)| (k.clone(), v)).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_symbol_table_scoping() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: true, ty: "i32".to_string(), }, scope_level: 0, }, ); table.push_scope(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: false, ty: "bool".to_string(), }, scope_level: 1, }, ); assert_eq!(table.lookup("x").unwrap().scope_level, 1); table.pop_scope(); assert_eq!(table.lookup("x").unwrap().scope_level, 0); } #[test] fn test_struct_field_ordering() { let fields = create_struct_layout(); let keys: Vec<_> = fields.keys().cloned().collect(); assert_eq!(keys, vec!["id", "name", "data"]); assert_eq!(fields.get_index(0).unwrap().0, "id"); assert_eq!(fields.get_index(1).unwrap().0, "name"); assert_eq!(fields.get_index(2).unwrap().0, "data"); } #[test] fn test_type_registry() { let mut registry = TypeRegistry::new(); let struct_def = TypeDefinition { name: "Point".to_string(), kind: TypeKind::Struct { fields: IndexMap::from([ ("x".to_string(), "f64".to_string()), ("y".to_string(), "f64".to_string()), ]), }, }; assert!(registry.register_type(struct_def.clone())); assert!(!registry.register_type(struct_def)); assert!(registry.get_type("Point").is_some()); assert!(registry.get_type("i32").is_some()); } } impl SymbolTable { pub fn new() -> Self { Self { scopes: vec![IndexMap::new()], } } pub fn push_scope(&mut self) { self.scopes.push(IndexMap::new()); } pub fn pop_scope(&mut self) -> Option<IndexMap<String, Symbol>> { if self.scopes.len() > 1 { self.scopes.pop() } else { None } } pub fn insert(&mut self, name: String, symbol: Symbol) -> Option<Symbol> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); if let Some(scope) = self.scopes.last_mut() { scope.insert(name, symbol) } else { None } } pub fn lookup(&self, name: &str) -> Option<&Symbol> { for scope in self.scopes.iter().rev() { if let Some(symbol) = scope.get(name) { return Some(symbol); } } None } pub fn current_scope_symbols(&self) -> Vec<(&String, &Symbol)> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); self.scopes .last() .map(|scope| scope.iter().collect()) .unwrap_or_default() } } }
The symbol table uses a stack of IndexMaps to represent nested scopes. When looking up a symbol, it searches from the innermost scope outward, implementing proper lexical scoping while maintaining declaration order within each scope.
Struct Field Layout
Struct field ordering is critical for memory layout and ABI compatibility:
#![allow(unused)] fn main() { use std::hash::Hash; use indexmap::map::Entry; use indexmap::{IndexMap, IndexSet}; #[derive(Debug, Clone, PartialEq)] pub struct Symbol { pub name: String, pub kind: SymbolKind, pub scope_level: usize, } #[derive(Debug, Clone, PartialEq)] pub enum SymbolKind { Variable { mutable: bool, ty: String }, Function { params: Vec<String>, ret_ty: String }, Type { definition: String }, } pub struct SymbolTable { scopes: Vec<IndexMap<String, Symbol>>, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } impl SymbolTable { pub fn new() -> Self { Self { scopes: vec![IndexMap::new()], } } pub fn push_scope(&mut self) { self.scopes.push(IndexMap::new()); } pub fn pop_scope(&mut self) -> Option<IndexMap<String, Symbol>> { if self.scopes.len() > 1 { self.scopes.pop() } else { None } } pub fn insert(&mut self, name: String, symbol: Symbol) -> Option<Symbol> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); if let Some(scope) = self.scopes.last_mut() { scope.insert(name, symbol) } else { None } } pub fn lookup(&self, name: &str) -> Option<&Symbol> { for scope in self.scopes.iter().rev() { if let Some(symbol) = scope.get(name) { return Some(symbol); } } None } pub fn current_scope_symbols(&self) -> Vec<(&String, &Symbol)> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); self.scopes .last() .map(|scope| scope.iter().collect()) .unwrap_or_default() } } #[derive(Debug, Clone)] pub struct StructField { pub name: String, pub ty: String, pub offset: usize, } pub struct ImportResolver { imports: IndexMap<String, IndexSet<String>>, } impl Default for ImportResolver { fn default() -> Self { Self::new() } } impl ImportResolver { pub fn new() -> Self { Self { imports: IndexMap::new(), } } pub fn add_import(&mut self, module: String, items: Vec<String>) { match self.imports.entry(module) { Entry::Occupied(mut e) => { for item in items { e.get_mut().insert(item); } } Entry::Vacant(e) => { let mut set = IndexSet::new(); for item in items { set.insert(item); } e.insert(set); } } } pub fn get_imports(&self) -> Vec<(String, Vec<String>)> { self.imports .iter() .map(|(module, items)| (module.clone(), items.iter().cloned().collect())) .collect() } } #[derive(Debug, Clone)] pub struct TypeDefinition { pub name: String, pub kind: TypeKind, } #[derive(Debug, Clone)] pub enum TypeKind { Primitive, Struct { fields: IndexMap<String, String> }, Enum { variants: IndexSet<String> }, Alias { target: String }, } pub struct TypeRegistry { types: IndexMap<String, TypeDefinition>, } impl Default for TypeRegistry { fn default() -> Self { Self::new() } } impl TypeRegistry { pub fn new() -> Self { let mut registry = Self { types: IndexMap::new(), }; registry.types.insert( "i32".to_string(), TypeDefinition { name: "i32".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "bool".to_string(), TypeDefinition { name: "bool".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "String".to_string(), TypeDefinition { name: "String".to_string(), kind: TypeKind::Primitive, }, ); registry } pub fn register_type(&mut self, def: TypeDefinition) -> bool { match self.types.entry(def.name.clone()) { Entry::Vacant(e) => { e.insert(def); true } Entry::Occupied(_) => false, } } pub fn get_type(&self, name: &str) -> Option<&TypeDefinition> { self.types.get(name) } pub fn iter_types(&self) -> impl Iterator<Item = (&String, &TypeDefinition)> { self.types.iter() } } pub fn demonstrate_field_ordering() { let fields = create_struct_layout(); println!("Struct fields in definition order:"); for (i, (name, field)) in fields.iter().enumerate() { println!( " {}: {} ({}) at offset {}", i, name, field.ty, field.offset ); } println!(); println!("Field access by name:"); if let Some(field) = fields.get("name") { println!(" fields[\"name\"] = {:?}", field); } println!(); println!("Field access by index:"); if let Some((_name, field)) = fields.get_index(1) { println!(" fields[1] = {:?}", field); } } pub fn demonstrate_import_resolution() { let mut resolver = ImportResolver::new(); resolver.add_import( "std::collections".to_string(), vec!["HashMap".to_string(), "Vec".to_string()], ); resolver.add_import( "std::io".to_string(), vec!["Read".to_string(), "Write".to_string()], ); resolver.add_import("std::collections".to_string(), vec!["HashSet".to_string()]); println!("Import resolution order:"); for (module, items) in resolver.get_imports() { println!(" use {} {{ {} }};", module, items.join(", ")); } } pub struct LocalScope<K: Hash + Eq, V> { bindings: IndexMap<K, V>, } impl<K: Hash + Eq + Clone, V> Default for LocalScope<K, V> { fn default() -> Self { Self::new() } } impl<K: Hash + Eq + Clone, V> LocalScope<K, V> { pub fn new() -> Self { Self { bindings: IndexMap::new(), } } pub fn bind(&mut self, name: K, value: V) -> Option<V> { self.bindings.insert(name, value) } pub fn lookup(&self, name: &K) -> Option<&V> { self.bindings.get(name) } pub fn bindings_ordered(&self) -> Vec<(K, &V)> { self.bindings.iter().map(|(k, v)| (k.clone(), v)).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_symbol_table_scoping() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: true, ty: "i32".to_string(), }, scope_level: 0, }, ); table.push_scope(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: false, ty: "bool".to_string(), }, scope_level: 1, }, ); assert_eq!(table.lookup("x").unwrap().scope_level, 1); table.pop_scope(); assert_eq!(table.lookup("x").unwrap().scope_level, 0); } #[test] fn test_struct_field_ordering() { let fields = create_struct_layout(); let keys: Vec<_> = fields.keys().cloned().collect(); assert_eq!(keys, vec!["id", "name", "data"]); assert_eq!(fields.get_index(0).unwrap().0, "id"); assert_eq!(fields.get_index(1).unwrap().0, "name"); assert_eq!(fields.get_index(2).unwrap().0, "data"); } #[test] fn test_type_registry() { let mut registry = TypeRegistry::new(); let struct_def = TypeDefinition { name: "Point".to_string(), kind: TypeKind::Struct { fields: IndexMap::from([ ("x".to_string(), "f64".to_string()), ("y".to_string(), "f64".to_string()), ]), }, }; assert!(registry.register_type(struct_def.clone())); assert!(!registry.register_type(struct_def)); assert!(registry.get_type("Point").is_some()); assert!(registry.get_type("i32").is_some()); } } pub fn create_struct_layout() -> IndexMap<String, StructField> { let mut fields = IndexMap::new(); fields.insert( "id".to_string(), StructField { name: "id".to_string(), ty: "u64".to_string(), offset: 0, }, ); fields.insert( "name".to_string(), StructField { name: "name".to_string(), ty: "String".to_string(), offset: 8, }, ); fields.insert( "data".to_string(), StructField { name: "data".to_string(), ty: "Vec<u8>".to_string(), offset: 32, }, ); fields } }
#![allow(unused)] fn main() { use std::hash::Hash; use indexmap::map::Entry; use indexmap::{IndexMap, IndexSet}; #[derive(Debug, Clone, PartialEq)] pub struct Symbol { pub name: String, pub kind: SymbolKind, pub scope_level: usize, } #[derive(Debug, Clone, PartialEq)] pub enum SymbolKind { Variable { mutable: bool, ty: String }, Function { params: Vec<String>, ret_ty: String }, Type { definition: String }, } pub struct SymbolTable { scopes: Vec<IndexMap<String, Symbol>>, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } impl SymbolTable { pub fn new() -> Self { Self { scopes: vec![IndexMap::new()], } } pub fn push_scope(&mut self) { self.scopes.push(IndexMap::new()); } pub fn pop_scope(&mut self) -> Option<IndexMap<String, Symbol>> { if self.scopes.len() > 1 { self.scopes.pop() } else { None } } pub fn insert(&mut self, name: String, symbol: Symbol) -> Option<Symbol> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); if let Some(scope) = self.scopes.last_mut() { scope.insert(name, symbol) } else { None } } pub fn lookup(&self, name: &str) -> Option<&Symbol> { for scope in self.scopes.iter().rev() { if let Some(symbol) = scope.get(name) { return Some(symbol); } } None } pub fn current_scope_symbols(&self) -> Vec<(&String, &Symbol)> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); self.scopes .last() .map(|scope| scope.iter().collect()) .unwrap_or_default() } } #[derive(Debug, Clone)] pub struct StructField { pub name: String, pub ty: String, pub offset: usize, } pub fn create_struct_layout() -> IndexMap<String, StructField> { let mut fields = IndexMap::new(); fields.insert( "id".to_string(), StructField { name: "id".to_string(), ty: "u64".to_string(), offset: 0, }, ); fields.insert( "name".to_string(), StructField { name: "name".to_string(), ty: "String".to_string(), offset: 8, }, ); fields.insert( "data".to_string(), StructField { name: "data".to_string(), ty: "Vec<u8>".to_string(), offset: 32, }, ); fields } pub struct ImportResolver { imports: IndexMap<String, IndexSet<String>>, } impl Default for ImportResolver { fn default() -> Self { Self::new() } } impl ImportResolver { pub fn new() -> Self { Self { imports: IndexMap::new(), } } pub fn add_import(&mut self, module: String, items: Vec<String>) { match self.imports.entry(module) { Entry::Occupied(mut e) => { for item in items { e.get_mut().insert(item); } } Entry::Vacant(e) => { let mut set = IndexSet::new(); for item in items { set.insert(item); } e.insert(set); } } } pub fn get_imports(&self) -> Vec<(String, Vec<String>)> { self.imports .iter() .map(|(module, items)| (module.clone(), items.iter().cloned().collect())) .collect() } } #[derive(Debug, Clone)] pub struct TypeDefinition { pub name: String, pub kind: TypeKind, } #[derive(Debug, Clone)] pub enum TypeKind { Primitive, Struct { fields: IndexMap<String, String> }, Enum { variants: IndexSet<String> }, Alias { target: String }, } pub struct TypeRegistry { types: IndexMap<String, TypeDefinition>, } impl Default for TypeRegistry { fn default() -> Self { Self::new() } } impl TypeRegistry { pub fn new() -> Self { let mut registry = Self { types: IndexMap::new(), }; registry.types.insert( "i32".to_string(), TypeDefinition { name: "i32".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "bool".to_string(), TypeDefinition { name: "bool".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "String".to_string(), TypeDefinition { name: "String".to_string(), kind: TypeKind::Primitive, }, ); registry } pub fn register_type(&mut self, def: TypeDefinition) -> bool { match self.types.entry(def.name.clone()) { Entry::Vacant(e) => { e.insert(def); true } Entry::Occupied(_) => false, } } pub fn get_type(&self, name: &str) -> Option<&TypeDefinition> { self.types.get(name) } pub fn iter_types(&self) -> impl Iterator<Item = (&String, &TypeDefinition)> { self.types.iter() } } pub fn demonstrate_import_resolution() { let mut resolver = ImportResolver::new(); resolver.add_import( "std::collections".to_string(), vec!["HashMap".to_string(), "Vec".to_string()], ); resolver.add_import( "std::io".to_string(), vec!["Read".to_string(), "Write".to_string()], ); resolver.add_import("std::collections".to_string(), vec!["HashSet".to_string()]); println!("Import resolution order:"); for (module, items) in resolver.get_imports() { println!(" use {} {{ {} }};", module, items.join(", ")); } } pub struct LocalScope<K: Hash + Eq, V> { bindings: IndexMap<K, V>, } impl<K: Hash + Eq + Clone, V> Default for LocalScope<K, V> { fn default() -> Self { Self::new() } } impl<K: Hash + Eq + Clone, V> LocalScope<K, V> { pub fn new() -> Self { Self { bindings: IndexMap::new(), } } pub fn bind(&mut self, name: K, value: V) -> Option<V> { self.bindings.insert(name, value) } pub fn lookup(&self, name: &K) -> Option<&V> { self.bindings.get(name) } pub fn bindings_ordered(&self) -> Vec<(K, &V)> { self.bindings.iter().map(|(k, v)| (k.clone(), v)).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_symbol_table_scoping() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: true, ty: "i32".to_string(), }, scope_level: 0, }, ); table.push_scope(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: false, ty: "bool".to_string(), }, scope_level: 1, }, ); assert_eq!(table.lookup("x").unwrap().scope_level, 1); table.pop_scope(); assert_eq!(table.lookup("x").unwrap().scope_level, 0); } #[test] fn test_struct_field_ordering() { let fields = create_struct_layout(); let keys: Vec<_> = fields.keys().cloned().collect(); assert_eq!(keys, vec!["id", "name", "data"]); assert_eq!(fields.get_index(0).unwrap().0, "id"); assert_eq!(fields.get_index(1).unwrap().0, "name"); assert_eq!(fields.get_index(2).unwrap().0, "data"); } #[test] fn test_type_registry() { let mut registry = TypeRegistry::new(); let struct_def = TypeDefinition { name: "Point".to_string(), kind: TypeKind::Struct { fields: IndexMap::from([ ("x".to_string(), "f64".to_string()), ("y".to_string(), "f64".to_string()), ]), }, }; assert!(registry.register_type(struct_def.clone())); assert!(!registry.register_type(struct_def)); assert!(registry.get_type("Point").is_some()); assert!(registry.get_type("i32").is_some()); } } pub fn demonstrate_field_ordering() { let fields = create_struct_layout(); println!("Struct fields in definition order:"); for (i, (name, field)) in fields.iter().enumerate() { println!( " {}: {} ({}) at offset {}", i, name, field.ty, field.offset ); } println!(); println!("Field access by name:"); if let Some(field) = fields.get("name") { println!(" fields[\"name\"] = {:?}", field); } println!(); println!("Field access by index:"); if let Some((_name, field)) = fields.get_index(1) { println!(" fields[1] = {:?}", field); } } }
IndexMap preserves field definition order while providing both name-based and index-based access. This is essential for generating correct struct layouts and for error messages that reference fields in source order.
Import Resolution
Managing imports requires both deduplication and order preservation:
#![allow(unused)] fn main() { use std::hash::Hash; use indexmap::map::Entry; use indexmap::{IndexMap, IndexSet}; #[derive(Debug, Clone, PartialEq)] pub struct Symbol { pub name: String, pub kind: SymbolKind, pub scope_level: usize, } #[derive(Debug, Clone, PartialEq)] pub enum SymbolKind { Variable { mutable: bool, ty: String }, Function { params: Vec<String>, ret_ty: String }, Type { definition: String }, } pub struct SymbolTable { scopes: Vec<IndexMap<String, Symbol>>, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } impl SymbolTable { pub fn new() -> Self { Self { scopes: vec![IndexMap::new()], } } pub fn push_scope(&mut self) { self.scopes.push(IndexMap::new()); } pub fn pop_scope(&mut self) -> Option<IndexMap<String, Symbol>> { if self.scopes.len() > 1 { self.scopes.pop() } else { None } } pub fn insert(&mut self, name: String, symbol: Symbol) -> Option<Symbol> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); if let Some(scope) = self.scopes.last_mut() { scope.insert(name, symbol) } else { None } } pub fn lookup(&self, name: &str) -> Option<&Symbol> { for scope in self.scopes.iter().rev() { if let Some(symbol) = scope.get(name) { return Some(symbol); } } None } pub fn current_scope_symbols(&self) -> Vec<(&String, &Symbol)> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); self.scopes .last() .map(|scope| scope.iter().collect()) .unwrap_or_default() } } #[derive(Debug, Clone)] pub struct StructField { pub name: String, pub ty: String, pub offset: usize, } pub fn create_struct_layout() -> IndexMap<String, StructField> { let mut fields = IndexMap::new(); fields.insert( "id".to_string(), StructField { name: "id".to_string(), ty: "u64".to_string(), offset: 0, }, ); fields.insert( "name".to_string(), StructField { name: "name".to_string(), ty: "String".to_string(), offset: 8, }, ); fields.insert( "data".to_string(), StructField { name: "data".to_string(), ty: "Vec<u8>".to_string(), offset: 32, }, ); fields } impl Default for ImportResolver { fn default() -> Self { Self::new() } } impl ImportResolver { pub fn new() -> Self { Self { imports: IndexMap::new(), } } pub fn add_import(&mut self, module: String, items: Vec<String>) { match self.imports.entry(module) { Entry::Occupied(mut e) => { for item in items { e.get_mut().insert(item); } } Entry::Vacant(e) => { let mut set = IndexSet::new(); for item in items { set.insert(item); } e.insert(set); } } } pub fn get_imports(&self) -> Vec<(String, Vec<String>)> { self.imports .iter() .map(|(module, items)| (module.clone(), items.iter().cloned().collect())) .collect() } } #[derive(Debug, Clone)] pub struct TypeDefinition { pub name: String, pub kind: TypeKind, } #[derive(Debug, Clone)] pub enum TypeKind { Primitive, Struct { fields: IndexMap<String, String> }, Enum { variants: IndexSet<String> }, Alias { target: String }, } pub struct TypeRegistry { types: IndexMap<String, TypeDefinition>, } impl Default for TypeRegistry { fn default() -> Self { Self::new() } } impl TypeRegistry { pub fn new() -> Self { let mut registry = Self { types: IndexMap::new(), }; registry.types.insert( "i32".to_string(), TypeDefinition { name: "i32".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "bool".to_string(), TypeDefinition { name: "bool".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "String".to_string(), TypeDefinition { name: "String".to_string(), kind: TypeKind::Primitive, }, ); registry } pub fn register_type(&mut self, def: TypeDefinition) -> bool { match self.types.entry(def.name.clone()) { Entry::Vacant(e) => { e.insert(def); true } Entry::Occupied(_) => false, } } pub fn get_type(&self, name: &str) -> Option<&TypeDefinition> { self.types.get(name) } pub fn iter_types(&self) -> impl Iterator<Item = (&String, &TypeDefinition)> { self.types.iter() } } pub fn demonstrate_field_ordering() { let fields = create_struct_layout(); println!("Struct fields in definition order:"); for (i, (name, field)) in fields.iter().enumerate() { println!( " {}: {} ({}) at offset {}", i, name, field.ty, field.offset ); } println!(); println!("Field access by name:"); if let Some(field) = fields.get("name") { println!(" fields[\"name\"] = {:?}", field); } println!(); println!("Field access by index:"); if let Some((_name, field)) = fields.get_index(1) { println!(" fields[1] = {:?}", field); } } pub fn demonstrate_import_resolution() { let mut resolver = ImportResolver::new(); resolver.add_import( "std::collections".to_string(), vec!["HashMap".to_string(), "Vec".to_string()], ); resolver.add_import( "std::io".to_string(), vec!["Read".to_string(), "Write".to_string()], ); resolver.add_import("std::collections".to_string(), vec!["HashSet".to_string()]); println!("Import resolution order:"); for (module, items) in resolver.get_imports() { println!(" use {} {{ {} }};", module, items.join(", ")); } } pub struct LocalScope<K: Hash + Eq, V> { bindings: IndexMap<K, V>, } impl<K: Hash + Eq + Clone, V> Default for LocalScope<K, V> { fn default() -> Self { Self::new() } } impl<K: Hash + Eq + Clone, V> LocalScope<K, V> { pub fn new() -> Self { Self { bindings: IndexMap::new(), } } pub fn bind(&mut self, name: K, value: V) -> Option<V> { self.bindings.insert(name, value) } pub fn lookup(&self, name: &K) -> Option<&V> { self.bindings.get(name) } pub fn bindings_ordered(&self) -> Vec<(K, &V)> { self.bindings.iter().map(|(k, v)| (k.clone(), v)).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_symbol_table_scoping() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: true, ty: "i32".to_string(), }, scope_level: 0, }, ); table.push_scope(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: false, ty: "bool".to_string(), }, scope_level: 1, }, ); assert_eq!(table.lookup("x").unwrap().scope_level, 1); table.pop_scope(); assert_eq!(table.lookup("x").unwrap().scope_level, 0); } #[test] fn test_struct_field_ordering() { let fields = create_struct_layout(); let keys: Vec<_> = fields.keys().cloned().collect(); assert_eq!(keys, vec!["id", "name", "data"]); assert_eq!(fields.get_index(0).unwrap().0, "id"); assert_eq!(fields.get_index(1).unwrap().0, "name"); assert_eq!(fields.get_index(2).unwrap().0, "data"); } #[test] fn test_type_registry() { let mut registry = TypeRegistry::new(); let struct_def = TypeDefinition { name: "Point".to_string(), kind: TypeKind::Struct { fields: IndexMap::from([ ("x".to_string(), "f64".to_string()), ("y".to_string(), "f64".to_string()), ]), }, }; assert!(registry.register_type(struct_def.clone())); assert!(!registry.register_type(struct_def)); assert!(registry.get_type("Point").is_some()); assert!(registry.get_type("i32").is_some()); } } pub struct ImportResolver { imports: IndexMap<String, IndexSet<String>>, } }
#![allow(unused)] fn main() { use std::hash::Hash; use indexmap::map::Entry; use indexmap::{IndexMap, IndexSet}; #[derive(Debug, Clone, PartialEq)] pub struct Symbol { pub name: String, pub kind: SymbolKind, pub scope_level: usize, } #[derive(Debug, Clone, PartialEq)] pub enum SymbolKind { Variable { mutable: bool, ty: String }, Function { params: Vec<String>, ret_ty: String }, Type { definition: String }, } pub struct SymbolTable { scopes: Vec<IndexMap<String, Symbol>>, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } impl SymbolTable { pub fn new() -> Self { Self { scopes: vec![IndexMap::new()], } } pub fn push_scope(&mut self) { self.scopes.push(IndexMap::new()); } pub fn pop_scope(&mut self) -> Option<IndexMap<String, Symbol>> { if self.scopes.len() > 1 { self.scopes.pop() } else { None } } pub fn insert(&mut self, name: String, symbol: Symbol) -> Option<Symbol> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); if let Some(scope) = self.scopes.last_mut() { scope.insert(name, symbol) } else { None } } pub fn lookup(&self, name: &str) -> Option<&Symbol> { for scope in self.scopes.iter().rev() { if let Some(symbol) = scope.get(name) { return Some(symbol); } } None } pub fn current_scope_symbols(&self) -> Vec<(&String, &Symbol)> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); self.scopes .last() .map(|scope| scope.iter().collect()) .unwrap_or_default() } } #[derive(Debug, Clone)] pub struct StructField { pub name: String, pub ty: String, pub offset: usize, } pub fn create_struct_layout() -> IndexMap<String, StructField> { let mut fields = IndexMap::new(); fields.insert( "id".to_string(), StructField { name: "id".to_string(), ty: "u64".to_string(), offset: 0, }, ); fields.insert( "name".to_string(), StructField { name: "name".to_string(), ty: "String".to_string(), offset: 8, }, ); fields.insert( "data".to_string(), StructField { name: "data".to_string(), ty: "Vec<u8>".to_string(), offset: 32, }, ); fields } pub struct ImportResolver { imports: IndexMap<String, IndexSet<String>>, } impl Default for ImportResolver { fn default() -> Self { Self::new() } } #[derive(Debug, Clone)] pub struct TypeDefinition { pub name: String, pub kind: TypeKind, } #[derive(Debug, Clone)] pub enum TypeKind { Primitive, Struct { fields: IndexMap<String, String> }, Enum { variants: IndexSet<String> }, Alias { target: String }, } pub struct TypeRegistry { types: IndexMap<String, TypeDefinition>, } impl Default for TypeRegistry { fn default() -> Self { Self::new() } } impl TypeRegistry { pub fn new() -> Self { let mut registry = Self { types: IndexMap::new(), }; registry.types.insert( "i32".to_string(), TypeDefinition { name: "i32".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "bool".to_string(), TypeDefinition { name: "bool".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "String".to_string(), TypeDefinition { name: "String".to_string(), kind: TypeKind::Primitive, }, ); registry } pub fn register_type(&mut self, def: TypeDefinition) -> bool { match self.types.entry(def.name.clone()) { Entry::Vacant(e) => { e.insert(def); true } Entry::Occupied(_) => false, } } pub fn get_type(&self, name: &str) -> Option<&TypeDefinition> { self.types.get(name) } pub fn iter_types(&self) -> impl Iterator<Item = (&String, &TypeDefinition)> { self.types.iter() } } pub fn demonstrate_field_ordering() { let fields = create_struct_layout(); println!("Struct fields in definition order:"); for (i, (name, field)) in fields.iter().enumerate() { println!( " {}: {} ({}) at offset {}", i, name, field.ty, field.offset ); } println!(); println!("Field access by name:"); if let Some(field) = fields.get("name") { println!(" fields[\"name\"] = {:?}", field); } println!(); println!("Field access by index:"); if let Some((_name, field)) = fields.get_index(1) { println!(" fields[1] = {:?}", field); } } pub fn demonstrate_import_resolution() { let mut resolver = ImportResolver::new(); resolver.add_import( "std::collections".to_string(), vec!["HashMap".to_string(), "Vec".to_string()], ); resolver.add_import( "std::io".to_string(), vec!["Read".to_string(), "Write".to_string()], ); resolver.add_import("std::collections".to_string(), vec!["HashSet".to_string()]); println!("Import resolution order:"); for (module, items) in resolver.get_imports() { println!(" use {} {{ {} }};", module, items.join(", ")); } } pub struct LocalScope<K: Hash + Eq, V> { bindings: IndexMap<K, V>, } impl<K: Hash + Eq + Clone, V> Default for LocalScope<K, V> { fn default() -> Self { Self::new() } } impl<K: Hash + Eq + Clone, V> LocalScope<K, V> { pub fn new() -> Self { Self { bindings: IndexMap::new(), } } pub fn bind(&mut self, name: K, value: V) -> Option<V> { self.bindings.insert(name, value) } pub fn lookup(&self, name: &K) -> Option<&V> { self.bindings.get(name) } pub fn bindings_ordered(&self) -> Vec<(K, &V)> { self.bindings.iter().map(|(k, v)| (k.clone(), v)).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_symbol_table_scoping() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: true, ty: "i32".to_string(), }, scope_level: 0, }, ); table.push_scope(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: false, ty: "bool".to_string(), }, scope_level: 1, }, ); assert_eq!(table.lookup("x").unwrap().scope_level, 1); table.pop_scope(); assert_eq!(table.lookup("x").unwrap().scope_level, 0); } #[test] fn test_struct_field_ordering() { let fields = create_struct_layout(); let keys: Vec<_> = fields.keys().cloned().collect(); assert_eq!(keys, vec!["id", "name", "data"]); assert_eq!(fields.get_index(0).unwrap().0, "id"); assert_eq!(fields.get_index(1).unwrap().0, "name"); assert_eq!(fields.get_index(2).unwrap().0, "data"); } #[test] fn test_type_registry() { let mut registry = TypeRegistry::new(); let struct_def = TypeDefinition { name: "Point".to_string(), kind: TypeKind::Struct { fields: IndexMap::from([ ("x".to_string(), "f64".to_string()), ("y".to_string(), "f64".to_string()), ]), }, }; assert!(registry.register_type(struct_def.clone())); assert!(!registry.register_type(struct_def)); assert!(registry.get_type("Point").is_some()); assert!(registry.get_type("i32").is_some()); } } impl ImportResolver { pub fn new() -> Self { Self { imports: IndexMap::new(), } } pub fn add_import(&mut self, module: String, items: Vec<String>) { match self.imports.entry(module) { Entry::Occupied(mut e) => { for item in items { e.get_mut().insert(item); } } Entry::Vacant(e) => { let mut set = IndexSet::new(); for item in items { set.insert(item); } e.insert(set); } } } pub fn get_imports(&self) -> Vec<(String, Vec<String>)> { self.imports .iter() .map(|(module, items)| (module.clone(), items.iter().cloned().collect())) .collect() } } }
The import resolver uses IndexMap for modules and IndexSet for imported items. This ensures imports are processed in a consistent order and duplicates are automatically removed while maintaining the first occurrence position.
Type Registry
Type systems benefit from ordered type definitions:
#![allow(unused)] fn main() { use std::hash::Hash; use indexmap::map::Entry; use indexmap::{IndexMap, IndexSet}; #[derive(Debug, Clone, PartialEq)] pub struct Symbol { pub name: String, pub kind: SymbolKind, pub scope_level: usize, } #[derive(Debug, Clone, PartialEq)] pub enum SymbolKind { Variable { mutable: bool, ty: String }, Function { params: Vec<String>, ret_ty: String }, Type { definition: String }, } pub struct SymbolTable { scopes: Vec<IndexMap<String, Symbol>>, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } impl SymbolTable { pub fn new() -> Self { Self { scopes: vec![IndexMap::new()], } } pub fn push_scope(&mut self) { self.scopes.push(IndexMap::new()); } pub fn pop_scope(&mut self) -> Option<IndexMap<String, Symbol>> { if self.scopes.len() > 1 { self.scopes.pop() } else { None } } pub fn insert(&mut self, name: String, symbol: Symbol) -> Option<Symbol> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); if let Some(scope) = self.scopes.last_mut() { scope.insert(name, symbol) } else { None } } pub fn lookup(&self, name: &str) -> Option<&Symbol> { for scope in self.scopes.iter().rev() { if let Some(symbol) = scope.get(name) { return Some(symbol); } } None } pub fn current_scope_symbols(&self) -> Vec<(&String, &Symbol)> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); self.scopes .last() .map(|scope| scope.iter().collect()) .unwrap_or_default() } } #[derive(Debug, Clone)] pub struct StructField { pub name: String, pub ty: String, pub offset: usize, } pub fn create_struct_layout() -> IndexMap<String, StructField> { let mut fields = IndexMap::new(); fields.insert( "id".to_string(), StructField { name: "id".to_string(), ty: "u64".to_string(), offset: 0, }, ); fields.insert( "name".to_string(), StructField { name: "name".to_string(), ty: "String".to_string(), offset: 8, }, ); fields.insert( "data".to_string(), StructField { name: "data".to_string(), ty: "Vec<u8>".to_string(), offset: 32, }, ); fields } pub struct ImportResolver { imports: IndexMap<String, IndexSet<String>>, } impl Default for ImportResolver { fn default() -> Self { Self::new() } } impl ImportResolver { pub fn new() -> Self { Self { imports: IndexMap::new(), } } pub fn add_import(&mut self, module: String, items: Vec<String>) { match self.imports.entry(module) { Entry::Occupied(mut e) => { for item in items { e.get_mut().insert(item); } } Entry::Vacant(e) => { let mut set = IndexSet::new(); for item in items { set.insert(item); } e.insert(set); } } } pub fn get_imports(&self) -> Vec<(String, Vec<String>)> { self.imports .iter() .map(|(module, items)| (module.clone(), items.iter().cloned().collect())) .collect() } } #[derive(Debug, Clone)] pub struct TypeDefinition { pub name: String, pub kind: TypeKind, } #[derive(Debug, Clone)] pub enum TypeKind { Primitive, Struct { fields: IndexMap<String, String> }, Enum { variants: IndexSet<String> }, Alias { target: String }, } impl Default for TypeRegistry { fn default() -> Self { Self::new() } } impl TypeRegistry { pub fn new() -> Self { let mut registry = Self { types: IndexMap::new(), }; registry.types.insert( "i32".to_string(), TypeDefinition { name: "i32".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "bool".to_string(), TypeDefinition { name: "bool".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "String".to_string(), TypeDefinition { name: "String".to_string(), kind: TypeKind::Primitive, }, ); registry } pub fn register_type(&mut self, def: TypeDefinition) -> bool { match self.types.entry(def.name.clone()) { Entry::Vacant(e) => { e.insert(def); true } Entry::Occupied(_) => false, } } pub fn get_type(&self, name: &str) -> Option<&TypeDefinition> { self.types.get(name) } pub fn iter_types(&self) -> impl Iterator<Item = (&String, &TypeDefinition)> { self.types.iter() } } pub fn demonstrate_field_ordering() { let fields = create_struct_layout(); println!("Struct fields in definition order:"); for (i, (name, field)) in fields.iter().enumerate() { println!( " {}: {} ({}) at offset {}", i, name, field.ty, field.offset ); } println!(); println!("Field access by name:"); if let Some(field) = fields.get("name") { println!(" fields[\"name\"] = {:?}", field); } println!(); println!("Field access by index:"); if let Some((_name, field)) = fields.get_index(1) { println!(" fields[1] = {:?}", field); } } pub fn demonstrate_import_resolution() { let mut resolver = ImportResolver::new(); resolver.add_import( "std::collections".to_string(), vec!["HashMap".to_string(), "Vec".to_string()], ); resolver.add_import( "std::io".to_string(), vec!["Read".to_string(), "Write".to_string()], ); resolver.add_import("std::collections".to_string(), vec!["HashSet".to_string()]); println!("Import resolution order:"); for (module, items) in resolver.get_imports() { println!(" use {} {{ {} }};", module, items.join(", ")); } } pub struct LocalScope<K: Hash + Eq, V> { bindings: IndexMap<K, V>, } impl<K: Hash + Eq + Clone, V> Default for LocalScope<K, V> { fn default() -> Self { Self::new() } } impl<K: Hash + Eq + Clone, V> LocalScope<K, V> { pub fn new() -> Self { Self { bindings: IndexMap::new(), } } pub fn bind(&mut self, name: K, value: V) -> Option<V> { self.bindings.insert(name, value) } pub fn lookup(&self, name: &K) -> Option<&V> { self.bindings.get(name) } pub fn bindings_ordered(&self) -> Vec<(K, &V)> { self.bindings.iter().map(|(k, v)| (k.clone(), v)).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_symbol_table_scoping() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: true, ty: "i32".to_string(), }, scope_level: 0, }, ); table.push_scope(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: false, ty: "bool".to_string(), }, scope_level: 1, }, ); assert_eq!(table.lookup("x").unwrap().scope_level, 1); table.pop_scope(); assert_eq!(table.lookup("x").unwrap().scope_level, 0); } #[test] fn test_struct_field_ordering() { let fields = create_struct_layout(); let keys: Vec<_> = fields.keys().cloned().collect(); assert_eq!(keys, vec!["id", "name", "data"]); assert_eq!(fields.get_index(0).unwrap().0, "id"); assert_eq!(fields.get_index(1).unwrap().0, "name"); assert_eq!(fields.get_index(2).unwrap().0, "data"); } #[test] fn test_type_registry() { let mut registry = TypeRegistry::new(); let struct_def = TypeDefinition { name: "Point".to_string(), kind: TypeKind::Struct { fields: IndexMap::from([ ("x".to_string(), "f64".to_string()), ("y".to_string(), "f64".to_string()), ]), }, }; assert!(registry.register_type(struct_def.clone())); assert!(!registry.register_type(struct_def)); assert!(registry.get_type("Point").is_some()); assert!(registry.get_type("i32").is_some()); } } pub struct TypeRegistry { types: IndexMap<String, TypeDefinition>, } }
#![allow(unused)] fn main() { use std::hash::Hash; use indexmap::map::Entry; use indexmap::{IndexMap, IndexSet}; #[derive(Debug, Clone, PartialEq)] pub struct Symbol { pub name: String, pub kind: SymbolKind, pub scope_level: usize, } #[derive(Debug, Clone, PartialEq)] pub enum SymbolKind { Variable { mutable: bool, ty: String }, Function { params: Vec<String>, ret_ty: String }, Type { definition: String }, } pub struct SymbolTable { scopes: Vec<IndexMap<String, Symbol>>, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } impl SymbolTable { pub fn new() -> Self { Self { scopes: vec![IndexMap::new()], } } pub fn push_scope(&mut self) { self.scopes.push(IndexMap::new()); } pub fn pop_scope(&mut self) -> Option<IndexMap<String, Symbol>> { if self.scopes.len() > 1 { self.scopes.pop() } else { None } } pub fn insert(&mut self, name: String, symbol: Symbol) -> Option<Symbol> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); if let Some(scope) = self.scopes.last_mut() { scope.insert(name, symbol) } else { None } } pub fn lookup(&self, name: &str) -> Option<&Symbol> { for scope in self.scopes.iter().rev() { if let Some(symbol) = scope.get(name) { return Some(symbol); } } None } pub fn current_scope_symbols(&self) -> Vec<(&String, &Symbol)> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); self.scopes .last() .map(|scope| scope.iter().collect()) .unwrap_or_default() } } #[derive(Debug, Clone)] pub struct StructField { pub name: String, pub ty: String, pub offset: usize, } pub fn create_struct_layout() -> IndexMap<String, StructField> { let mut fields = IndexMap::new(); fields.insert( "id".to_string(), StructField { name: "id".to_string(), ty: "u64".to_string(), offset: 0, }, ); fields.insert( "name".to_string(), StructField { name: "name".to_string(), ty: "String".to_string(), offset: 8, }, ); fields.insert( "data".to_string(), StructField { name: "data".to_string(), ty: "Vec<u8>".to_string(), offset: 32, }, ); fields } pub struct ImportResolver { imports: IndexMap<String, IndexSet<String>>, } impl Default for ImportResolver { fn default() -> Self { Self::new() } } impl ImportResolver { pub fn new() -> Self { Self { imports: IndexMap::new(), } } pub fn add_import(&mut self, module: String, items: Vec<String>) { match self.imports.entry(module) { Entry::Occupied(mut e) => { for item in items { e.get_mut().insert(item); } } Entry::Vacant(e) => { let mut set = IndexSet::new(); for item in items { set.insert(item); } e.insert(set); } } } pub fn get_imports(&self) -> Vec<(String, Vec<String>)> { self.imports .iter() .map(|(module, items)| (module.clone(), items.iter().cloned().collect())) .collect() } } #[derive(Debug, Clone)] pub struct TypeDefinition { pub name: String, pub kind: TypeKind, } #[derive(Debug, Clone)] pub enum TypeKind { Primitive, Struct { fields: IndexMap<String, String> }, Enum { variants: IndexSet<String> }, Alias { target: String }, } pub struct TypeRegistry { types: IndexMap<String, TypeDefinition>, } impl Default for TypeRegistry { fn default() -> Self { Self::new() } } pub fn demonstrate_field_ordering() { let fields = create_struct_layout(); println!("Struct fields in definition order:"); for (i, (name, field)) in fields.iter().enumerate() { println!( " {}: {} ({}) at offset {}", i, name, field.ty, field.offset ); } println!(); println!("Field access by name:"); if let Some(field) = fields.get("name") { println!(" fields[\"name\"] = {:?}", field); } println!(); println!("Field access by index:"); if let Some((_name, field)) = fields.get_index(1) { println!(" fields[1] = {:?}", field); } } pub fn demonstrate_import_resolution() { let mut resolver = ImportResolver::new(); resolver.add_import( "std::collections".to_string(), vec!["HashMap".to_string(), "Vec".to_string()], ); resolver.add_import( "std::io".to_string(), vec!["Read".to_string(), "Write".to_string()], ); resolver.add_import("std::collections".to_string(), vec!["HashSet".to_string()]); println!("Import resolution order:"); for (module, items) in resolver.get_imports() { println!(" use {} {{ {} }};", module, items.join(", ")); } } pub struct LocalScope<K: Hash + Eq, V> { bindings: IndexMap<K, V>, } impl<K: Hash + Eq + Clone, V> Default for LocalScope<K, V> { fn default() -> Self { Self::new() } } impl<K: Hash + Eq + Clone, V> LocalScope<K, V> { pub fn new() -> Self { Self { bindings: IndexMap::new(), } } pub fn bind(&mut self, name: K, value: V) -> Option<V> { self.bindings.insert(name, value) } pub fn lookup(&self, name: &K) -> Option<&V> { self.bindings.get(name) } pub fn bindings_ordered(&self) -> Vec<(K, &V)> { self.bindings.iter().map(|(k, v)| (k.clone(), v)).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_symbol_table_scoping() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: true, ty: "i32".to_string(), }, scope_level: 0, }, ); table.push_scope(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: false, ty: "bool".to_string(), }, scope_level: 1, }, ); assert_eq!(table.lookup("x").unwrap().scope_level, 1); table.pop_scope(); assert_eq!(table.lookup("x").unwrap().scope_level, 0); } #[test] fn test_struct_field_ordering() { let fields = create_struct_layout(); let keys: Vec<_> = fields.keys().cloned().collect(); assert_eq!(keys, vec!["id", "name", "data"]); assert_eq!(fields.get_index(0).unwrap().0, "id"); assert_eq!(fields.get_index(1).unwrap().0, "name"); assert_eq!(fields.get_index(2).unwrap().0, "data"); } #[test] fn test_type_registry() { let mut registry = TypeRegistry::new(); let struct_def = TypeDefinition { name: "Point".to_string(), kind: TypeKind::Struct { fields: IndexMap::from([ ("x".to_string(), "f64".to_string()), ("y".to_string(), "f64".to_string()), ]), }, }; assert!(registry.register_type(struct_def.clone())); assert!(!registry.register_type(struct_def)); assert!(registry.get_type("Point").is_some()); assert!(registry.get_type("i32").is_some()); } } impl TypeRegistry { pub fn new() -> Self { let mut registry = Self { types: IndexMap::new(), }; registry.types.insert( "i32".to_string(), TypeDefinition { name: "i32".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "bool".to_string(), TypeDefinition { name: "bool".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "String".to_string(), TypeDefinition { name: "String".to_string(), kind: TypeKind::Primitive, }, ); registry } pub fn register_type(&mut self, def: TypeDefinition) -> bool { match self.types.entry(def.name.clone()) { Entry::Vacant(e) => { e.insert(def); true } Entry::Occupied(_) => false, } } pub fn get_type(&self, name: &str) -> Option<&TypeDefinition> { self.types.get(name) } pub fn iter_types(&self) -> impl Iterator<Item = (&String, &TypeDefinition)> { self.types.iter() } } }
The registry maintains types in registration order, which is important for error messages, documentation generation, and ensuring primitive types are always processed before user-defined types.
Local Variable Bindings
Tracking local variables in their declaration order helps with debugging and error reporting:
#![allow(unused)] fn main() { use std::hash::Hash; use indexmap::map::Entry; use indexmap::{IndexMap, IndexSet}; #[derive(Debug, Clone, PartialEq)] pub struct Symbol { pub name: String, pub kind: SymbolKind, pub scope_level: usize, } #[derive(Debug, Clone, PartialEq)] pub enum SymbolKind { Variable { mutable: bool, ty: String }, Function { params: Vec<String>, ret_ty: String }, Type { definition: String }, } pub struct SymbolTable { scopes: Vec<IndexMap<String, Symbol>>, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } impl SymbolTable { pub fn new() -> Self { Self { scopes: vec![IndexMap::new()], } } pub fn push_scope(&mut self) { self.scopes.push(IndexMap::new()); } pub fn pop_scope(&mut self) -> Option<IndexMap<String, Symbol>> { if self.scopes.len() > 1 { self.scopes.pop() } else { None } } pub fn insert(&mut self, name: String, symbol: Symbol) -> Option<Symbol> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); if let Some(scope) = self.scopes.last_mut() { scope.insert(name, symbol) } else { None } } pub fn lookup(&self, name: &str) -> Option<&Symbol> { for scope in self.scopes.iter().rev() { if let Some(symbol) = scope.get(name) { return Some(symbol); } } None } pub fn current_scope_symbols(&self) -> Vec<(&String, &Symbol)> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); self.scopes .last() .map(|scope| scope.iter().collect()) .unwrap_or_default() } } #[derive(Debug, Clone)] pub struct StructField { pub name: String, pub ty: String, pub offset: usize, } pub fn create_struct_layout() -> IndexMap<String, StructField> { let mut fields = IndexMap::new(); fields.insert( "id".to_string(), StructField { name: "id".to_string(), ty: "u64".to_string(), offset: 0, }, ); fields.insert( "name".to_string(), StructField { name: "name".to_string(), ty: "String".to_string(), offset: 8, }, ); fields.insert( "data".to_string(), StructField { name: "data".to_string(), ty: "Vec<u8>".to_string(), offset: 32, }, ); fields } pub struct ImportResolver { imports: IndexMap<String, IndexSet<String>>, } impl Default for ImportResolver { fn default() -> Self { Self::new() } } impl ImportResolver { pub fn new() -> Self { Self { imports: IndexMap::new(), } } pub fn add_import(&mut self, module: String, items: Vec<String>) { match self.imports.entry(module) { Entry::Occupied(mut e) => { for item in items { e.get_mut().insert(item); } } Entry::Vacant(e) => { let mut set = IndexSet::new(); for item in items { set.insert(item); } e.insert(set); } } } pub fn get_imports(&self) -> Vec<(String, Vec<String>)> { self.imports .iter() .map(|(module, items)| (module.clone(), items.iter().cloned().collect())) .collect() } } #[derive(Debug, Clone)] pub struct TypeDefinition { pub name: String, pub kind: TypeKind, } #[derive(Debug, Clone)] pub enum TypeKind { Primitive, Struct { fields: IndexMap<String, String> }, Enum { variants: IndexSet<String> }, Alias { target: String }, } pub struct TypeRegistry { types: IndexMap<String, TypeDefinition>, } impl Default for TypeRegistry { fn default() -> Self { Self::new() } } impl TypeRegistry { pub fn new() -> Self { let mut registry = Self { types: IndexMap::new(), }; registry.types.insert( "i32".to_string(), TypeDefinition { name: "i32".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "bool".to_string(), TypeDefinition { name: "bool".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "String".to_string(), TypeDefinition { name: "String".to_string(), kind: TypeKind::Primitive, }, ); registry } pub fn register_type(&mut self, def: TypeDefinition) -> bool { match self.types.entry(def.name.clone()) { Entry::Vacant(e) => { e.insert(def); true } Entry::Occupied(_) => false, } } pub fn get_type(&self, name: &str) -> Option<&TypeDefinition> { self.types.get(name) } pub fn iter_types(&self) -> impl Iterator<Item = (&String, &TypeDefinition)> { self.types.iter() } } pub fn demonstrate_field_ordering() { let fields = create_struct_layout(); println!("Struct fields in definition order:"); for (i, (name, field)) in fields.iter().enumerate() { println!( " {}: {} ({}) at offset {}", i, name, field.ty, field.offset ); } println!(); println!("Field access by name:"); if let Some(field) = fields.get("name") { println!(" fields[\"name\"] = {:?}", field); } println!(); println!("Field access by index:"); if let Some((_name, field)) = fields.get_index(1) { println!(" fields[1] = {:?}", field); } } pub fn demonstrate_import_resolution() { let mut resolver = ImportResolver::new(); resolver.add_import( "std::collections".to_string(), vec!["HashMap".to_string(), "Vec".to_string()], ); resolver.add_import( "std::io".to_string(), vec!["Read".to_string(), "Write".to_string()], ); resolver.add_import("std::collections".to_string(), vec!["HashSet".to_string()]); println!("Import resolution order:"); for (module, items) in resolver.get_imports() { println!(" use {} {{ {} }};", module, items.join(", ")); } } impl<K: Hash + Eq + Clone, V> Default for LocalScope<K, V> { fn default() -> Self { Self::new() } } impl<K: Hash + Eq + Clone, V> LocalScope<K, V> { pub fn new() -> Self { Self { bindings: IndexMap::new(), } } pub fn bind(&mut self, name: K, value: V) -> Option<V> { self.bindings.insert(name, value) } pub fn lookup(&self, name: &K) -> Option<&V> { self.bindings.get(name) } pub fn bindings_ordered(&self) -> Vec<(K, &V)> { self.bindings.iter().map(|(k, v)| (k.clone(), v)).collect() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_symbol_table_scoping() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: true, ty: "i32".to_string(), }, scope_level: 0, }, ); table.push_scope(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: false, ty: "bool".to_string(), }, scope_level: 1, }, ); assert_eq!(table.lookup("x").unwrap().scope_level, 1); table.pop_scope(); assert_eq!(table.lookup("x").unwrap().scope_level, 0); } #[test] fn test_struct_field_ordering() { let fields = create_struct_layout(); let keys: Vec<_> = fields.keys().cloned().collect(); assert_eq!(keys, vec!["id", "name", "data"]); assert_eq!(fields.get_index(0).unwrap().0, "id"); assert_eq!(fields.get_index(1).unwrap().0, "name"); assert_eq!(fields.get_index(2).unwrap().0, "data"); } #[test] fn test_type_registry() { let mut registry = TypeRegistry::new(); let struct_def = TypeDefinition { name: "Point".to_string(), kind: TypeKind::Struct { fields: IndexMap::from([ ("x".to_string(), "f64".to_string()), ("y".to_string(), "f64".to_string()), ]), }, }; assert!(registry.register_type(struct_def.clone())); assert!(!registry.register_type(struct_def)); assert!(registry.get_type("Point").is_some()); assert!(registry.get_type("i32").is_some()); } } pub struct LocalScope<K: Hash + Eq, V> { bindings: IndexMap<K, V>, } }
#![allow(unused)] fn main() { use std::hash::Hash; use indexmap::map::Entry; use indexmap::{IndexMap, IndexSet}; #[derive(Debug, Clone, PartialEq)] pub struct Symbol { pub name: String, pub kind: SymbolKind, pub scope_level: usize, } #[derive(Debug, Clone, PartialEq)] pub enum SymbolKind { Variable { mutable: bool, ty: String }, Function { params: Vec<String>, ret_ty: String }, Type { definition: String }, } pub struct SymbolTable { scopes: Vec<IndexMap<String, Symbol>>, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } impl SymbolTable { pub fn new() -> Self { Self { scopes: vec![IndexMap::new()], } } pub fn push_scope(&mut self) { self.scopes.push(IndexMap::new()); } pub fn pop_scope(&mut self) -> Option<IndexMap<String, Symbol>> { if self.scopes.len() > 1 { self.scopes.pop() } else { None } } pub fn insert(&mut self, name: String, symbol: Symbol) -> Option<Symbol> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); if let Some(scope) = self.scopes.last_mut() { scope.insert(name, symbol) } else { None } } pub fn lookup(&self, name: &str) -> Option<&Symbol> { for scope in self.scopes.iter().rev() { if let Some(symbol) = scope.get(name) { return Some(symbol); } } None } pub fn current_scope_symbols(&self) -> Vec<(&String, &Symbol)> { // We maintain the invariant that there's always at least one scope debug_assert!( !self.scopes.is_empty(), "SymbolTable should always have at least one scope" ); self.scopes .last() .map(|scope| scope.iter().collect()) .unwrap_or_default() } } #[derive(Debug, Clone)] pub struct StructField { pub name: String, pub ty: String, pub offset: usize, } pub fn create_struct_layout() -> IndexMap<String, StructField> { let mut fields = IndexMap::new(); fields.insert( "id".to_string(), StructField { name: "id".to_string(), ty: "u64".to_string(), offset: 0, }, ); fields.insert( "name".to_string(), StructField { name: "name".to_string(), ty: "String".to_string(), offset: 8, }, ); fields.insert( "data".to_string(), StructField { name: "data".to_string(), ty: "Vec<u8>".to_string(), offset: 32, }, ); fields } pub struct ImportResolver { imports: IndexMap<String, IndexSet<String>>, } impl Default for ImportResolver { fn default() -> Self { Self::new() } } impl ImportResolver { pub fn new() -> Self { Self { imports: IndexMap::new(), } } pub fn add_import(&mut self, module: String, items: Vec<String>) { match self.imports.entry(module) { Entry::Occupied(mut e) => { for item in items { e.get_mut().insert(item); } } Entry::Vacant(e) => { let mut set = IndexSet::new(); for item in items { set.insert(item); } e.insert(set); } } } pub fn get_imports(&self) -> Vec<(String, Vec<String>)> { self.imports .iter() .map(|(module, items)| (module.clone(), items.iter().cloned().collect())) .collect() } } #[derive(Debug, Clone)] pub struct TypeDefinition { pub name: String, pub kind: TypeKind, } #[derive(Debug, Clone)] pub enum TypeKind { Primitive, Struct { fields: IndexMap<String, String> }, Enum { variants: IndexSet<String> }, Alias { target: String }, } pub struct TypeRegistry { types: IndexMap<String, TypeDefinition>, } impl Default for TypeRegistry { fn default() -> Self { Self::new() } } impl TypeRegistry { pub fn new() -> Self { let mut registry = Self { types: IndexMap::new(), }; registry.types.insert( "i32".to_string(), TypeDefinition { name: "i32".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "bool".to_string(), TypeDefinition { name: "bool".to_string(), kind: TypeKind::Primitive, }, ); registry.types.insert( "String".to_string(), TypeDefinition { name: "String".to_string(), kind: TypeKind::Primitive, }, ); registry } pub fn register_type(&mut self, def: TypeDefinition) -> bool { match self.types.entry(def.name.clone()) { Entry::Vacant(e) => { e.insert(def); true } Entry::Occupied(_) => false, } } pub fn get_type(&self, name: &str) -> Option<&TypeDefinition> { self.types.get(name) } pub fn iter_types(&self) -> impl Iterator<Item = (&String, &TypeDefinition)> { self.types.iter() } } pub fn demonstrate_field_ordering() { let fields = create_struct_layout(); println!("Struct fields in definition order:"); for (i, (name, field)) in fields.iter().enumerate() { println!( " {}: {} ({}) at offset {}", i, name, field.ty, field.offset ); } println!(); println!("Field access by name:"); if let Some(field) = fields.get("name") { println!(" fields[\"name\"] = {:?}", field); } println!(); println!("Field access by index:"); if let Some((_name, field)) = fields.get_index(1) { println!(" fields[1] = {:?}", field); } } pub fn demonstrate_import_resolution() { let mut resolver = ImportResolver::new(); resolver.add_import( "std::collections".to_string(), vec!["HashMap".to_string(), "Vec".to_string()], ); resolver.add_import( "std::io".to_string(), vec!["Read".to_string(), "Write".to_string()], ); resolver.add_import("std::collections".to_string(), vec!["HashSet".to_string()]); println!("Import resolution order:"); for (module, items) in resolver.get_imports() { println!(" use {} {{ {} }};", module, items.join(", ")); } } pub struct LocalScope<K: Hash + Eq, V> { bindings: IndexMap<K, V>, } impl<K: Hash + Eq + Clone, V> Default for LocalScope<K, V> { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_symbol_table_scoping() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: true, ty: "i32".to_string(), }, scope_level: 0, }, ); table.push_scope(); table.insert( "x".to_string(), Symbol { name: "x".to_string(), kind: SymbolKind::Variable { mutable: false, ty: "bool".to_string(), }, scope_level: 1, }, ); assert_eq!(table.lookup("x").unwrap().scope_level, 1); table.pop_scope(); assert_eq!(table.lookup("x").unwrap().scope_level, 0); } #[test] fn test_struct_field_ordering() { let fields = create_struct_layout(); let keys: Vec<_> = fields.keys().cloned().collect(); assert_eq!(keys, vec!["id", "name", "data"]); assert_eq!(fields.get_index(0).unwrap().0, "id"); assert_eq!(fields.get_index(1).unwrap().0, "name"); assert_eq!(fields.get_index(2).unwrap().0, "data"); } #[test] fn test_type_registry() { let mut registry = TypeRegistry::new(); let struct_def = TypeDefinition { name: "Point".to_string(), kind: TypeKind::Struct { fields: IndexMap::from([ ("x".to_string(), "f64".to_string()), ("y".to_string(), "f64".to_string()), ]), }, }; assert!(registry.register_type(struct_def.clone())); assert!(!registry.register_type(struct_def)); assert!(registry.get_type("Point").is_some()); assert!(registry.get_type("i32").is_some()); } } impl<K: Hash + Eq + Clone, V> LocalScope<K, V> { pub fn new() -> Self { Self { bindings: IndexMap::new(), } } pub fn bind(&mut self, name: K, value: V) -> Option<V> { self.bindings.insert(name, value) } pub fn lookup(&self, name: &K) -> Option<&V> { self.bindings.get(name) } pub fn bindings_ordered(&self) -> Vec<(K, &V)> { self.bindings.iter().map(|(k, v)| (k.clone(), v)).collect() } } }
This generic scope structure can track any kind of bindings while preserving order. The ordered iteration is particularly useful for displaying variable dumps or generating debug information.
Best Practices
Use IndexMap for any collection where iteration order matters for correctness or user experience. This includes symbol tables, type definitions, struct fields, function parameters, and import lists. The small overhead compared to HashMap is usually negligible compared to the benefits of deterministic behavior.
Leverage both map and index access patterns. IndexMap allows you to look up entries by key in O(1) time and also access them by position. This is useful for positional parameters, struct field offsets, and anywhere you need both named and indexed access.
Use IndexSet for ordered unique collections. Import lists, keyword sets, and type parameter bounds are good candidates. IndexSet provides the same ordering guarantees as IndexMap while ensuring uniqueness.
Consider using the Entry
API for efficient insertions and updates. This avoids double lookups and clearly expresses the intent to either update existing entries or insert new ones.
For deterministic compilation, ensure all collections that affect output use ordered variants. This includes not just IndexMap but also considering BTreeMap for sorted output or Vec for purely sequential access.
When implementing compiler passes that transform data structures, preserve ordering information. If a pass reads from an IndexMap and produces a new collection, use IndexMap for the output to maintain order invariants throughout the compilation pipeline.
Remember that IndexMap is not a replacement for Vec when you need purely sequential access. Use Vec for instruction sequences, basic blocks, and other truly linear data. Use IndexMap when you need both key-based lookup and order preservation.
smallvec
The smallvec
crate provides a vector type that stores a small number of elements inline, avoiding heap allocation for common cases. In compiler development, many data structures contain a small number of elements most of the time. For example, most expressions have only a few operands, most basic blocks have only a few instructions, and most functions have only a few parameters. Using SmallVec
for these cases can significantly reduce allocations and improve cache locality.
The key insight is that compilers often deal with collections that are usually small but occasionally large. Traditional vectors always allocate on the heap, even for a single element. SmallVec stores up to N elements inline within the struct itself, only spilling to heap allocation when the capacity is exceeded. This optimization is particularly effective in AST nodes, instruction operands, and symbol tables.
Basic Usage
SmallVec is used similarly to standard vectors but with an inline capacity specified in the type:
#![allow(unused)] fn main() { use smallvec::{smallvec, SmallVec}; #[derive(Debug, Clone, PartialEq)] pub struct Token { pub kind: TokenKind, pub span: Span, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), Operator(char), Keyword(String), } #[derive(Debug, Clone, PartialEq)] pub struct Span { pub start: usize, pub end: usize, } pub type TokenStream = SmallVec<[Token; 32]>; pub fn tokenize_expression(input: &str) -> TokenStream { let mut tokens = SmallVec::new(); let mut chars = input.char_indices().peekable(); while let Some((i, ch)) = chars.next() { match ch { ' ' | '\t' | '\n' => continue, '+' | '-' | '*' | '/' | '(' | ')' => { tokens.push(Token { kind: TokenKind::Operator(ch), span: Span { start: i, end: i + 1, }, }); } '0'..='9' => { let start = i; let mut value = ch.to_digit(10).unwrap() as i64; while let Some(&(_, ch)) = chars.peek() { if ch.is_ascii_digit() { value = value * 10 + ch.to_digit(10).unwrap() as i64; chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); tokens.push(Token { kind: TokenKind::Number(value), span: Span { start, end }, }); } 'a'..='z' | 'A'..='Z' | '_' => { let start = i; let mut ident = String::new(); ident.push(ch); while let Some(&(_, ch)) = chars.peek() { if ch.is_alphanumeric() || ch == '_' { ident.push(ch); chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); let kind = match ident.as_str() { "if" | "else" | "while" | "for" | "return" => TokenKind::Keyword(ident), _ => TokenKind::Identifier(ident), }; tokens.push(Token { kind, span: Span { start, end }, }); } _ => {} } } tokens } #[derive(Debug, Clone)] pub struct AstNode { pub kind: AstKind, pub children: SmallVec<[Box<AstNode>; 4]>, } #[derive(Debug, Clone)] pub enum AstKind { Program, Function(String), Block, Expression, Statement, Identifier(String), Number(i64), } pub fn build_simple_ast() -> AstNode { AstNode { kind: AstKind::Program, children: smallvec![Box::new(AstNode { kind: AstKind::Function("main".to_string()), children: smallvec![Box::new(AstNode { kind: AstKind::Block, children: smallvec![Box::new(AstNode { kind: AstKind::Expression, children: smallvec![Box::new(AstNode { kind: AstKind::Number(42), children: SmallVec::new(), })], })], })], })], } } #[derive(Debug, Clone)] pub struct Instruction { pub opcode: Opcode, pub operands: SmallVec<[Operand; 3]>, } #[derive(Debug, Clone)] pub enum Opcode { Load, Store, Add, Sub, Mul, Jmp, Ret, } #[derive(Debug, Clone)] pub enum Operand { Register(u8), Immediate(i32), Memory(u32), } pub fn create_instruction_sequence() -> Vec<Instruction> { vec![ Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(0), Operand::Memory(0x1000)], }, Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(1), Operand::Memory(0x1004)], }, Instruction { opcode: Opcode::Add, operands: smallvec![ Operand::Register(2), Operand::Register(0), Operand::Register(1) ], }, Instruction { opcode: Opcode::Store, operands: smallvec![Operand::Memory(0x1008), Operand::Register(2)], }, Instruction { opcode: Opcode::Ret, operands: SmallVec::new(), }, ] } pub struct SymbolTable { scopes: SmallVec<[Scope; 8]>, } pub struct Scope { symbols: SmallVec<[(String, SymbolInfo); 16]>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub offset: usize, } #[derive(Debug, Clone)] pub enum SymbolKind { Variable, Function, Parameter, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } impl SymbolTable { pub fn new() -> Self { Self { scopes: smallvec![Scope { symbols: SmallVec::new(), }], } } pub fn push_scope(&mut self) { self.scopes.push(Scope { symbols: SmallVec::new(), }); } pub fn pop_scope(&mut self) { if self.scopes.len() > 1 { self.scopes.pop(); } } pub fn insert(&mut self, name: String, info: SymbolInfo) { if let Some(scope) = self.scopes.last_mut() { scope.symbols.push((name, info)); } } pub fn lookup(&self, name: &str) -> Option<&SymbolInfo> { for scope in self.scopes.iter().rev() { for (sym_name, info) in &scope.symbols { if sym_name == name { return Some(info); } } } None } } #[derive(Debug)] pub struct CompactError { pub messages: SmallVec<[String; 2]>, pub locations: SmallVec<[Location; 2]>, } #[derive(Debug)] pub struct Location { pub file: String, pub line: u32, pub column: u32, } impl CompactError { pub fn new(message: String, location: Location) -> Self { Self { messages: smallvec![message], locations: smallvec![location], } } pub fn add_context(&mut self, message: String, location: Location) { self.messages.push(message); self.locations.push(location); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let tokens = tokenize_expression("x + 42 * (y - 3)"); assert_eq!(tokens.len(), 9); assert!(matches!(tokens[0].kind, TokenKind::Identifier(_))); assert!(matches!(tokens[2].kind, TokenKind::Number(42))); } #[test] fn test_inline_capacity() { let vec: SmallVec<[i32; 8]> = smallvec![1, 2, 3, 4]; assert!(!vec.spilled()); assert_eq!(vec.len(), 4); assert_eq!(vec.capacity(), 8); } #[test] fn test_spill_to_heap() { let mut vec: SmallVec<[i32; 2]> = SmallVec::new(); vec.push(1); vec.push(2); assert!(!vec.spilled()); vec.push(3); assert!(vec.spilled()); assert!(vec.capacity() >= 3); } #[test] fn test_symbol_table() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 0, }, ); table.push_scope(); table.insert( "y".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 4, }, ); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_some()); table.pop_scope(); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_none()); } } pub fn demonstrate_capacity() { let mut vec: SmallVec<[i32; 4]> = SmallVec::new(); println!("Initial capacity: {}", vec.inline_size()); println!("Is heap allocated: {}", vec.spilled()); for i in 0..6 { vec.push(i); println!( "After pushing {}: capacity = {}, spilled = {}", i, vec.capacity(), vec.spilled() ); } } }
The type parameter [i32; 4]
specifies both the element type and inline capacity. The vector starts with space for 4 elements allocated inline. When a fifth element is added, it spills to the heap with a larger capacity.
Tokenization
A common use case in compilers is storing token streams. Most expressions contain a moderate number of tokens that fit well within inline storage:
#![allow(unused)] fn main() { use smallvec::{smallvec, SmallVec}; #[derive(Debug, Clone, PartialEq)] pub struct Token { pub kind: TokenKind, pub span: Span, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), Operator(char), Keyword(String), } #[derive(Debug, Clone, PartialEq)] pub struct Span { pub start: usize, pub end: usize, } pub type TokenStream = SmallVec<[Token; 32]>; #[derive(Debug, Clone)] pub struct AstNode { pub kind: AstKind, pub children: SmallVec<[Box<AstNode>; 4]>, } #[derive(Debug, Clone)] pub enum AstKind { Program, Function(String), Block, Expression, Statement, Identifier(String), Number(i64), } pub fn build_simple_ast() -> AstNode { AstNode { kind: AstKind::Program, children: smallvec![Box::new(AstNode { kind: AstKind::Function("main".to_string()), children: smallvec![Box::new(AstNode { kind: AstKind::Block, children: smallvec![Box::new(AstNode { kind: AstKind::Expression, children: smallvec![Box::new(AstNode { kind: AstKind::Number(42), children: SmallVec::new(), })], })], })], })], } } #[derive(Debug, Clone)] pub struct Instruction { pub opcode: Opcode, pub operands: SmallVec<[Operand; 3]>, } #[derive(Debug, Clone)] pub enum Opcode { Load, Store, Add, Sub, Mul, Jmp, Ret, } #[derive(Debug, Clone)] pub enum Operand { Register(u8), Immediate(i32), Memory(u32), } pub fn create_instruction_sequence() -> Vec<Instruction> { vec![ Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(0), Operand::Memory(0x1000)], }, Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(1), Operand::Memory(0x1004)], }, Instruction { opcode: Opcode::Add, operands: smallvec![ Operand::Register(2), Operand::Register(0), Operand::Register(1) ], }, Instruction { opcode: Opcode::Store, operands: smallvec![Operand::Memory(0x1008), Operand::Register(2)], }, Instruction { opcode: Opcode::Ret, operands: SmallVec::new(), }, ] } pub fn demonstrate_capacity() { let mut vec: SmallVec<[i32; 4]> = SmallVec::new(); println!("Initial capacity: {}", vec.inline_size()); println!("Is heap allocated: {}", vec.spilled()); for i in 0..6 { vec.push(i); println!( "After pushing {}: capacity = {}, spilled = {}", i, vec.capacity(), vec.spilled() ); } } pub struct SymbolTable { scopes: SmallVec<[Scope; 8]>, } pub struct Scope { symbols: SmallVec<[(String, SymbolInfo); 16]>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub offset: usize, } #[derive(Debug, Clone)] pub enum SymbolKind { Variable, Function, Parameter, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } impl SymbolTable { pub fn new() -> Self { Self { scopes: smallvec![Scope { symbols: SmallVec::new(), }], } } pub fn push_scope(&mut self) { self.scopes.push(Scope { symbols: SmallVec::new(), }); } pub fn pop_scope(&mut self) { if self.scopes.len() > 1 { self.scopes.pop(); } } pub fn insert(&mut self, name: String, info: SymbolInfo) { if let Some(scope) = self.scopes.last_mut() { scope.symbols.push((name, info)); } } pub fn lookup(&self, name: &str) -> Option<&SymbolInfo> { for scope in self.scopes.iter().rev() { for (sym_name, info) in &scope.symbols { if sym_name == name { return Some(info); } } } None } } #[derive(Debug)] pub struct CompactError { pub messages: SmallVec<[String; 2]>, pub locations: SmallVec<[Location; 2]>, } #[derive(Debug)] pub struct Location { pub file: String, pub line: u32, pub column: u32, } impl CompactError { pub fn new(message: String, location: Location) -> Self { Self { messages: smallvec![message], locations: smallvec![location], } } pub fn add_context(&mut self, message: String, location: Location) { self.messages.push(message); self.locations.push(location); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let tokens = tokenize_expression("x + 42 * (y - 3)"); assert_eq!(tokens.len(), 9); assert!(matches!(tokens[0].kind, TokenKind::Identifier(_))); assert!(matches!(tokens[2].kind, TokenKind::Number(42))); } #[test] fn test_inline_capacity() { let vec: SmallVec<[i32; 8]> = smallvec![1, 2, 3, 4]; assert!(!vec.spilled()); assert_eq!(vec.len(), 4); assert_eq!(vec.capacity(), 8); } #[test] fn test_spill_to_heap() { let mut vec: SmallVec<[i32; 2]> = SmallVec::new(); vec.push(1); vec.push(2); assert!(!vec.spilled()); vec.push(3); assert!(vec.spilled()); assert!(vec.capacity() >= 3); } #[test] fn test_symbol_table() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 0, }, ); table.push_scope(); table.insert( "y".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 4, }, ); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_some()); table.pop_scope(); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_none()); } } pub fn tokenize_expression(input: &str) -> TokenStream { let mut tokens = SmallVec::new(); let mut chars = input.char_indices().peekable(); while let Some((i, ch)) = chars.next() { match ch { ' ' | '\t' | '\n' => continue, '+' | '-' | '*' | '/' | '(' | ')' => { tokens.push(Token { kind: TokenKind::Operator(ch), span: Span { start: i, end: i + 1, }, }); } '0'..='9' => { let start = i; let mut value = ch.to_digit(10).unwrap() as i64; while let Some(&(_, ch)) = chars.peek() { if ch.is_ascii_digit() { value = value * 10 + ch.to_digit(10).unwrap() as i64; chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); tokens.push(Token { kind: TokenKind::Number(value), span: Span { start, end }, }); } 'a'..='z' | 'A'..='Z' | '_' => { let start = i; let mut ident = String::new(); ident.push(ch); while let Some(&(_, ch)) = chars.peek() { if ch.is_alphanumeric() || ch == '_' { ident.push(ch); chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); let kind = match ident.as_str() { "if" | "else" | "while" | "for" | "return" => TokenKind::Keyword(ident), _ => TokenKind::Identifier(ident), }; tokens.push(Token { kind, span: Span { start, end }, }); } _ => {} } } tokens } }
The TokenStream
type alias uses a SmallVec with inline capacity for 32 tokens. This covers most expressions without heap allocation while still handling arbitrarily large inputs when needed.
AST Nodes
Abstract syntax tree nodes often have a small, variable number of children. SmallVec is ideal for storing these children:
#![allow(unused)] fn main() { use smallvec::{smallvec, SmallVec}; #[derive(Debug, Clone, PartialEq)] pub struct Token { pub kind: TokenKind, pub span: Span, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), Operator(char), Keyword(String), } #[derive(Debug, Clone, PartialEq)] pub struct Span { pub start: usize, pub end: usize, } pub type TokenStream = SmallVec<[Token; 32]>; pub fn tokenize_expression(input: &str) -> TokenStream { let mut tokens = SmallVec::new(); let mut chars = input.char_indices().peekable(); while let Some((i, ch)) = chars.next() { match ch { ' ' | '\t' | '\n' => continue, '+' | '-' | '*' | '/' | '(' | ')' => { tokens.push(Token { kind: TokenKind::Operator(ch), span: Span { start: i, end: i + 1, }, }); } '0'..='9' => { let start = i; let mut value = ch.to_digit(10).unwrap() as i64; while let Some(&(_, ch)) = chars.peek() { if ch.is_ascii_digit() { value = value * 10 + ch.to_digit(10).unwrap() as i64; chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); tokens.push(Token { kind: TokenKind::Number(value), span: Span { start, end }, }); } 'a'..='z' | 'A'..='Z' | '_' => { let start = i; let mut ident = String::new(); ident.push(ch); while let Some(&(_, ch)) = chars.peek() { if ch.is_alphanumeric() || ch == '_' { ident.push(ch); chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); let kind = match ident.as_str() { "if" | "else" | "while" | "for" | "return" => TokenKind::Keyword(ident), _ => TokenKind::Identifier(ident), }; tokens.push(Token { kind, span: Span { start, end }, }); } _ => {} } } tokens } #[derive(Debug, Clone)] pub enum AstKind { Program, Function(String), Block, Expression, Statement, Identifier(String), Number(i64), } pub fn build_simple_ast() -> AstNode { AstNode { kind: AstKind::Program, children: smallvec![Box::new(AstNode { kind: AstKind::Function("main".to_string()), children: smallvec![Box::new(AstNode { kind: AstKind::Block, children: smallvec![Box::new(AstNode { kind: AstKind::Expression, children: smallvec![Box::new(AstNode { kind: AstKind::Number(42), children: SmallVec::new(), })], })], })], })], } } #[derive(Debug, Clone)] pub struct Instruction { pub opcode: Opcode, pub operands: SmallVec<[Operand; 3]>, } #[derive(Debug, Clone)] pub enum Opcode { Load, Store, Add, Sub, Mul, Jmp, Ret, } #[derive(Debug, Clone)] pub enum Operand { Register(u8), Immediate(i32), Memory(u32), } pub fn create_instruction_sequence() -> Vec<Instruction> { vec![ Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(0), Operand::Memory(0x1000)], }, Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(1), Operand::Memory(0x1004)], }, Instruction { opcode: Opcode::Add, operands: smallvec![ Operand::Register(2), Operand::Register(0), Operand::Register(1) ], }, Instruction { opcode: Opcode::Store, operands: smallvec![Operand::Memory(0x1008), Operand::Register(2)], }, Instruction { opcode: Opcode::Ret, operands: SmallVec::new(), }, ] } pub fn demonstrate_capacity() { let mut vec: SmallVec<[i32; 4]> = SmallVec::new(); println!("Initial capacity: {}", vec.inline_size()); println!("Is heap allocated: {}", vec.spilled()); for i in 0..6 { vec.push(i); println!( "After pushing {}: capacity = {}, spilled = {}", i, vec.capacity(), vec.spilled() ); } } pub struct SymbolTable { scopes: SmallVec<[Scope; 8]>, } pub struct Scope { symbols: SmallVec<[(String, SymbolInfo); 16]>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub offset: usize, } #[derive(Debug, Clone)] pub enum SymbolKind { Variable, Function, Parameter, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } impl SymbolTable { pub fn new() -> Self { Self { scopes: smallvec![Scope { symbols: SmallVec::new(), }], } } pub fn push_scope(&mut self) { self.scopes.push(Scope { symbols: SmallVec::new(), }); } pub fn pop_scope(&mut self) { if self.scopes.len() > 1 { self.scopes.pop(); } } pub fn insert(&mut self, name: String, info: SymbolInfo) { if let Some(scope) = self.scopes.last_mut() { scope.symbols.push((name, info)); } } pub fn lookup(&self, name: &str) -> Option<&SymbolInfo> { for scope in self.scopes.iter().rev() { for (sym_name, info) in &scope.symbols { if sym_name == name { return Some(info); } } } None } } #[derive(Debug)] pub struct CompactError { pub messages: SmallVec<[String; 2]>, pub locations: SmallVec<[Location; 2]>, } #[derive(Debug)] pub struct Location { pub file: String, pub line: u32, pub column: u32, } impl CompactError { pub fn new(message: String, location: Location) -> Self { Self { messages: smallvec![message], locations: smallvec![location], } } pub fn add_context(&mut self, message: String, location: Location) { self.messages.push(message); self.locations.push(location); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let tokens = tokenize_expression("x + 42 * (y - 3)"); assert_eq!(tokens.len(), 9); assert!(matches!(tokens[0].kind, TokenKind::Identifier(_))); assert!(matches!(tokens[2].kind, TokenKind::Number(42))); } #[test] fn test_inline_capacity() { let vec: SmallVec<[i32; 8]> = smallvec![1, 2, 3, 4]; assert!(!vec.spilled()); assert_eq!(vec.len(), 4); assert_eq!(vec.capacity(), 8); } #[test] fn test_spill_to_heap() { let mut vec: SmallVec<[i32; 2]> = SmallVec::new(); vec.push(1); vec.push(2); assert!(!vec.spilled()); vec.push(3); assert!(vec.spilled()); assert!(vec.capacity() >= 3); } #[test] fn test_symbol_table() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 0, }, ); table.push_scope(); table.insert( "y".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 4, }, ); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_some()); table.pop_scope(); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_none()); } } #[derive(Debug, Clone)] pub struct AstNode { pub kind: AstKind, pub children: SmallVec<[Box<AstNode>; 4]>, } }
#![allow(unused)] fn main() { use smallvec::{smallvec, SmallVec}; #[derive(Debug, Clone, PartialEq)] pub struct Token { pub kind: TokenKind, pub span: Span, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), Operator(char), Keyword(String), } #[derive(Debug, Clone, PartialEq)] pub struct Span { pub start: usize, pub end: usize, } pub type TokenStream = SmallVec<[Token; 32]>; pub fn tokenize_expression(input: &str) -> TokenStream { let mut tokens = SmallVec::new(); let mut chars = input.char_indices().peekable(); while let Some((i, ch)) = chars.next() { match ch { ' ' | '\t' | '\n' => continue, '+' | '-' | '*' | '/' | '(' | ')' => { tokens.push(Token { kind: TokenKind::Operator(ch), span: Span { start: i, end: i + 1, }, }); } '0'..='9' => { let start = i; let mut value = ch.to_digit(10).unwrap() as i64; while let Some(&(_, ch)) = chars.peek() { if ch.is_ascii_digit() { value = value * 10 + ch.to_digit(10).unwrap() as i64; chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); tokens.push(Token { kind: TokenKind::Number(value), span: Span { start, end }, }); } 'a'..='z' | 'A'..='Z' | '_' => { let start = i; let mut ident = String::new(); ident.push(ch); while let Some(&(_, ch)) = chars.peek() { if ch.is_alphanumeric() || ch == '_' { ident.push(ch); chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); let kind = match ident.as_str() { "if" | "else" | "while" | "for" | "return" => TokenKind::Keyword(ident), _ => TokenKind::Identifier(ident), }; tokens.push(Token { kind, span: Span { start, end }, }); } _ => {} } } tokens } #[derive(Debug, Clone)] pub struct AstNode { pub kind: AstKind, pub children: SmallVec<[Box<AstNode>; 4]>, } pub fn build_simple_ast() -> AstNode { AstNode { kind: AstKind::Program, children: smallvec![Box::new(AstNode { kind: AstKind::Function("main".to_string()), children: smallvec![Box::new(AstNode { kind: AstKind::Block, children: smallvec![Box::new(AstNode { kind: AstKind::Expression, children: smallvec![Box::new(AstNode { kind: AstKind::Number(42), children: SmallVec::new(), })], })], })], })], } } #[derive(Debug, Clone)] pub struct Instruction { pub opcode: Opcode, pub operands: SmallVec<[Operand; 3]>, } #[derive(Debug, Clone)] pub enum Opcode { Load, Store, Add, Sub, Mul, Jmp, Ret, } #[derive(Debug, Clone)] pub enum Operand { Register(u8), Immediate(i32), Memory(u32), } pub fn create_instruction_sequence() -> Vec<Instruction> { vec![ Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(0), Operand::Memory(0x1000)], }, Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(1), Operand::Memory(0x1004)], }, Instruction { opcode: Opcode::Add, operands: smallvec![ Operand::Register(2), Operand::Register(0), Operand::Register(1) ], }, Instruction { opcode: Opcode::Store, operands: smallvec![Operand::Memory(0x1008), Operand::Register(2)], }, Instruction { opcode: Opcode::Ret, operands: SmallVec::new(), }, ] } pub fn demonstrate_capacity() { let mut vec: SmallVec<[i32; 4]> = SmallVec::new(); println!("Initial capacity: {}", vec.inline_size()); println!("Is heap allocated: {}", vec.spilled()); for i in 0..6 { vec.push(i); println!( "After pushing {}: capacity = {}, spilled = {}", i, vec.capacity(), vec.spilled() ); } } pub struct SymbolTable { scopes: SmallVec<[Scope; 8]>, } pub struct Scope { symbols: SmallVec<[(String, SymbolInfo); 16]>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub offset: usize, } #[derive(Debug, Clone)] pub enum SymbolKind { Variable, Function, Parameter, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } impl SymbolTable { pub fn new() -> Self { Self { scopes: smallvec![Scope { symbols: SmallVec::new(), }], } } pub fn push_scope(&mut self) { self.scopes.push(Scope { symbols: SmallVec::new(), }); } pub fn pop_scope(&mut self) { if self.scopes.len() > 1 { self.scopes.pop(); } } pub fn insert(&mut self, name: String, info: SymbolInfo) { if let Some(scope) = self.scopes.last_mut() { scope.symbols.push((name, info)); } } pub fn lookup(&self, name: &str) -> Option<&SymbolInfo> { for scope in self.scopes.iter().rev() { for (sym_name, info) in &scope.symbols { if sym_name == name { return Some(info); } } } None } } #[derive(Debug)] pub struct CompactError { pub messages: SmallVec<[String; 2]>, pub locations: SmallVec<[Location; 2]>, } #[derive(Debug)] pub struct Location { pub file: String, pub line: u32, pub column: u32, } impl CompactError { pub fn new(message: String, location: Location) -> Self { Self { messages: smallvec![message], locations: smallvec![location], } } pub fn add_context(&mut self, message: String, location: Location) { self.messages.push(message); self.locations.push(location); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let tokens = tokenize_expression("x + 42 * (y - 3)"); assert_eq!(tokens.len(), 9); assert!(matches!(tokens[0].kind, TokenKind::Identifier(_))); assert!(matches!(tokens[2].kind, TokenKind::Number(42))); } #[test] fn test_inline_capacity() { let vec: SmallVec<[i32; 8]> = smallvec![1, 2, 3, 4]; assert!(!vec.spilled()); assert_eq!(vec.len(), 4); assert_eq!(vec.capacity(), 8); } #[test] fn test_spill_to_heap() { let mut vec: SmallVec<[i32; 2]> = SmallVec::new(); vec.push(1); vec.push(2); assert!(!vec.spilled()); vec.push(3); assert!(vec.spilled()); assert!(vec.capacity() >= 3); } #[test] fn test_symbol_table() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 0, }, ); table.push_scope(); table.insert( "y".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 4, }, ); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_some()); table.pop_scope(); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_none()); } } #[derive(Debug, Clone)] pub enum AstKind { Program, Function(String), Block, Expression, Statement, Identifier(String), Number(i64), } }
Most AST nodes have fewer than 4 children, so this inline capacity avoids heap allocation for the common case. Function calls might have many arguments, but the vector seamlessly handles this by spilling to the heap.
Instruction Operands
Compiler intermediate representations often model instructions with a variable number of operands:
#![allow(unused)] fn main() { use smallvec::{smallvec, SmallVec}; #[derive(Debug, Clone, PartialEq)] pub struct Token { pub kind: TokenKind, pub span: Span, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), Operator(char), Keyword(String), } #[derive(Debug, Clone, PartialEq)] pub struct Span { pub start: usize, pub end: usize, } pub type TokenStream = SmallVec<[Token; 32]>; pub fn tokenize_expression(input: &str) -> TokenStream { let mut tokens = SmallVec::new(); let mut chars = input.char_indices().peekable(); while let Some((i, ch)) = chars.next() { match ch { ' ' | '\t' | '\n' => continue, '+' | '-' | '*' | '/' | '(' | ')' => { tokens.push(Token { kind: TokenKind::Operator(ch), span: Span { start: i, end: i + 1, }, }); } '0'..='9' => { let start = i; let mut value = ch.to_digit(10).unwrap() as i64; while let Some(&(_, ch)) = chars.peek() { if ch.is_ascii_digit() { value = value * 10 + ch.to_digit(10).unwrap() as i64; chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); tokens.push(Token { kind: TokenKind::Number(value), span: Span { start, end }, }); } 'a'..='z' | 'A'..='Z' | '_' => { let start = i; let mut ident = String::new(); ident.push(ch); while let Some(&(_, ch)) = chars.peek() { if ch.is_alphanumeric() || ch == '_' { ident.push(ch); chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); let kind = match ident.as_str() { "if" | "else" | "while" | "for" | "return" => TokenKind::Keyword(ident), _ => TokenKind::Identifier(ident), }; tokens.push(Token { kind, span: Span { start, end }, }); } _ => {} } } tokens } #[derive(Debug, Clone)] pub struct AstNode { pub kind: AstKind, pub children: SmallVec<[Box<AstNode>; 4]>, } #[derive(Debug, Clone)] pub enum AstKind { Program, Function(String), Block, Expression, Statement, Identifier(String), Number(i64), } pub fn build_simple_ast() -> AstNode { AstNode { kind: AstKind::Program, children: smallvec![Box::new(AstNode { kind: AstKind::Function("main".to_string()), children: smallvec![Box::new(AstNode { kind: AstKind::Block, children: smallvec![Box::new(AstNode { kind: AstKind::Expression, children: smallvec![Box::new(AstNode { kind: AstKind::Number(42), children: SmallVec::new(), })], })], })], })], } } #[derive(Debug, Clone)] pub enum Opcode { Load, Store, Add, Sub, Mul, Jmp, Ret, } #[derive(Debug, Clone)] pub enum Operand { Register(u8), Immediate(i32), Memory(u32), } pub fn create_instruction_sequence() -> Vec<Instruction> { vec![ Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(0), Operand::Memory(0x1000)], }, Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(1), Operand::Memory(0x1004)], }, Instruction { opcode: Opcode::Add, operands: smallvec![ Operand::Register(2), Operand::Register(0), Operand::Register(1) ], }, Instruction { opcode: Opcode::Store, operands: smallvec![Operand::Memory(0x1008), Operand::Register(2)], }, Instruction { opcode: Opcode::Ret, operands: SmallVec::new(), }, ] } pub fn demonstrate_capacity() { let mut vec: SmallVec<[i32; 4]> = SmallVec::new(); println!("Initial capacity: {}", vec.inline_size()); println!("Is heap allocated: {}", vec.spilled()); for i in 0..6 { vec.push(i); println!( "After pushing {}: capacity = {}, spilled = {}", i, vec.capacity(), vec.spilled() ); } } pub struct SymbolTable { scopes: SmallVec<[Scope; 8]>, } pub struct Scope { symbols: SmallVec<[(String, SymbolInfo); 16]>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub offset: usize, } #[derive(Debug, Clone)] pub enum SymbolKind { Variable, Function, Parameter, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } impl SymbolTable { pub fn new() -> Self { Self { scopes: smallvec![Scope { symbols: SmallVec::new(), }], } } pub fn push_scope(&mut self) { self.scopes.push(Scope { symbols: SmallVec::new(), }); } pub fn pop_scope(&mut self) { if self.scopes.len() > 1 { self.scopes.pop(); } } pub fn insert(&mut self, name: String, info: SymbolInfo) { if let Some(scope) = self.scopes.last_mut() { scope.symbols.push((name, info)); } } pub fn lookup(&self, name: &str) -> Option<&SymbolInfo> { for scope in self.scopes.iter().rev() { for (sym_name, info) in &scope.symbols { if sym_name == name { return Some(info); } } } None } } #[derive(Debug)] pub struct CompactError { pub messages: SmallVec<[String; 2]>, pub locations: SmallVec<[Location; 2]>, } #[derive(Debug)] pub struct Location { pub file: String, pub line: u32, pub column: u32, } impl CompactError { pub fn new(message: String, location: Location) -> Self { Self { messages: smallvec![message], locations: smallvec![location], } } pub fn add_context(&mut self, message: String, location: Location) { self.messages.push(message); self.locations.push(location); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let tokens = tokenize_expression("x + 42 * (y - 3)"); assert_eq!(tokens.len(), 9); assert!(matches!(tokens[0].kind, TokenKind::Identifier(_))); assert!(matches!(tokens[2].kind, TokenKind::Number(42))); } #[test] fn test_inline_capacity() { let vec: SmallVec<[i32; 8]> = smallvec![1, 2, 3, 4]; assert!(!vec.spilled()); assert_eq!(vec.len(), 4); assert_eq!(vec.capacity(), 8); } #[test] fn test_spill_to_heap() { let mut vec: SmallVec<[i32; 2]> = SmallVec::new(); vec.push(1); vec.push(2); assert!(!vec.spilled()); vec.push(3); assert!(vec.spilled()); assert!(vec.capacity() >= 3); } #[test] fn test_symbol_table() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 0, }, ); table.push_scope(); table.insert( "y".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 4, }, ); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_some()); table.pop_scope(); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_none()); } } #[derive(Debug, Clone)] pub struct Instruction { pub opcode: Opcode, pub operands: SmallVec<[Operand; 3]>, } }
#![allow(unused)] fn main() { use smallvec::{smallvec, SmallVec}; #[derive(Debug, Clone, PartialEq)] pub struct Token { pub kind: TokenKind, pub span: Span, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), Operator(char), Keyword(String), } #[derive(Debug, Clone, PartialEq)] pub struct Span { pub start: usize, pub end: usize, } pub type TokenStream = SmallVec<[Token; 32]>; pub fn tokenize_expression(input: &str) -> TokenStream { let mut tokens = SmallVec::new(); let mut chars = input.char_indices().peekable(); while let Some((i, ch)) = chars.next() { match ch { ' ' | '\t' | '\n' => continue, '+' | '-' | '*' | '/' | '(' | ')' => { tokens.push(Token { kind: TokenKind::Operator(ch), span: Span { start: i, end: i + 1, }, }); } '0'..='9' => { let start = i; let mut value = ch.to_digit(10).unwrap() as i64; while let Some(&(_, ch)) = chars.peek() { if ch.is_ascii_digit() { value = value * 10 + ch.to_digit(10).unwrap() as i64; chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); tokens.push(Token { kind: TokenKind::Number(value), span: Span { start, end }, }); } 'a'..='z' | 'A'..='Z' | '_' => { let start = i; let mut ident = String::new(); ident.push(ch); while let Some(&(_, ch)) = chars.peek() { if ch.is_alphanumeric() || ch == '_' { ident.push(ch); chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); let kind = match ident.as_str() { "if" | "else" | "while" | "for" | "return" => TokenKind::Keyword(ident), _ => TokenKind::Identifier(ident), }; tokens.push(Token { kind, span: Span { start, end }, }); } _ => {} } } tokens } #[derive(Debug, Clone)] pub struct AstNode { pub kind: AstKind, pub children: SmallVec<[Box<AstNode>; 4]>, } #[derive(Debug, Clone)] pub enum AstKind { Program, Function(String), Block, Expression, Statement, Identifier(String), Number(i64), } pub fn build_simple_ast() -> AstNode { AstNode { kind: AstKind::Program, children: smallvec![Box::new(AstNode { kind: AstKind::Function("main".to_string()), children: smallvec![Box::new(AstNode { kind: AstKind::Block, children: smallvec![Box::new(AstNode { kind: AstKind::Expression, children: smallvec![Box::new(AstNode { kind: AstKind::Number(42), children: SmallVec::new(), })], })], })], })], } } #[derive(Debug, Clone)] pub struct Instruction { pub opcode: Opcode, pub operands: SmallVec<[Operand; 3]>, } #[derive(Debug, Clone)] pub enum Opcode { Load, Store, Add, Sub, Mul, Jmp, Ret, } #[derive(Debug, Clone)] pub enum Operand { Register(u8), Immediate(i32), Memory(u32), } pub fn demonstrate_capacity() { let mut vec: SmallVec<[i32; 4]> = SmallVec::new(); println!("Initial capacity: {}", vec.inline_size()); println!("Is heap allocated: {}", vec.spilled()); for i in 0..6 { vec.push(i); println!( "After pushing {}: capacity = {}, spilled = {}", i, vec.capacity(), vec.spilled() ); } } pub struct SymbolTable { scopes: SmallVec<[Scope; 8]>, } pub struct Scope { symbols: SmallVec<[(String, SymbolInfo); 16]>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub offset: usize, } #[derive(Debug, Clone)] pub enum SymbolKind { Variable, Function, Parameter, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } impl SymbolTable { pub fn new() -> Self { Self { scopes: smallvec![Scope { symbols: SmallVec::new(), }], } } pub fn push_scope(&mut self) { self.scopes.push(Scope { symbols: SmallVec::new(), }); } pub fn pop_scope(&mut self) { if self.scopes.len() > 1 { self.scopes.pop(); } } pub fn insert(&mut self, name: String, info: SymbolInfo) { if let Some(scope) = self.scopes.last_mut() { scope.symbols.push((name, info)); } } pub fn lookup(&self, name: &str) -> Option<&SymbolInfo> { for scope in self.scopes.iter().rev() { for (sym_name, info) in &scope.symbols { if sym_name == name { return Some(info); } } } None } } #[derive(Debug)] pub struct CompactError { pub messages: SmallVec<[String; 2]>, pub locations: SmallVec<[Location; 2]>, } #[derive(Debug)] pub struct Location { pub file: String, pub line: u32, pub column: u32, } impl CompactError { pub fn new(message: String, location: Location) -> Self { Self { messages: smallvec![message], locations: smallvec![location], } } pub fn add_context(&mut self, message: String, location: Location) { self.messages.push(message); self.locations.push(location); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let tokens = tokenize_expression("x + 42 * (y - 3)"); assert_eq!(tokens.len(), 9); assert!(matches!(tokens[0].kind, TokenKind::Identifier(_))); assert!(matches!(tokens[2].kind, TokenKind::Number(42))); } #[test] fn test_inline_capacity() { let vec: SmallVec<[i32; 8]> = smallvec![1, 2, 3, 4]; assert!(!vec.spilled()); assert_eq!(vec.len(), 4); assert_eq!(vec.capacity(), 8); } #[test] fn test_spill_to_heap() { let mut vec: SmallVec<[i32; 2]> = SmallVec::new(); vec.push(1); vec.push(2); assert!(!vec.spilled()); vec.push(3); assert!(vec.spilled()); assert!(vec.capacity() >= 3); } #[test] fn test_symbol_table() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 0, }, ); table.push_scope(); table.insert( "y".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 4, }, ); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_some()); table.pop_scope(); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_none()); } } pub fn create_instruction_sequence() -> Vec<Instruction> { vec![ Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(0), Operand::Memory(0x1000)], }, Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(1), Operand::Memory(0x1004)], }, Instruction { opcode: Opcode::Add, operands: smallvec![ Operand::Register(2), Operand::Register(0), Operand::Register(1) ], }, Instruction { opcode: Opcode::Store, operands: smallvec![Operand::Memory(0x1008), Operand::Register(2)], }, Instruction { opcode: Opcode::Ret, operands: SmallVec::new(), }, ] } }
Most instructions have 0-3 operands, making SmallVec with inline capacity of 3 an excellent choice. This keeps instruction objects compact and cache-friendly.
Symbol Tables
Symbol tables benefit from SmallVec at multiple levels. Most scopes contain few symbols, and the scope stack itself is usually shallow:
#![allow(unused)] fn main() { use smallvec::{smallvec, SmallVec}; #[derive(Debug, Clone, PartialEq)] pub struct Token { pub kind: TokenKind, pub span: Span, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), Operator(char), Keyword(String), } #[derive(Debug, Clone, PartialEq)] pub struct Span { pub start: usize, pub end: usize, } pub type TokenStream = SmallVec<[Token; 32]>; pub fn tokenize_expression(input: &str) -> TokenStream { let mut tokens = SmallVec::new(); let mut chars = input.char_indices().peekable(); while let Some((i, ch)) = chars.next() { match ch { ' ' | '\t' | '\n' => continue, '+' | '-' | '*' | '/' | '(' | ')' => { tokens.push(Token { kind: TokenKind::Operator(ch), span: Span { start: i, end: i + 1, }, }); } '0'..='9' => { let start = i; let mut value = ch.to_digit(10).unwrap() as i64; while let Some(&(_, ch)) = chars.peek() { if ch.is_ascii_digit() { value = value * 10 + ch.to_digit(10).unwrap() as i64; chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); tokens.push(Token { kind: TokenKind::Number(value), span: Span { start, end }, }); } 'a'..='z' | 'A'..='Z' | '_' => { let start = i; let mut ident = String::new(); ident.push(ch); while let Some(&(_, ch)) = chars.peek() { if ch.is_alphanumeric() || ch == '_' { ident.push(ch); chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); let kind = match ident.as_str() { "if" | "else" | "while" | "for" | "return" => TokenKind::Keyword(ident), _ => TokenKind::Identifier(ident), }; tokens.push(Token { kind, span: Span { start, end }, }); } _ => {} } } tokens } #[derive(Debug, Clone)] pub struct AstNode { pub kind: AstKind, pub children: SmallVec<[Box<AstNode>; 4]>, } #[derive(Debug, Clone)] pub enum AstKind { Program, Function(String), Block, Expression, Statement, Identifier(String), Number(i64), } pub fn build_simple_ast() -> AstNode { AstNode { kind: AstKind::Program, children: smallvec![Box::new(AstNode { kind: AstKind::Function("main".to_string()), children: smallvec![Box::new(AstNode { kind: AstKind::Block, children: smallvec![Box::new(AstNode { kind: AstKind::Expression, children: smallvec![Box::new(AstNode { kind: AstKind::Number(42), children: SmallVec::new(), })], })], })], })], } } #[derive(Debug, Clone)] pub struct Instruction { pub opcode: Opcode, pub operands: SmallVec<[Operand; 3]>, } #[derive(Debug, Clone)] pub enum Opcode { Load, Store, Add, Sub, Mul, Jmp, Ret, } #[derive(Debug, Clone)] pub enum Operand { Register(u8), Immediate(i32), Memory(u32), } pub fn create_instruction_sequence() -> Vec<Instruction> { vec![ Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(0), Operand::Memory(0x1000)], }, Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(1), Operand::Memory(0x1004)], }, Instruction { opcode: Opcode::Add, operands: smallvec![ Operand::Register(2), Operand::Register(0), Operand::Register(1) ], }, Instruction { opcode: Opcode::Store, operands: smallvec![Operand::Memory(0x1008), Operand::Register(2)], }, Instruction { opcode: Opcode::Ret, operands: SmallVec::new(), }, ] } pub fn demonstrate_capacity() { let mut vec: SmallVec<[i32; 4]> = SmallVec::new(); println!("Initial capacity: {}", vec.inline_size()); println!("Is heap allocated: {}", vec.spilled()); for i in 0..6 { vec.push(i); println!( "After pushing {}: capacity = {}, spilled = {}", i, vec.capacity(), vec.spilled() ); } } pub struct Scope { symbols: SmallVec<[(String, SymbolInfo); 16]>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub offset: usize, } #[derive(Debug, Clone)] pub enum SymbolKind { Variable, Function, Parameter, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } impl SymbolTable { pub fn new() -> Self { Self { scopes: smallvec![Scope { symbols: SmallVec::new(), }], } } pub fn push_scope(&mut self) { self.scopes.push(Scope { symbols: SmallVec::new(), }); } pub fn pop_scope(&mut self) { if self.scopes.len() > 1 { self.scopes.pop(); } } pub fn insert(&mut self, name: String, info: SymbolInfo) { if let Some(scope) = self.scopes.last_mut() { scope.symbols.push((name, info)); } } pub fn lookup(&self, name: &str) -> Option<&SymbolInfo> { for scope in self.scopes.iter().rev() { for (sym_name, info) in &scope.symbols { if sym_name == name { return Some(info); } } } None } } #[derive(Debug)] pub struct CompactError { pub messages: SmallVec<[String; 2]>, pub locations: SmallVec<[Location; 2]>, } #[derive(Debug)] pub struct Location { pub file: String, pub line: u32, pub column: u32, } impl CompactError { pub fn new(message: String, location: Location) -> Self { Self { messages: smallvec![message], locations: smallvec![location], } } pub fn add_context(&mut self, message: String, location: Location) { self.messages.push(message); self.locations.push(location); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let tokens = tokenize_expression("x + 42 * (y - 3)"); assert_eq!(tokens.len(), 9); assert!(matches!(tokens[0].kind, TokenKind::Identifier(_))); assert!(matches!(tokens[2].kind, TokenKind::Number(42))); } #[test] fn test_inline_capacity() { let vec: SmallVec<[i32; 8]> = smallvec![1, 2, 3, 4]; assert!(!vec.spilled()); assert_eq!(vec.len(), 4); assert_eq!(vec.capacity(), 8); } #[test] fn test_spill_to_heap() { let mut vec: SmallVec<[i32; 2]> = SmallVec::new(); vec.push(1); vec.push(2); assert!(!vec.spilled()); vec.push(3); assert!(vec.spilled()); assert!(vec.capacity() >= 3); } #[test] fn test_symbol_table() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 0, }, ); table.push_scope(); table.insert( "y".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 4, }, ); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_some()); table.pop_scope(); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_none()); } } pub struct SymbolTable { scopes: SmallVec<[Scope; 8]>, } }
#![allow(unused)] fn main() { use smallvec::{smallvec, SmallVec}; #[derive(Debug, Clone, PartialEq)] pub struct Token { pub kind: TokenKind, pub span: Span, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), Operator(char), Keyword(String), } #[derive(Debug, Clone, PartialEq)] pub struct Span { pub start: usize, pub end: usize, } pub type TokenStream = SmallVec<[Token; 32]>; pub fn tokenize_expression(input: &str) -> TokenStream { let mut tokens = SmallVec::new(); let mut chars = input.char_indices().peekable(); while let Some((i, ch)) = chars.next() { match ch { ' ' | '\t' | '\n' => continue, '+' | '-' | '*' | '/' | '(' | ')' => { tokens.push(Token { kind: TokenKind::Operator(ch), span: Span { start: i, end: i + 1, }, }); } '0'..='9' => { let start = i; let mut value = ch.to_digit(10).unwrap() as i64; while let Some(&(_, ch)) = chars.peek() { if ch.is_ascii_digit() { value = value * 10 + ch.to_digit(10).unwrap() as i64; chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); tokens.push(Token { kind: TokenKind::Number(value), span: Span { start, end }, }); } 'a'..='z' | 'A'..='Z' | '_' => { let start = i; let mut ident = String::new(); ident.push(ch); while let Some(&(_, ch)) = chars.peek() { if ch.is_alphanumeric() || ch == '_' { ident.push(ch); chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); let kind = match ident.as_str() { "if" | "else" | "while" | "for" | "return" => TokenKind::Keyword(ident), _ => TokenKind::Identifier(ident), }; tokens.push(Token { kind, span: Span { start, end }, }); } _ => {} } } tokens } #[derive(Debug, Clone)] pub struct AstNode { pub kind: AstKind, pub children: SmallVec<[Box<AstNode>; 4]>, } #[derive(Debug, Clone)] pub enum AstKind { Program, Function(String), Block, Expression, Statement, Identifier(String), Number(i64), } pub fn build_simple_ast() -> AstNode { AstNode { kind: AstKind::Program, children: smallvec![Box::new(AstNode { kind: AstKind::Function("main".to_string()), children: smallvec![Box::new(AstNode { kind: AstKind::Block, children: smallvec![Box::new(AstNode { kind: AstKind::Expression, children: smallvec![Box::new(AstNode { kind: AstKind::Number(42), children: SmallVec::new(), })], })], })], })], } } #[derive(Debug, Clone)] pub struct Instruction { pub opcode: Opcode, pub operands: SmallVec<[Operand; 3]>, } #[derive(Debug, Clone)] pub enum Opcode { Load, Store, Add, Sub, Mul, Jmp, Ret, } #[derive(Debug, Clone)] pub enum Operand { Register(u8), Immediate(i32), Memory(u32), } pub fn create_instruction_sequence() -> Vec<Instruction> { vec![ Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(0), Operand::Memory(0x1000)], }, Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(1), Operand::Memory(0x1004)], }, Instruction { opcode: Opcode::Add, operands: smallvec![ Operand::Register(2), Operand::Register(0), Operand::Register(1) ], }, Instruction { opcode: Opcode::Store, operands: smallvec![Operand::Memory(0x1008), Operand::Register(2)], }, Instruction { opcode: Opcode::Ret, operands: SmallVec::new(), }, ] } pub fn demonstrate_capacity() { let mut vec: SmallVec<[i32; 4]> = SmallVec::new(); println!("Initial capacity: {}", vec.inline_size()); println!("Is heap allocated: {}", vec.spilled()); for i in 0..6 { vec.push(i); println!( "After pushing {}: capacity = {}, spilled = {}", i, vec.capacity(), vec.spilled() ); } } pub struct SymbolTable { scopes: SmallVec<[Scope; 8]>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub offset: usize, } #[derive(Debug, Clone)] pub enum SymbolKind { Variable, Function, Parameter, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } impl SymbolTable { pub fn new() -> Self { Self { scopes: smallvec![Scope { symbols: SmallVec::new(), }], } } pub fn push_scope(&mut self) { self.scopes.push(Scope { symbols: SmallVec::new(), }); } pub fn pop_scope(&mut self) { if self.scopes.len() > 1 { self.scopes.pop(); } } pub fn insert(&mut self, name: String, info: SymbolInfo) { if let Some(scope) = self.scopes.last_mut() { scope.symbols.push((name, info)); } } pub fn lookup(&self, name: &str) -> Option<&SymbolInfo> { for scope in self.scopes.iter().rev() { for (sym_name, info) in &scope.symbols { if sym_name == name { return Some(info); } } } None } } #[derive(Debug)] pub struct CompactError { pub messages: SmallVec<[String; 2]>, pub locations: SmallVec<[Location; 2]>, } #[derive(Debug)] pub struct Location { pub file: String, pub line: u32, pub column: u32, } impl CompactError { pub fn new(message: String, location: Location) -> Self { Self { messages: smallvec![message], locations: smallvec![location], } } pub fn add_context(&mut self, message: String, location: Location) { self.messages.push(message); self.locations.push(location); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let tokens = tokenize_expression("x + 42 * (y - 3)"); assert_eq!(tokens.len(), 9); assert!(matches!(tokens[0].kind, TokenKind::Identifier(_))); assert!(matches!(tokens[2].kind, TokenKind::Number(42))); } #[test] fn test_inline_capacity() { let vec: SmallVec<[i32; 8]> = smallvec![1, 2, 3, 4]; assert!(!vec.spilled()); assert_eq!(vec.len(), 4); assert_eq!(vec.capacity(), 8); } #[test] fn test_spill_to_heap() { let mut vec: SmallVec<[i32; 2]> = SmallVec::new(); vec.push(1); vec.push(2); assert!(!vec.spilled()); vec.push(3); assert!(vec.spilled()); assert!(vec.capacity() >= 3); } #[test] fn test_symbol_table() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 0, }, ); table.push_scope(); table.insert( "y".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 4, }, ); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_some()); table.pop_scope(); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_none()); } } pub struct Scope { symbols: SmallVec<[(String, SymbolInfo); 16]>, } }
#![allow(unused)] fn main() { use smallvec::{smallvec, SmallVec}; #[derive(Debug, Clone, PartialEq)] pub struct Token { pub kind: TokenKind, pub span: Span, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), Operator(char), Keyword(String), } #[derive(Debug, Clone, PartialEq)] pub struct Span { pub start: usize, pub end: usize, } pub type TokenStream = SmallVec<[Token; 32]>; pub fn tokenize_expression(input: &str) -> TokenStream { let mut tokens = SmallVec::new(); let mut chars = input.char_indices().peekable(); while let Some((i, ch)) = chars.next() { match ch { ' ' | '\t' | '\n' => continue, '+' | '-' | '*' | '/' | '(' | ')' => { tokens.push(Token { kind: TokenKind::Operator(ch), span: Span { start: i, end: i + 1, }, }); } '0'..='9' => { let start = i; let mut value = ch.to_digit(10).unwrap() as i64; while let Some(&(_, ch)) = chars.peek() { if ch.is_ascii_digit() { value = value * 10 + ch.to_digit(10).unwrap() as i64; chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); tokens.push(Token { kind: TokenKind::Number(value), span: Span { start, end }, }); } 'a'..='z' | 'A'..='Z' | '_' => { let start = i; let mut ident = String::new(); ident.push(ch); while let Some(&(_, ch)) = chars.peek() { if ch.is_alphanumeric() || ch == '_' { ident.push(ch); chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); let kind = match ident.as_str() { "if" | "else" | "while" | "for" | "return" => TokenKind::Keyword(ident), _ => TokenKind::Identifier(ident), }; tokens.push(Token { kind, span: Span { start, end }, }); } _ => {} } } tokens } #[derive(Debug, Clone)] pub struct AstNode { pub kind: AstKind, pub children: SmallVec<[Box<AstNode>; 4]>, } #[derive(Debug, Clone)] pub enum AstKind { Program, Function(String), Block, Expression, Statement, Identifier(String), Number(i64), } pub fn build_simple_ast() -> AstNode { AstNode { kind: AstKind::Program, children: smallvec![Box::new(AstNode { kind: AstKind::Function("main".to_string()), children: smallvec![Box::new(AstNode { kind: AstKind::Block, children: smallvec![Box::new(AstNode { kind: AstKind::Expression, children: smallvec![Box::new(AstNode { kind: AstKind::Number(42), children: SmallVec::new(), })], })], })], })], } } #[derive(Debug, Clone)] pub struct Instruction { pub opcode: Opcode, pub operands: SmallVec<[Operand; 3]>, } #[derive(Debug, Clone)] pub enum Opcode { Load, Store, Add, Sub, Mul, Jmp, Ret, } #[derive(Debug, Clone)] pub enum Operand { Register(u8), Immediate(i32), Memory(u32), } pub fn create_instruction_sequence() -> Vec<Instruction> { vec![ Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(0), Operand::Memory(0x1000)], }, Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(1), Operand::Memory(0x1004)], }, Instruction { opcode: Opcode::Add, operands: smallvec![ Operand::Register(2), Operand::Register(0), Operand::Register(1) ], }, Instruction { opcode: Opcode::Store, operands: smallvec![Operand::Memory(0x1008), Operand::Register(2)], }, Instruction { opcode: Opcode::Ret, operands: SmallVec::new(), }, ] } pub fn demonstrate_capacity() { let mut vec: SmallVec<[i32; 4]> = SmallVec::new(); println!("Initial capacity: {}", vec.inline_size()); println!("Is heap allocated: {}", vec.spilled()); for i in 0..6 { vec.push(i); println!( "After pushing {}: capacity = {}, spilled = {}", i, vec.capacity(), vec.spilled() ); } } pub struct SymbolTable { scopes: SmallVec<[Scope; 8]>, } pub struct Scope { symbols: SmallVec<[(String, SymbolInfo); 16]>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub offset: usize, } #[derive(Debug, Clone)] pub enum SymbolKind { Variable, Function, Parameter, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } #[derive(Debug)] pub struct CompactError { pub messages: SmallVec<[String; 2]>, pub locations: SmallVec<[Location; 2]>, } #[derive(Debug)] pub struct Location { pub file: String, pub line: u32, pub column: u32, } impl CompactError { pub fn new(message: String, location: Location) -> Self { Self { messages: smallvec![message], locations: smallvec![location], } } pub fn add_context(&mut self, message: String, location: Location) { self.messages.push(message); self.locations.push(location); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let tokens = tokenize_expression("x + 42 * (y - 3)"); assert_eq!(tokens.len(), 9); assert!(matches!(tokens[0].kind, TokenKind::Identifier(_))); assert!(matches!(tokens[2].kind, TokenKind::Number(42))); } #[test] fn test_inline_capacity() { let vec: SmallVec<[i32; 8]> = smallvec![1, 2, 3, 4]; assert!(!vec.spilled()); assert_eq!(vec.len(), 4); assert_eq!(vec.capacity(), 8); } #[test] fn test_spill_to_heap() { let mut vec: SmallVec<[i32; 2]> = SmallVec::new(); vec.push(1); vec.push(2); assert!(!vec.spilled()); vec.push(3); assert!(vec.spilled()); assert!(vec.capacity() >= 3); } #[test] fn test_symbol_table() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 0, }, ); table.push_scope(); table.insert( "y".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 4, }, ); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_some()); table.pop_scope(); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_none()); } } impl SymbolTable { pub fn new() -> Self { Self { scopes: smallvec![Scope { symbols: SmallVec::new(), }], } } pub fn push_scope(&mut self) { self.scopes.push(Scope { symbols: SmallVec::new(), }); } pub fn pop_scope(&mut self) { if self.scopes.len() > 1 { self.scopes.pop(); } } pub fn insert(&mut self, name: String, info: SymbolInfo) { if let Some(scope) = self.scopes.last_mut() { scope.symbols.push((name, info)); } } pub fn lookup(&self, name: &str) -> Option<&SymbolInfo> { for scope in self.scopes.iter().rev() { for (sym_name, info) in &scope.symbols { if sym_name == name { return Some(info); } } } None } } }
This implementation uses SmallVec for both the scope stack (usually less than 8 deep) and the symbol list within each scope (often less than 16 symbols). This provides excellent performance for typical programs while gracefully handling edge cases.
Error Context
Compiler errors often need to track context through multiple levels. SmallVec efficiently stores this context:
#![allow(unused)] fn main() { use smallvec::{smallvec, SmallVec}; #[derive(Debug, Clone, PartialEq)] pub struct Token { pub kind: TokenKind, pub span: Span, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), Operator(char), Keyword(String), } #[derive(Debug, Clone, PartialEq)] pub struct Span { pub start: usize, pub end: usize, } pub type TokenStream = SmallVec<[Token; 32]>; pub fn tokenize_expression(input: &str) -> TokenStream { let mut tokens = SmallVec::new(); let mut chars = input.char_indices().peekable(); while let Some((i, ch)) = chars.next() { match ch { ' ' | '\t' | '\n' => continue, '+' | '-' | '*' | '/' | '(' | ')' => { tokens.push(Token { kind: TokenKind::Operator(ch), span: Span { start: i, end: i + 1, }, }); } '0'..='9' => { let start = i; let mut value = ch.to_digit(10).unwrap() as i64; while let Some(&(_, ch)) = chars.peek() { if ch.is_ascii_digit() { value = value * 10 + ch.to_digit(10).unwrap() as i64; chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); tokens.push(Token { kind: TokenKind::Number(value), span: Span { start, end }, }); } 'a'..='z' | 'A'..='Z' | '_' => { let start = i; let mut ident = String::new(); ident.push(ch); while let Some(&(_, ch)) = chars.peek() { if ch.is_alphanumeric() || ch == '_' { ident.push(ch); chars.next(); } else { break; } } let end = chars.peek().map(|&(i, _)| i).unwrap_or(input.len()); let kind = match ident.as_str() { "if" | "else" | "while" | "for" | "return" => TokenKind::Keyword(ident), _ => TokenKind::Identifier(ident), }; tokens.push(Token { kind, span: Span { start, end }, }); } _ => {} } } tokens } #[derive(Debug, Clone)] pub struct AstNode { pub kind: AstKind, pub children: SmallVec<[Box<AstNode>; 4]>, } #[derive(Debug, Clone)] pub enum AstKind { Program, Function(String), Block, Expression, Statement, Identifier(String), Number(i64), } pub fn build_simple_ast() -> AstNode { AstNode { kind: AstKind::Program, children: smallvec![Box::new(AstNode { kind: AstKind::Function("main".to_string()), children: smallvec![Box::new(AstNode { kind: AstKind::Block, children: smallvec![Box::new(AstNode { kind: AstKind::Expression, children: smallvec![Box::new(AstNode { kind: AstKind::Number(42), children: SmallVec::new(), })], })], })], })], } } #[derive(Debug, Clone)] pub struct Instruction { pub opcode: Opcode, pub operands: SmallVec<[Operand; 3]>, } #[derive(Debug, Clone)] pub enum Opcode { Load, Store, Add, Sub, Mul, Jmp, Ret, } #[derive(Debug, Clone)] pub enum Operand { Register(u8), Immediate(i32), Memory(u32), } pub fn create_instruction_sequence() -> Vec<Instruction> { vec![ Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(0), Operand::Memory(0x1000)], }, Instruction { opcode: Opcode::Load, operands: smallvec![Operand::Register(1), Operand::Memory(0x1004)], }, Instruction { opcode: Opcode::Add, operands: smallvec![ Operand::Register(2), Operand::Register(0), Operand::Register(1) ], }, Instruction { opcode: Opcode::Store, operands: smallvec![Operand::Memory(0x1008), Operand::Register(2)], }, Instruction { opcode: Opcode::Ret, operands: SmallVec::new(), }, ] } pub fn demonstrate_capacity() { let mut vec: SmallVec<[i32; 4]> = SmallVec::new(); println!("Initial capacity: {}", vec.inline_size()); println!("Is heap allocated: {}", vec.spilled()); for i in 0..6 { vec.push(i); println!( "After pushing {}: capacity = {}, spilled = {}", i, vec.capacity(), vec.spilled() ); } } pub struct SymbolTable { scopes: SmallVec<[Scope; 8]>, } pub struct Scope { symbols: SmallVec<[(String, SymbolInfo); 16]>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub offset: usize, } #[derive(Debug, Clone)] pub enum SymbolKind { Variable, Function, Parameter, } impl Default for SymbolTable { fn default() -> Self { Self::new() } } impl SymbolTable { pub fn new() -> Self { Self { scopes: smallvec![Scope { symbols: SmallVec::new(), }], } } pub fn push_scope(&mut self) { self.scopes.push(Scope { symbols: SmallVec::new(), }); } pub fn pop_scope(&mut self) { if self.scopes.len() > 1 { self.scopes.pop(); } } pub fn insert(&mut self, name: String, info: SymbolInfo) { if let Some(scope) = self.scopes.last_mut() { scope.symbols.push((name, info)); } } pub fn lookup(&self, name: &str) -> Option<&SymbolInfo> { for scope in self.scopes.iter().rev() { for (sym_name, info) in &scope.symbols { if sym_name == name { return Some(info); } } } None } } #[derive(Debug)] pub struct Location { pub file: String, pub line: u32, pub column: u32, } impl CompactError { pub fn new(message: String, location: Location) -> Self { Self { messages: smallvec![message], locations: smallvec![location], } } pub fn add_context(&mut self, message: String, location: Location) { self.messages.push(message); self.locations.push(location); } } #[cfg(test)] mod tests { use super::*; #[test] fn test_tokenization() { let tokens = tokenize_expression("x + 42 * (y - 3)"); assert_eq!(tokens.len(), 9); assert!(matches!(tokens[0].kind, TokenKind::Identifier(_))); assert!(matches!(tokens[2].kind, TokenKind::Number(42))); } #[test] fn test_inline_capacity() { let vec: SmallVec<[i32; 8]> = smallvec![1, 2, 3, 4]; assert!(!vec.spilled()); assert_eq!(vec.len(), 4); assert_eq!(vec.capacity(), 8); } #[test] fn test_spill_to_heap() { let mut vec: SmallVec<[i32; 2]> = SmallVec::new(); vec.push(1); vec.push(2); assert!(!vec.spilled()); vec.push(3); assert!(vec.spilled()); assert!(vec.capacity() >= 3); } #[test] fn test_symbol_table() { let mut table = SymbolTable::new(); table.insert( "x".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 0, }, ); table.push_scope(); table.insert( "y".to_string(), SymbolInfo { kind: SymbolKind::Variable, offset: 4, }, ); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_some()); table.pop_scope(); assert!(table.lookup("x").is_some()); assert!(table.lookup("y").is_none()); } } #[derive(Debug)] pub struct CompactError { pub messages: SmallVec<[String; 2]>, pub locations: SmallVec<[Location; 2]>, } }
Most errors have only one or two context levels, so inline storage of 2 elements covers the common case without allocation.
Best Practices
Choose inline capacity based on profiling and typical use cases. Too small wastes the optimization opportunity, while too large wastes stack space. Common sweet spots are 2-4 for AST children, 8-16 for local collections, and 32-64 for token buffers.
Be aware of the size implications. A SmallVec<[T; N]>
is approximately the size of N elements plus a discriminant and pointer. This can make structs larger, potentially affecting cache behavior. Measure the trade-offs in your specific use case.
Use type aliases to make code more readable and to centralize capacity decisions. This makes it easy to tune capacities based on profiling data.
Consider using SmallVec in hot paths where allocation overhead matters. Parser combinators, visitor patterns, and iterative algorithms often benefit significantly.
The smallvec!
macro provides convenient initialization similar to vec!
. Use it for clarity when creating SmallVecs with initial values.
For recursive structures like ASTs, SmallVec can dramatically reduce total allocations. A tree with depth D and branching factor B would normally require O(B^D) allocations, but with SmallVec, most nodes require zero heap allocations.
symbol_table
The symbol_table
crate provides fast and concurrent symbol interning for compilers. Symbol interning is the process of storing only one copy of each distinct string value, which provides both memory efficiency and fast equality comparisons. In compiler development, symbols appear everywhere: identifiers, keywords, string literals, type names, and module paths. By interning these strings, compilers can use simple pointer comparisons instead of string comparisons, dramatically improving performance.
The crate offers two main APIs: a thread-local SymbolTable
for single-threaded use and a global GlobalSymbol
type that provides concurrent access across threads. The global symbols are particularly useful in modern compilers that use parallel parsing, type checking, or code generation. The static_symbol!
macro enables compile-time symbol creation for known strings like keywords, avoiding runtime overhead entirely.
Static Symbols
For symbols known at compile time, the static_symbol!
macro provides the fastest possible access:
#![allow(unused)] fn main() { use std::collections::HashMap; use std::sync::{Arc, RwLock}; use symbol_table::{static_symbol, GlobalSymbol, Symbol, SymbolTable}; #[derive(Debug, Clone)] pub struct Identifier { pub symbol: GlobalSymbol, pub span: Span, } #[derive(Debug, Clone, Copy)] pub struct Span { pub start: usize, pub end: usize, } impl Identifier { pub fn new(name: &str, span: Span) -> Self { Self { symbol: GlobalSymbol::from(name), span, } } pub fn as_str(&self) -> &str { self.symbol.as_str() } } pub fn demonstrate_global_symbols() { let sym1 = GlobalSymbol::from("variable_name"); let sym2 = GlobalSymbol::from("variable_name"); let sym3 = GlobalSymbol::from("another_name"); println!("Symbol interning:"); println!(" sym1 == sym2: {}", sym1 == sym2); println!(" sym1 == sym3: {}", sym1 == sym3); println!( " Pointers equal: {}", std::ptr::eq(sym1.as_str(), sym2.as_str()) ); } pub struct CompilerContext { pub symbols: SymbolTable, pub keywords: HashMap<Symbol, TokenKind>, pub string_literals: Vec<Symbol>, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { If, Else, While, For, Return, Function, Let, Const, } impl Default for CompilerContext { fn default() -> Self { Self::new() } } impl CompilerContext { pub fn new() -> Self { let symbols = SymbolTable::new(); let mut keywords = HashMap::new(); keywords.insert(symbols.intern("if"), TokenKind::If); keywords.insert(symbols.intern("else"), TokenKind::Else); keywords.insert(symbols.intern("while"), TokenKind::While); keywords.insert(symbols.intern("for"), TokenKind::For); keywords.insert(symbols.intern("return"), TokenKind::Return); keywords.insert(symbols.intern("function"), TokenKind::Function); keywords.insert(symbols.intern("let"), TokenKind::Let); keywords.insert(symbols.intern("const"), TokenKind::Const); Self { symbols, keywords, string_literals: Vec::new(), } } pub fn intern_string(&mut self, s: &str) -> Symbol { self.symbols.intern(s) } pub fn is_keyword(&self, sym: Symbol) -> Option<&TokenKind> { self.keywords.get(&sym) } pub fn add_string_literal(&mut self, s: &str) -> usize { let sym = self.symbols.intern(s); self.string_literals.push(sym); self.string_literals.len() - 1 } } pub fn demonstrate_compiler_context() { let mut ctx = CompilerContext::new(); let ident = ctx.intern_string("my_variable"); let keyword = ctx.intern_string("if"); println!("Identifier 'my_variable' interned as Symbol"); println!("Is 'my_variable' a keyword: {:?}", ctx.is_keyword(ident)); println!("Is 'if' a keyword: {:?}", ctx.is_keyword(keyword)); let lit_idx = ctx.add_string_literal("Hello, world!"); println!("String literal index: {}", lit_idx); } #[derive(Debug)] pub struct ModuleSymbolTable { pub name: GlobalSymbol, pub exported: HashMap<GlobalSymbol, SymbolInfo>, pub internal: HashMap<GlobalSymbol, SymbolInfo>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub defined_at: Option<Location>, pub type_info: Option<String>, } #[derive(Debug, Clone)] pub enum SymbolKind { Function, Variable, Type, Module, } #[derive(Debug, Clone)] pub struct Location { pub file: GlobalSymbol, pub line: u32, pub column: u32, } impl ModuleSymbolTable { pub fn new(name: &str) -> Self { Self { name: GlobalSymbol::from(name), exported: HashMap::new(), internal: HashMap::new(), } } pub fn define_exported(&mut self, name: &str, info: SymbolInfo) { self.exported.insert(GlobalSymbol::from(name), info); } pub fn define_internal(&mut self, name: &str, info: SymbolInfo) { self.internal.insert(GlobalSymbol::from(name), info); } pub fn lookup(&self, name: &GlobalSymbol) -> Option<&SymbolInfo> { self.exported.get(name).or_else(|| self.internal.get(name)) } } pub type ConcurrentSymbolCache = Arc<RwLock<HashMap<GlobalSymbol, String>>>; pub fn create_concurrent_cache() -> ConcurrentSymbolCache { Arc::new(RwLock::new(HashMap::new())) } pub fn demonstrate_concurrent_access() { let cache = create_concurrent_cache(); let symbols: Vec<_> = (0..10) .map(|i| GlobalSymbol::from(format!("symbol_{}", i))) .collect(); { let mut cache_write = cache.write().unwrap(); for (i, sym) in symbols.iter().enumerate() { cache_write.insert(*sym, format!("Value {}", i)); } } { let cache_read = cache.read().unwrap(); println!("Concurrent cache contents:"); for sym in &symbols[..5] { if let Some(value) = cache_read.get(sym) { println!(" {} => {}", sym.as_str(), value); } } } } pub fn benchmark_symbol_creation() { use std::time::Instant; let iterations = 10000; let unique_symbols = 1000; let start = Instant::now(); let table = SymbolTable::new(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); table.intern(&name); } let local_time = start.elapsed(); let start = Instant::now(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); let _ = GlobalSymbol::from(name); } let global_time = start.elapsed(); println!( "Symbol creation benchmark ({} iterations, {} unique):", iterations, unique_symbols ); println!(" Local SymbolTable: {:?}", local_time); println!(" GlobalSymbol: {:?}", global_time); println!( " Ratio: {:.2}x", local_time.as_nanos() as f64 / global_time.as_nanos() as f64 ); } #[cfg(test)] mod tests { use super::*; #[test] fn test_static_symbol_equality() { let sym1 = static_symbol!("test"); let sym2 = static_symbol!("test"); let sym3 = GlobalSymbol::from("test"); assert_eq!(sym1, sym2); assert_eq!(sym1, sym3); assert!(std::ptr::eq(sym1.as_str(), sym2.as_str())); } #[test] fn test_identifier() { let ident = Identifier::new("foo", Span { start: 0, end: 3 }); assert_eq!(ident.as_str(), "foo"); assert_eq!(ident.symbol, GlobalSymbol::from("foo")); } #[test] fn test_compiler_context() { let mut ctx = CompilerContext::new(); let sym = ctx.intern_string("if"); assert!(ctx.is_keyword(sym).is_some()); assert_eq!(*ctx.is_keyword(sym).unwrap(), TokenKind::If); let var_sym = ctx.intern_string("my_var"); assert!(ctx.is_keyword(var_sym).is_none()); } #[test] fn test_module_symbol_table() { let mut module = ModuleSymbolTable::new("my_module"); module.define_exported( "public_func", SymbolInfo { kind: SymbolKind::Function, defined_at: None, type_info: Some("fn() -> i32".to_string()), }, ); module.define_internal( "private_var", SymbolInfo { kind: SymbolKind::Variable, defined_at: None, type_info: Some("String".to_string()), }, ); let pub_sym = GlobalSymbol::from("public_func"); let priv_sym = GlobalSymbol::from("private_var"); assert!(module.lookup(&pub_sym).is_some()); assert!(module.lookup(&priv_sym).is_some()); assert!(module.exported.contains_key(&pub_sym)); assert!(!module.exported.contains_key(&priv_sym)); } } pub fn demonstrate_static_symbols() { let if_sym = static_symbol!("if"); let else_sym = static_symbol!("else"); let while_sym = static_symbol!("while"); let for_sym = static_symbol!("for"); let return_sym = static_symbol!("return"); println!("Static symbols created:"); println!(" if: {:?}", if_sym); println!(" else: {:?}", else_sym); println!(" while: {:?}", while_sym); println!(" for: {:?}", for_sym); println!(" return: {:?}", return_sym); let if_sym2 = static_symbol!("if"); println!("\nSymbol equality: if == if2: {}", if_sym == if_sym2); println!( "Pointer equality: {:p} == {:p}", if_sym.as_str(), if_sym2.as_str() ); } }
Static symbols are created once and cached forever. Subsequent calls with the same string return the exact same symbol without any synchronization overhead. This makes them perfect for language keywords and built-in identifiers.
Global Symbol Interning
The GlobalSymbol
type provides thread-safe symbol interning:
#![allow(unused)] fn main() { use std::collections::HashMap; use std::sync::{Arc, RwLock}; use symbol_table::{static_symbol, GlobalSymbol, Symbol, SymbolTable}; #[derive(Debug, Clone)] pub struct Identifier { pub symbol: GlobalSymbol, pub span: Span, } #[derive(Debug, Clone, Copy)] pub struct Span { pub start: usize, pub end: usize, } impl Identifier { pub fn new(name: &str, span: Span) -> Self { Self { symbol: GlobalSymbol::from(name), span, } } pub fn as_str(&self) -> &str { self.symbol.as_str() } } pub fn demonstrate_static_symbols() { let if_sym = static_symbol!("if"); let else_sym = static_symbol!("else"); let while_sym = static_symbol!("while"); let for_sym = static_symbol!("for"); let return_sym = static_symbol!("return"); println!("Static symbols created:"); println!(" if: {:?}", if_sym); println!(" else: {:?}", else_sym); println!(" while: {:?}", while_sym); println!(" for: {:?}", for_sym); println!(" return: {:?}", return_sym); let if_sym2 = static_symbol!("if"); println!("\nSymbol equality: if == if2: {}", if_sym == if_sym2); println!( "Pointer equality: {:p} == {:p}", if_sym.as_str(), if_sym2.as_str() ); } pub struct CompilerContext { pub symbols: SymbolTable, pub keywords: HashMap<Symbol, TokenKind>, pub string_literals: Vec<Symbol>, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { If, Else, While, For, Return, Function, Let, Const, } impl Default for CompilerContext { fn default() -> Self { Self::new() } } impl CompilerContext { pub fn new() -> Self { let symbols = SymbolTable::new(); let mut keywords = HashMap::new(); keywords.insert(symbols.intern("if"), TokenKind::If); keywords.insert(symbols.intern("else"), TokenKind::Else); keywords.insert(symbols.intern("while"), TokenKind::While); keywords.insert(symbols.intern("for"), TokenKind::For); keywords.insert(symbols.intern("return"), TokenKind::Return); keywords.insert(symbols.intern("function"), TokenKind::Function); keywords.insert(symbols.intern("let"), TokenKind::Let); keywords.insert(symbols.intern("const"), TokenKind::Const); Self { symbols, keywords, string_literals: Vec::new(), } } pub fn intern_string(&mut self, s: &str) -> Symbol { self.symbols.intern(s) } pub fn is_keyword(&self, sym: Symbol) -> Option<&TokenKind> { self.keywords.get(&sym) } pub fn add_string_literal(&mut self, s: &str) -> usize { let sym = self.symbols.intern(s); self.string_literals.push(sym); self.string_literals.len() - 1 } } pub fn demonstrate_compiler_context() { let mut ctx = CompilerContext::new(); let ident = ctx.intern_string("my_variable"); let keyword = ctx.intern_string("if"); println!("Identifier 'my_variable' interned as Symbol"); println!("Is 'my_variable' a keyword: {:?}", ctx.is_keyword(ident)); println!("Is 'if' a keyword: {:?}", ctx.is_keyword(keyword)); let lit_idx = ctx.add_string_literal("Hello, world!"); println!("String literal index: {}", lit_idx); } #[derive(Debug)] pub struct ModuleSymbolTable { pub name: GlobalSymbol, pub exported: HashMap<GlobalSymbol, SymbolInfo>, pub internal: HashMap<GlobalSymbol, SymbolInfo>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub defined_at: Option<Location>, pub type_info: Option<String>, } #[derive(Debug, Clone)] pub enum SymbolKind { Function, Variable, Type, Module, } #[derive(Debug, Clone)] pub struct Location { pub file: GlobalSymbol, pub line: u32, pub column: u32, } impl ModuleSymbolTable { pub fn new(name: &str) -> Self { Self { name: GlobalSymbol::from(name), exported: HashMap::new(), internal: HashMap::new(), } } pub fn define_exported(&mut self, name: &str, info: SymbolInfo) { self.exported.insert(GlobalSymbol::from(name), info); } pub fn define_internal(&mut self, name: &str, info: SymbolInfo) { self.internal.insert(GlobalSymbol::from(name), info); } pub fn lookup(&self, name: &GlobalSymbol) -> Option<&SymbolInfo> { self.exported.get(name).or_else(|| self.internal.get(name)) } } pub type ConcurrentSymbolCache = Arc<RwLock<HashMap<GlobalSymbol, String>>>; pub fn create_concurrent_cache() -> ConcurrentSymbolCache { Arc::new(RwLock::new(HashMap::new())) } pub fn demonstrate_concurrent_access() { let cache = create_concurrent_cache(); let symbols: Vec<_> = (0..10) .map(|i| GlobalSymbol::from(format!("symbol_{}", i))) .collect(); { let mut cache_write = cache.write().unwrap(); for (i, sym) in symbols.iter().enumerate() { cache_write.insert(*sym, format!("Value {}", i)); } } { let cache_read = cache.read().unwrap(); println!("Concurrent cache contents:"); for sym in &symbols[..5] { if let Some(value) = cache_read.get(sym) { println!(" {} => {}", sym.as_str(), value); } } } } pub fn benchmark_symbol_creation() { use std::time::Instant; let iterations = 10000; let unique_symbols = 1000; let start = Instant::now(); let table = SymbolTable::new(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); table.intern(&name); } let local_time = start.elapsed(); let start = Instant::now(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); let _ = GlobalSymbol::from(name); } let global_time = start.elapsed(); println!( "Symbol creation benchmark ({} iterations, {} unique):", iterations, unique_symbols ); println!(" Local SymbolTable: {:?}", local_time); println!(" GlobalSymbol: {:?}", global_time); println!( " Ratio: {:.2}x", local_time.as_nanos() as f64 / global_time.as_nanos() as f64 ); } #[cfg(test)] mod tests { use super::*; #[test] fn test_static_symbol_equality() { let sym1 = static_symbol!("test"); let sym2 = static_symbol!("test"); let sym3 = GlobalSymbol::from("test"); assert_eq!(sym1, sym2); assert_eq!(sym1, sym3); assert!(std::ptr::eq(sym1.as_str(), sym2.as_str())); } #[test] fn test_identifier() { let ident = Identifier::new("foo", Span { start: 0, end: 3 }); assert_eq!(ident.as_str(), "foo"); assert_eq!(ident.symbol, GlobalSymbol::from("foo")); } #[test] fn test_compiler_context() { let mut ctx = CompilerContext::new(); let sym = ctx.intern_string("if"); assert!(ctx.is_keyword(sym).is_some()); assert_eq!(*ctx.is_keyword(sym).unwrap(), TokenKind::If); let var_sym = ctx.intern_string("my_var"); assert!(ctx.is_keyword(var_sym).is_none()); } #[test] fn test_module_symbol_table() { let mut module = ModuleSymbolTable::new("my_module"); module.define_exported( "public_func", SymbolInfo { kind: SymbolKind::Function, defined_at: None, type_info: Some("fn() -> i32".to_string()), }, ); module.define_internal( "private_var", SymbolInfo { kind: SymbolKind::Variable, defined_at: None, type_info: Some("String".to_string()), }, ); let pub_sym = GlobalSymbol::from("public_func"); let priv_sym = GlobalSymbol::from("private_var"); assert!(module.lookup(&pub_sym).is_some()); assert!(module.lookup(&priv_sym).is_some()); assert!(module.exported.contains_key(&pub_sym)); assert!(!module.exported.contains_key(&priv_sym)); } } pub fn demonstrate_global_symbols() { let sym1 = GlobalSymbol::from("variable_name"); let sym2 = GlobalSymbol::from("variable_name"); let sym3 = GlobalSymbol::from("another_name"); println!("Symbol interning:"); println!(" sym1 == sym2: {}", sym1 == sym2); println!(" sym1 == sym3: {}", sym1 == sym3); println!( " Pointers equal: {}", std::ptr::eq(sym1.as_str(), sym2.as_str()) ); } }
When multiple threads intern the same string, they receive symbols that compare equal and point to the same underlying string data. This enables efficient symbol sharing across parallel compiler passes.
Compiler Context
A typical compiler pattern is to maintain a context with interned keywords and symbols:
#![allow(unused)] fn main() { use std::collections::HashMap; use std::sync::{Arc, RwLock}; use symbol_table::{static_symbol, GlobalSymbol, Symbol, SymbolTable}; #[derive(Debug, Clone)] pub struct Identifier { pub symbol: GlobalSymbol, pub span: Span, } #[derive(Debug, Clone, Copy)] pub struct Span { pub start: usize, pub end: usize, } impl Identifier { pub fn new(name: &str, span: Span) -> Self { Self { symbol: GlobalSymbol::from(name), span, } } pub fn as_str(&self) -> &str { self.symbol.as_str() } } pub fn demonstrate_static_symbols() { let if_sym = static_symbol!("if"); let else_sym = static_symbol!("else"); let while_sym = static_symbol!("while"); let for_sym = static_symbol!("for"); let return_sym = static_symbol!("return"); println!("Static symbols created:"); println!(" if: {:?}", if_sym); println!(" else: {:?}", else_sym); println!(" while: {:?}", while_sym); println!(" for: {:?}", for_sym); println!(" return: {:?}", return_sym); let if_sym2 = static_symbol!("if"); println!("\nSymbol equality: if == if2: {}", if_sym == if_sym2); println!( "Pointer equality: {:p} == {:p}", if_sym.as_str(), if_sym2.as_str() ); } pub fn demonstrate_global_symbols() { let sym1 = GlobalSymbol::from("variable_name"); let sym2 = GlobalSymbol::from("variable_name"); let sym3 = GlobalSymbol::from("another_name"); println!("Symbol interning:"); println!(" sym1 == sym2: {}", sym1 == sym2); println!(" sym1 == sym3: {}", sym1 == sym3); println!( " Pointers equal: {}", std::ptr::eq(sym1.as_str(), sym2.as_str()) ); } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { If, Else, While, For, Return, Function, Let, Const, } impl Default for CompilerContext { fn default() -> Self { Self::new() } } impl CompilerContext { pub fn new() -> Self { let symbols = SymbolTable::new(); let mut keywords = HashMap::new(); keywords.insert(symbols.intern("if"), TokenKind::If); keywords.insert(symbols.intern("else"), TokenKind::Else); keywords.insert(symbols.intern("while"), TokenKind::While); keywords.insert(symbols.intern("for"), TokenKind::For); keywords.insert(symbols.intern("return"), TokenKind::Return); keywords.insert(symbols.intern("function"), TokenKind::Function); keywords.insert(symbols.intern("let"), TokenKind::Let); keywords.insert(symbols.intern("const"), TokenKind::Const); Self { symbols, keywords, string_literals: Vec::new(), } } pub fn intern_string(&mut self, s: &str) -> Symbol { self.symbols.intern(s) } pub fn is_keyword(&self, sym: Symbol) -> Option<&TokenKind> { self.keywords.get(&sym) } pub fn add_string_literal(&mut self, s: &str) -> usize { let sym = self.symbols.intern(s); self.string_literals.push(sym); self.string_literals.len() - 1 } } pub fn demonstrate_compiler_context() { let mut ctx = CompilerContext::new(); let ident = ctx.intern_string("my_variable"); let keyword = ctx.intern_string("if"); println!("Identifier 'my_variable' interned as Symbol"); println!("Is 'my_variable' a keyword: {:?}", ctx.is_keyword(ident)); println!("Is 'if' a keyword: {:?}", ctx.is_keyword(keyword)); let lit_idx = ctx.add_string_literal("Hello, world!"); println!("String literal index: {}", lit_idx); } #[derive(Debug)] pub struct ModuleSymbolTable { pub name: GlobalSymbol, pub exported: HashMap<GlobalSymbol, SymbolInfo>, pub internal: HashMap<GlobalSymbol, SymbolInfo>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub defined_at: Option<Location>, pub type_info: Option<String>, } #[derive(Debug, Clone)] pub enum SymbolKind { Function, Variable, Type, Module, } #[derive(Debug, Clone)] pub struct Location { pub file: GlobalSymbol, pub line: u32, pub column: u32, } impl ModuleSymbolTable { pub fn new(name: &str) -> Self { Self { name: GlobalSymbol::from(name), exported: HashMap::new(), internal: HashMap::new(), } } pub fn define_exported(&mut self, name: &str, info: SymbolInfo) { self.exported.insert(GlobalSymbol::from(name), info); } pub fn define_internal(&mut self, name: &str, info: SymbolInfo) { self.internal.insert(GlobalSymbol::from(name), info); } pub fn lookup(&self, name: &GlobalSymbol) -> Option<&SymbolInfo> { self.exported.get(name).or_else(|| self.internal.get(name)) } } pub type ConcurrentSymbolCache = Arc<RwLock<HashMap<GlobalSymbol, String>>>; pub fn create_concurrent_cache() -> ConcurrentSymbolCache { Arc::new(RwLock::new(HashMap::new())) } pub fn demonstrate_concurrent_access() { let cache = create_concurrent_cache(); let symbols: Vec<_> = (0..10) .map(|i| GlobalSymbol::from(format!("symbol_{}", i))) .collect(); { let mut cache_write = cache.write().unwrap(); for (i, sym) in symbols.iter().enumerate() { cache_write.insert(*sym, format!("Value {}", i)); } } { let cache_read = cache.read().unwrap(); println!("Concurrent cache contents:"); for sym in &symbols[..5] { if let Some(value) = cache_read.get(sym) { println!(" {} => {}", sym.as_str(), value); } } } } pub fn benchmark_symbol_creation() { use std::time::Instant; let iterations = 10000; let unique_symbols = 1000; let start = Instant::now(); let table = SymbolTable::new(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); table.intern(&name); } let local_time = start.elapsed(); let start = Instant::now(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); let _ = GlobalSymbol::from(name); } let global_time = start.elapsed(); println!( "Symbol creation benchmark ({} iterations, {} unique):", iterations, unique_symbols ); println!(" Local SymbolTable: {:?}", local_time); println!(" GlobalSymbol: {:?}", global_time); println!( " Ratio: {:.2}x", local_time.as_nanos() as f64 / global_time.as_nanos() as f64 ); } #[cfg(test)] mod tests { use super::*; #[test] fn test_static_symbol_equality() { let sym1 = static_symbol!("test"); let sym2 = static_symbol!("test"); let sym3 = GlobalSymbol::from("test"); assert_eq!(sym1, sym2); assert_eq!(sym1, sym3); assert!(std::ptr::eq(sym1.as_str(), sym2.as_str())); } #[test] fn test_identifier() { let ident = Identifier::new("foo", Span { start: 0, end: 3 }); assert_eq!(ident.as_str(), "foo"); assert_eq!(ident.symbol, GlobalSymbol::from("foo")); } #[test] fn test_compiler_context() { let mut ctx = CompilerContext::new(); let sym = ctx.intern_string("if"); assert!(ctx.is_keyword(sym).is_some()); assert_eq!(*ctx.is_keyword(sym).unwrap(), TokenKind::If); let var_sym = ctx.intern_string("my_var"); assert!(ctx.is_keyword(var_sym).is_none()); } #[test] fn test_module_symbol_table() { let mut module = ModuleSymbolTable::new("my_module"); module.define_exported( "public_func", SymbolInfo { kind: SymbolKind::Function, defined_at: None, type_info: Some("fn() -> i32".to_string()), }, ); module.define_internal( "private_var", SymbolInfo { kind: SymbolKind::Variable, defined_at: None, type_info: Some("String".to_string()), }, ); let pub_sym = GlobalSymbol::from("public_func"); let priv_sym = GlobalSymbol::from("private_var"); assert!(module.lookup(&pub_sym).is_some()); assert!(module.lookup(&priv_sym).is_some()); assert!(module.exported.contains_key(&pub_sym)); assert!(!module.exported.contains_key(&priv_sym)); } } pub struct CompilerContext { pub symbols: SymbolTable, pub keywords: HashMap<Symbol, TokenKind>, pub string_literals: Vec<Symbol>, } }
#![allow(unused)] fn main() { use std::collections::HashMap; use std::sync::{Arc, RwLock}; use symbol_table::{static_symbol, GlobalSymbol, Symbol, SymbolTable}; #[derive(Debug, Clone)] pub struct Identifier { pub symbol: GlobalSymbol, pub span: Span, } #[derive(Debug, Clone, Copy)] pub struct Span { pub start: usize, pub end: usize, } impl Identifier { pub fn new(name: &str, span: Span) -> Self { Self { symbol: GlobalSymbol::from(name), span, } } pub fn as_str(&self) -> &str { self.symbol.as_str() } } pub fn demonstrate_static_symbols() { let if_sym = static_symbol!("if"); let else_sym = static_symbol!("else"); let while_sym = static_symbol!("while"); let for_sym = static_symbol!("for"); let return_sym = static_symbol!("return"); println!("Static symbols created:"); println!(" if: {:?}", if_sym); println!(" else: {:?}", else_sym); println!(" while: {:?}", while_sym); println!(" for: {:?}", for_sym); println!(" return: {:?}", return_sym); let if_sym2 = static_symbol!("if"); println!("\nSymbol equality: if == if2: {}", if_sym == if_sym2); println!( "Pointer equality: {:p} == {:p}", if_sym.as_str(), if_sym2.as_str() ); } pub fn demonstrate_global_symbols() { let sym1 = GlobalSymbol::from("variable_name"); let sym2 = GlobalSymbol::from("variable_name"); let sym3 = GlobalSymbol::from("another_name"); println!("Symbol interning:"); println!(" sym1 == sym2: {}", sym1 == sym2); println!(" sym1 == sym3: {}", sym1 == sym3); println!( " Pointers equal: {}", std::ptr::eq(sym1.as_str(), sym2.as_str()) ); } pub struct CompilerContext { pub symbols: SymbolTable, pub keywords: HashMap<Symbol, TokenKind>, pub string_literals: Vec<Symbol>, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { If, Else, While, For, Return, Function, Let, Const, } impl Default for CompilerContext { fn default() -> Self { Self::new() } } pub fn demonstrate_compiler_context() { let mut ctx = CompilerContext::new(); let ident = ctx.intern_string("my_variable"); let keyword = ctx.intern_string("if"); println!("Identifier 'my_variable' interned as Symbol"); println!("Is 'my_variable' a keyword: {:?}", ctx.is_keyword(ident)); println!("Is 'if' a keyword: {:?}", ctx.is_keyword(keyword)); let lit_idx = ctx.add_string_literal("Hello, world!"); println!("String literal index: {}", lit_idx); } #[derive(Debug)] pub struct ModuleSymbolTable { pub name: GlobalSymbol, pub exported: HashMap<GlobalSymbol, SymbolInfo>, pub internal: HashMap<GlobalSymbol, SymbolInfo>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub defined_at: Option<Location>, pub type_info: Option<String>, } #[derive(Debug, Clone)] pub enum SymbolKind { Function, Variable, Type, Module, } #[derive(Debug, Clone)] pub struct Location { pub file: GlobalSymbol, pub line: u32, pub column: u32, } impl ModuleSymbolTable { pub fn new(name: &str) -> Self { Self { name: GlobalSymbol::from(name), exported: HashMap::new(), internal: HashMap::new(), } } pub fn define_exported(&mut self, name: &str, info: SymbolInfo) { self.exported.insert(GlobalSymbol::from(name), info); } pub fn define_internal(&mut self, name: &str, info: SymbolInfo) { self.internal.insert(GlobalSymbol::from(name), info); } pub fn lookup(&self, name: &GlobalSymbol) -> Option<&SymbolInfo> { self.exported.get(name).or_else(|| self.internal.get(name)) } } pub type ConcurrentSymbolCache = Arc<RwLock<HashMap<GlobalSymbol, String>>>; pub fn create_concurrent_cache() -> ConcurrentSymbolCache { Arc::new(RwLock::new(HashMap::new())) } pub fn demonstrate_concurrent_access() { let cache = create_concurrent_cache(); let symbols: Vec<_> = (0..10) .map(|i| GlobalSymbol::from(format!("symbol_{}", i))) .collect(); { let mut cache_write = cache.write().unwrap(); for (i, sym) in symbols.iter().enumerate() { cache_write.insert(*sym, format!("Value {}", i)); } } { let cache_read = cache.read().unwrap(); println!("Concurrent cache contents:"); for sym in &symbols[..5] { if let Some(value) = cache_read.get(sym) { println!(" {} => {}", sym.as_str(), value); } } } } pub fn benchmark_symbol_creation() { use std::time::Instant; let iterations = 10000; let unique_symbols = 1000; let start = Instant::now(); let table = SymbolTable::new(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); table.intern(&name); } let local_time = start.elapsed(); let start = Instant::now(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); let _ = GlobalSymbol::from(name); } let global_time = start.elapsed(); println!( "Symbol creation benchmark ({} iterations, {} unique):", iterations, unique_symbols ); println!(" Local SymbolTable: {:?}", local_time); println!(" GlobalSymbol: {:?}", global_time); println!( " Ratio: {:.2}x", local_time.as_nanos() as f64 / global_time.as_nanos() as f64 ); } #[cfg(test)] mod tests { use super::*; #[test] fn test_static_symbol_equality() { let sym1 = static_symbol!("test"); let sym2 = static_symbol!("test"); let sym3 = GlobalSymbol::from("test"); assert_eq!(sym1, sym2); assert_eq!(sym1, sym3); assert!(std::ptr::eq(sym1.as_str(), sym2.as_str())); } #[test] fn test_identifier() { let ident = Identifier::new("foo", Span { start: 0, end: 3 }); assert_eq!(ident.as_str(), "foo"); assert_eq!(ident.symbol, GlobalSymbol::from("foo")); } #[test] fn test_compiler_context() { let mut ctx = CompilerContext::new(); let sym = ctx.intern_string("if"); assert!(ctx.is_keyword(sym).is_some()); assert_eq!(*ctx.is_keyword(sym).unwrap(), TokenKind::If); let var_sym = ctx.intern_string("my_var"); assert!(ctx.is_keyword(var_sym).is_none()); } #[test] fn test_module_symbol_table() { let mut module = ModuleSymbolTable::new("my_module"); module.define_exported( "public_func", SymbolInfo { kind: SymbolKind::Function, defined_at: None, type_info: Some("fn() -> i32".to_string()), }, ); module.define_internal( "private_var", SymbolInfo { kind: SymbolKind::Variable, defined_at: None, type_info: Some("String".to_string()), }, ); let pub_sym = GlobalSymbol::from("public_func"); let priv_sym = GlobalSymbol::from("private_var"); assert!(module.lookup(&pub_sym).is_some()); assert!(module.lookup(&priv_sym).is_some()); assert!(module.exported.contains_key(&pub_sym)); assert!(!module.exported.contains_key(&priv_sym)); } } impl CompilerContext { pub fn new() -> Self { let symbols = SymbolTable::new(); let mut keywords = HashMap::new(); keywords.insert(symbols.intern("if"), TokenKind::If); keywords.insert(symbols.intern("else"), TokenKind::Else); keywords.insert(symbols.intern("while"), TokenKind::While); keywords.insert(symbols.intern("for"), TokenKind::For); keywords.insert(symbols.intern("return"), TokenKind::Return); keywords.insert(symbols.intern("function"), TokenKind::Function); keywords.insert(symbols.intern("let"), TokenKind::Let); keywords.insert(symbols.intern("const"), TokenKind::Const); Self { symbols, keywords, string_literals: Vec::new(), } } pub fn intern_string(&mut self, s: &str) -> Symbol { self.symbols.intern(s) } pub fn is_keyword(&self, sym: Symbol) -> Option<&TokenKind> { self.keywords.get(&sym) } pub fn add_string_literal(&mut self, s: &str) -> usize { let sym = self.symbols.intern(s); self.string_literals.push(sym); self.string_literals.len() - 1 } } }
This approach pre-interns all keywords during initialization, making lexical analysis faster. The is_keyword
method becomes a simple hash lookup rather than string comparison.
Identifiers with Spans
Compilers need to track where symbols appear in source code. Combining symbols with span information creates efficient identifier representations:
#![allow(unused)] fn main() { use std::collections::HashMap; use std::sync::{Arc, RwLock}; use symbol_table::{static_symbol, GlobalSymbol, Symbol, SymbolTable}; #[derive(Debug, Clone, Copy)] pub struct Span { pub start: usize, pub end: usize, } impl Identifier { pub fn new(name: &str, span: Span) -> Self { Self { symbol: GlobalSymbol::from(name), span, } } pub fn as_str(&self) -> &str { self.symbol.as_str() } } pub fn demonstrate_static_symbols() { let if_sym = static_symbol!("if"); let else_sym = static_symbol!("else"); let while_sym = static_symbol!("while"); let for_sym = static_symbol!("for"); let return_sym = static_symbol!("return"); println!("Static symbols created:"); println!(" if: {:?}", if_sym); println!(" else: {:?}", else_sym); println!(" while: {:?}", while_sym); println!(" for: {:?}", for_sym); println!(" return: {:?}", return_sym); let if_sym2 = static_symbol!("if"); println!("\nSymbol equality: if == if2: {}", if_sym == if_sym2); println!( "Pointer equality: {:p} == {:p}", if_sym.as_str(), if_sym2.as_str() ); } pub fn demonstrate_global_symbols() { let sym1 = GlobalSymbol::from("variable_name"); let sym2 = GlobalSymbol::from("variable_name"); let sym3 = GlobalSymbol::from("another_name"); println!("Symbol interning:"); println!(" sym1 == sym2: {}", sym1 == sym2); println!(" sym1 == sym3: {}", sym1 == sym3); println!( " Pointers equal: {}", std::ptr::eq(sym1.as_str(), sym2.as_str()) ); } pub struct CompilerContext { pub symbols: SymbolTable, pub keywords: HashMap<Symbol, TokenKind>, pub string_literals: Vec<Symbol>, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { If, Else, While, For, Return, Function, Let, Const, } impl Default for CompilerContext { fn default() -> Self { Self::new() } } impl CompilerContext { pub fn new() -> Self { let symbols = SymbolTable::new(); let mut keywords = HashMap::new(); keywords.insert(symbols.intern("if"), TokenKind::If); keywords.insert(symbols.intern("else"), TokenKind::Else); keywords.insert(symbols.intern("while"), TokenKind::While); keywords.insert(symbols.intern("for"), TokenKind::For); keywords.insert(symbols.intern("return"), TokenKind::Return); keywords.insert(symbols.intern("function"), TokenKind::Function); keywords.insert(symbols.intern("let"), TokenKind::Let); keywords.insert(symbols.intern("const"), TokenKind::Const); Self { symbols, keywords, string_literals: Vec::new(), } } pub fn intern_string(&mut self, s: &str) -> Symbol { self.symbols.intern(s) } pub fn is_keyword(&self, sym: Symbol) -> Option<&TokenKind> { self.keywords.get(&sym) } pub fn add_string_literal(&mut self, s: &str) -> usize { let sym = self.symbols.intern(s); self.string_literals.push(sym); self.string_literals.len() - 1 } } pub fn demonstrate_compiler_context() { let mut ctx = CompilerContext::new(); let ident = ctx.intern_string("my_variable"); let keyword = ctx.intern_string("if"); println!("Identifier 'my_variable' interned as Symbol"); println!("Is 'my_variable' a keyword: {:?}", ctx.is_keyword(ident)); println!("Is 'if' a keyword: {:?}", ctx.is_keyword(keyword)); let lit_idx = ctx.add_string_literal("Hello, world!"); println!("String literal index: {}", lit_idx); } #[derive(Debug)] pub struct ModuleSymbolTable { pub name: GlobalSymbol, pub exported: HashMap<GlobalSymbol, SymbolInfo>, pub internal: HashMap<GlobalSymbol, SymbolInfo>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub defined_at: Option<Location>, pub type_info: Option<String>, } #[derive(Debug, Clone)] pub enum SymbolKind { Function, Variable, Type, Module, } #[derive(Debug, Clone)] pub struct Location { pub file: GlobalSymbol, pub line: u32, pub column: u32, } impl ModuleSymbolTable { pub fn new(name: &str) -> Self { Self { name: GlobalSymbol::from(name), exported: HashMap::new(), internal: HashMap::new(), } } pub fn define_exported(&mut self, name: &str, info: SymbolInfo) { self.exported.insert(GlobalSymbol::from(name), info); } pub fn define_internal(&mut self, name: &str, info: SymbolInfo) { self.internal.insert(GlobalSymbol::from(name), info); } pub fn lookup(&self, name: &GlobalSymbol) -> Option<&SymbolInfo> { self.exported.get(name).or_else(|| self.internal.get(name)) } } pub type ConcurrentSymbolCache = Arc<RwLock<HashMap<GlobalSymbol, String>>>; pub fn create_concurrent_cache() -> ConcurrentSymbolCache { Arc::new(RwLock::new(HashMap::new())) } pub fn demonstrate_concurrent_access() { let cache = create_concurrent_cache(); let symbols: Vec<_> = (0..10) .map(|i| GlobalSymbol::from(format!("symbol_{}", i))) .collect(); { let mut cache_write = cache.write().unwrap(); for (i, sym) in symbols.iter().enumerate() { cache_write.insert(*sym, format!("Value {}", i)); } } { let cache_read = cache.read().unwrap(); println!("Concurrent cache contents:"); for sym in &symbols[..5] { if let Some(value) = cache_read.get(sym) { println!(" {} => {}", sym.as_str(), value); } } } } pub fn benchmark_symbol_creation() { use std::time::Instant; let iterations = 10000; let unique_symbols = 1000; let start = Instant::now(); let table = SymbolTable::new(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); table.intern(&name); } let local_time = start.elapsed(); let start = Instant::now(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); let _ = GlobalSymbol::from(name); } let global_time = start.elapsed(); println!( "Symbol creation benchmark ({} iterations, {} unique):", iterations, unique_symbols ); println!(" Local SymbolTable: {:?}", local_time); println!(" GlobalSymbol: {:?}", global_time); println!( " Ratio: {:.2}x", local_time.as_nanos() as f64 / global_time.as_nanos() as f64 ); } #[cfg(test)] mod tests { use super::*; #[test] fn test_static_symbol_equality() { let sym1 = static_symbol!("test"); let sym2 = static_symbol!("test"); let sym3 = GlobalSymbol::from("test"); assert_eq!(sym1, sym2); assert_eq!(sym1, sym3); assert!(std::ptr::eq(sym1.as_str(), sym2.as_str())); } #[test] fn test_identifier() { let ident = Identifier::new("foo", Span { start: 0, end: 3 }); assert_eq!(ident.as_str(), "foo"); assert_eq!(ident.symbol, GlobalSymbol::from("foo")); } #[test] fn test_compiler_context() { let mut ctx = CompilerContext::new(); let sym = ctx.intern_string("if"); assert!(ctx.is_keyword(sym).is_some()); assert_eq!(*ctx.is_keyword(sym).unwrap(), TokenKind::If); let var_sym = ctx.intern_string("my_var"); assert!(ctx.is_keyword(var_sym).is_none()); } #[test] fn test_module_symbol_table() { let mut module = ModuleSymbolTable::new("my_module"); module.define_exported( "public_func", SymbolInfo { kind: SymbolKind::Function, defined_at: None, type_info: Some("fn() -> i32".to_string()), }, ); module.define_internal( "private_var", SymbolInfo { kind: SymbolKind::Variable, defined_at: None, type_info: Some("String".to_string()), }, ); let pub_sym = GlobalSymbol::from("public_func"); let priv_sym = GlobalSymbol::from("private_var"); assert!(module.lookup(&pub_sym).is_some()); assert!(module.lookup(&priv_sym).is_some()); assert!(module.exported.contains_key(&pub_sym)); assert!(!module.exported.contains_key(&priv_sym)); } } #[derive(Debug, Clone)] pub struct Identifier { pub symbol: GlobalSymbol, pub span: Span, } }
#![allow(unused)] fn main() { use std::collections::HashMap; use std::sync::{Arc, RwLock}; use symbol_table::{static_symbol, GlobalSymbol, Symbol, SymbolTable}; #[derive(Debug, Clone)] pub struct Identifier { pub symbol: GlobalSymbol, pub span: Span, } #[derive(Debug, Clone, Copy)] pub struct Span { pub start: usize, pub end: usize, } pub fn demonstrate_static_symbols() { let if_sym = static_symbol!("if"); let else_sym = static_symbol!("else"); let while_sym = static_symbol!("while"); let for_sym = static_symbol!("for"); let return_sym = static_symbol!("return"); println!("Static symbols created:"); println!(" if: {:?}", if_sym); println!(" else: {:?}", else_sym); println!(" while: {:?}", while_sym); println!(" for: {:?}", for_sym); println!(" return: {:?}", return_sym); let if_sym2 = static_symbol!("if"); println!("\nSymbol equality: if == if2: {}", if_sym == if_sym2); println!( "Pointer equality: {:p} == {:p}", if_sym.as_str(), if_sym2.as_str() ); } pub fn demonstrate_global_symbols() { let sym1 = GlobalSymbol::from("variable_name"); let sym2 = GlobalSymbol::from("variable_name"); let sym3 = GlobalSymbol::from("another_name"); println!("Symbol interning:"); println!(" sym1 == sym2: {}", sym1 == sym2); println!(" sym1 == sym3: {}", sym1 == sym3); println!( " Pointers equal: {}", std::ptr::eq(sym1.as_str(), sym2.as_str()) ); } pub struct CompilerContext { pub symbols: SymbolTable, pub keywords: HashMap<Symbol, TokenKind>, pub string_literals: Vec<Symbol>, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { If, Else, While, For, Return, Function, Let, Const, } impl Default for CompilerContext { fn default() -> Self { Self::new() } } impl CompilerContext { pub fn new() -> Self { let symbols = SymbolTable::new(); let mut keywords = HashMap::new(); keywords.insert(symbols.intern("if"), TokenKind::If); keywords.insert(symbols.intern("else"), TokenKind::Else); keywords.insert(symbols.intern("while"), TokenKind::While); keywords.insert(symbols.intern("for"), TokenKind::For); keywords.insert(symbols.intern("return"), TokenKind::Return); keywords.insert(symbols.intern("function"), TokenKind::Function); keywords.insert(symbols.intern("let"), TokenKind::Let); keywords.insert(symbols.intern("const"), TokenKind::Const); Self { symbols, keywords, string_literals: Vec::new(), } } pub fn intern_string(&mut self, s: &str) -> Symbol { self.symbols.intern(s) } pub fn is_keyword(&self, sym: Symbol) -> Option<&TokenKind> { self.keywords.get(&sym) } pub fn add_string_literal(&mut self, s: &str) -> usize { let sym = self.symbols.intern(s); self.string_literals.push(sym); self.string_literals.len() - 1 } } pub fn demonstrate_compiler_context() { let mut ctx = CompilerContext::new(); let ident = ctx.intern_string("my_variable"); let keyword = ctx.intern_string("if"); println!("Identifier 'my_variable' interned as Symbol"); println!("Is 'my_variable' a keyword: {:?}", ctx.is_keyword(ident)); println!("Is 'if' a keyword: {:?}", ctx.is_keyword(keyword)); let lit_idx = ctx.add_string_literal("Hello, world!"); println!("String literal index: {}", lit_idx); } #[derive(Debug)] pub struct ModuleSymbolTable { pub name: GlobalSymbol, pub exported: HashMap<GlobalSymbol, SymbolInfo>, pub internal: HashMap<GlobalSymbol, SymbolInfo>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub defined_at: Option<Location>, pub type_info: Option<String>, } #[derive(Debug, Clone)] pub enum SymbolKind { Function, Variable, Type, Module, } #[derive(Debug, Clone)] pub struct Location { pub file: GlobalSymbol, pub line: u32, pub column: u32, } impl ModuleSymbolTable { pub fn new(name: &str) -> Self { Self { name: GlobalSymbol::from(name), exported: HashMap::new(), internal: HashMap::new(), } } pub fn define_exported(&mut self, name: &str, info: SymbolInfo) { self.exported.insert(GlobalSymbol::from(name), info); } pub fn define_internal(&mut self, name: &str, info: SymbolInfo) { self.internal.insert(GlobalSymbol::from(name), info); } pub fn lookup(&self, name: &GlobalSymbol) -> Option<&SymbolInfo> { self.exported.get(name).or_else(|| self.internal.get(name)) } } pub type ConcurrentSymbolCache = Arc<RwLock<HashMap<GlobalSymbol, String>>>; pub fn create_concurrent_cache() -> ConcurrentSymbolCache { Arc::new(RwLock::new(HashMap::new())) } pub fn demonstrate_concurrent_access() { let cache = create_concurrent_cache(); let symbols: Vec<_> = (0..10) .map(|i| GlobalSymbol::from(format!("symbol_{}", i))) .collect(); { let mut cache_write = cache.write().unwrap(); for (i, sym) in symbols.iter().enumerate() { cache_write.insert(*sym, format!("Value {}", i)); } } { let cache_read = cache.read().unwrap(); println!("Concurrent cache contents:"); for sym in &symbols[..5] { if let Some(value) = cache_read.get(sym) { println!(" {} => {}", sym.as_str(), value); } } } } pub fn benchmark_symbol_creation() { use std::time::Instant; let iterations = 10000; let unique_symbols = 1000; let start = Instant::now(); let table = SymbolTable::new(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); table.intern(&name); } let local_time = start.elapsed(); let start = Instant::now(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); let _ = GlobalSymbol::from(name); } let global_time = start.elapsed(); println!( "Symbol creation benchmark ({} iterations, {} unique):", iterations, unique_symbols ); println!(" Local SymbolTable: {:?}", local_time); println!(" GlobalSymbol: {:?}", global_time); println!( " Ratio: {:.2}x", local_time.as_nanos() as f64 / global_time.as_nanos() as f64 ); } #[cfg(test)] mod tests { use super::*; #[test] fn test_static_symbol_equality() { let sym1 = static_symbol!("test"); let sym2 = static_symbol!("test"); let sym3 = GlobalSymbol::from("test"); assert_eq!(sym1, sym2); assert_eq!(sym1, sym3); assert!(std::ptr::eq(sym1.as_str(), sym2.as_str())); } #[test] fn test_identifier() { let ident = Identifier::new("foo", Span { start: 0, end: 3 }); assert_eq!(ident.as_str(), "foo"); assert_eq!(ident.symbol, GlobalSymbol::from("foo")); } #[test] fn test_compiler_context() { let mut ctx = CompilerContext::new(); let sym = ctx.intern_string("if"); assert!(ctx.is_keyword(sym).is_some()); assert_eq!(*ctx.is_keyword(sym).unwrap(), TokenKind::If); let var_sym = ctx.intern_string("my_var"); assert!(ctx.is_keyword(var_sym).is_none()); } #[test] fn test_module_symbol_table() { let mut module = ModuleSymbolTable::new("my_module"); module.define_exported( "public_func", SymbolInfo { kind: SymbolKind::Function, defined_at: None, type_info: Some("fn() -> i32".to_string()), }, ); module.define_internal( "private_var", SymbolInfo { kind: SymbolKind::Variable, defined_at: None, type_info: Some("String".to_string()), }, ); let pub_sym = GlobalSymbol::from("public_func"); let priv_sym = GlobalSymbol::from("private_var"); assert!(module.lookup(&pub_sym).is_some()); assert!(module.lookup(&priv_sym).is_some()); assert!(module.exported.contains_key(&pub_sym)); assert!(!module.exported.contains_key(&priv_sym)); } } impl Identifier { pub fn new(name: &str, span: Span) -> Self { Self { symbol: GlobalSymbol::from(name), span, } } pub fn as_str(&self) -> &str { self.symbol.as_str() } } }
The identifier stores an interned symbol plus location information. The as_str
method provides convenient access to the underlying string without allocation.
Module Symbol Tables
Complex compilers often organize symbols by module, distinguishing between exported and internal symbols:
#![allow(unused)] fn main() { use std::collections::HashMap; use std::sync::{Arc, RwLock}; use symbol_table::{static_symbol, GlobalSymbol, Symbol, SymbolTable}; #[derive(Debug, Clone)] pub struct Identifier { pub symbol: GlobalSymbol, pub span: Span, } #[derive(Debug, Clone, Copy)] pub struct Span { pub start: usize, pub end: usize, } impl Identifier { pub fn new(name: &str, span: Span) -> Self { Self { symbol: GlobalSymbol::from(name), span, } } pub fn as_str(&self) -> &str { self.symbol.as_str() } } pub fn demonstrate_static_symbols() { let if_sym = static_symbol!("if"); let else_sym = static_symbol!("else"); let while_sym = static_symbol!("while"); let for_sym = static_symbol!("for"); let return_sym = static_symbol!("return"); println!("Static symbols created:"); println!(" if: {:?}", if_sym); println!(" else: {:?}", else_sym); println!(" while: {:?}", while_sym); println!(" for: {:?}", for_sym); println!(" return: {:?}", return_sym); let if_sym2 = static_symbol!("if"); println!("\nSymbol equality: if == if2: {}", if_sym == if_sym2); println!( "Pointer equality: {:p} == {:p}", if_sym.as_str(), if_sym2.as_str() ); } pub fn demonstrate_global_symbols() { let sym1 = GlobalSymbol::from("variable_name"); let sym2 = GlobalSymbol::from("variable_name"); let sym3 = GlobalSymbol::from("another_name"); println!("Symbol interning:"); println!(" sym1 == sym2: {}", sym1 == sym2); println!(" sym1 == sym3: {}", sym1 == sym3); println!( " Pointers equal: {}", std::ptr::eq(sym1.as_str(), sym2.as_str()) ); } pub struct CompilerContext { pub symbols: SymbolTable, pub keywords: HashMap<Symbol, TokenKind>, pub string_literals: Vec<Symbol>, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { If, Else, While, For, Return, Function, Let, Const, } impl Default for CompilerContext { fn default() -> Self { Self::new() } } impl CompilerContext { pub fn new() -> Self { let symbols = SymbolTable::new(); let mut keywords = HashMap::new(); keywords.insert(symbols.intern("if"), TokenKind::If); keywords.insert(symbols.intern("else"), TokenKind::Else); keywords.insert(symbols.intern("while"), TokenKind::While); keywords.insert(symbols.intern("for"), TokenKind::For); keywords.insert(symbols.intern("return"), TokenKind::Return); keywords.insert(symbols.intern("function"), TokenKind::Function); keywords.insert(symbols.intern("let"), TokenKind::Let); keywords.insert(symbols.intern("const"), TokenKind::Const); Self { symbols, keywords, string_literals: Vec::new(), } } pub fn intern_string(&mut self, s: &str) -> Symbol { self.symbols.intern(s) } pub fn is_keyword(&self, sym: Symbol) -> Option<&TokenKind> { self.keywords.get(&sym) } pub fn add_string_literal(&mut self, s: &str) -> usize { let sym = self.symbols.intern(s); self.string_literals.push(sym); self.string_literals.len() - 1 } } pub fn demonstrate_compiler_context() { let mut ctx = CompilerContext::new(); let ident = ctx.intern_string("my_variable"); let keyword = ctx.intern_string("if"); println!("Identifier 'my_variable' interned as Symbol"); println!("Is 'my_variable' a keyword: {:?}", ctx.is_keyword(ident)); println!("Is 'if' a keyword: {:?}", ctx.is_keyword(keyword)); let lit_idx = ctx.add_string_literal("Hello, world!"); println!("String literal index: {}", lit_idx); } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub defined_at: Option<Location>, pub type_info: Option<String>, } #[derive(Debug, Clone)] pub enum SymbolKind { Function, Variable, Type, Module, } #[derive(Debug, Clone)] pub struct Location { pub file: GlobalSymbol, pub line: u32, pub column: u32, } impl ModuleSymbolTable { pub fn new(name: &str) -> Self { Self { name: GlobalSymbol::from(name), exported: HashMap::new(), internal: HashMap::new(), } } pub fn define_exported(&mut self, name: &str, info: SymbolInfo) { self.exported.insert(GlobalSymbol::from(name), info); } pub fn define_internal(&mut self, name: &str, info: SymbolInfo) { self.internal.insert(GlobalSymbol::from(name), info); } pub fn lookup(&self, name: &GlobalSymbol) -> Option<&SymbolInfo> { self.exported.get(name).or_else(|| self.internal.get(name)) } } pub type ConcurrentSymbolCache = Arc<RwLock<HashMap<GlobalSymbol, String>>>; pub fn create_concurrent_cache() -> ConcurrentSymbolCache { Arc::new(RwLock::new(HashMap::new())) } pub fn demonstrate_concurrent_access() { let cache = create_concurrent_cache(); let symbols: Vec<_> = (0..10) .map(|i| GlobalSymbol::from(format!("symbol_{}", i))) .collect(); { let mut cache_write = cache.write().unwrap(); for (i, sym) in symbols.iter().enumerate() { cache_write.insert(*sym, format!("Value {}", i)); } } { let cache_read = cache.read().unwrap(); println!("Concurrent cache contents:"); for sym in &symbols[..5] { if let Some(value) = cache_read.get(sym) { println!(" {} => {}", sym.as_str(), value); } } } } pub fn benchmark_symbol_creation() { use std::time::Instant; let iterations = 10000; let unique_symbols = 1000; let start = Instant::now(); let table = SymbolTable::new(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); table.intern(&name); } let local_time = start.elapsed(); let start = Instant::now(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); let _ = GlobalSymbol::from(name); } let global_time = start.elapsed(); println!( "Symbol creation benchmark ({} iterations, {} unique):", iterations, unique_symbols ); println!(" Local SymbolTable: {:?}", local_time); println!(" GlobalSymbol: {:?}", global_time); println!( " Ratio: {:.2}x", local_time.as_nanos() as f64 / global_time.as_nanos() as f64 ); } #[cfg(test)] mod tests { use super::*; #[test] fn test_static_symbol_equality() { let sym1 = static_symbol!("test"); let sym2 = static_symbol!("test"); let sym3 = GlobalSymbol::from("test"); assert_eq!(sym1, sym2); assert_eq!(sym1, sym3); assert!(std::ptr::eq(sym1.as_str(), sym2.as_str())); } #[test] fn test_identifier() { let ident = Identifier::new("foo", Span { start: 0, end: 3 }); assert_eq!(ident.as_str(), "foo"); assert_eq!(ident.symbol, GlobalSymbol::from("foo")); } #[test] fn test_compiler_context() { let mut ctx = CompilerContext::new(); let sym = ctx.intern_string("if"); assert!(ctx.is_keyword(sym).is_some()); assert_eq!(*ctx.is_keyword(sym).unwrap(), TokenKind::If); let var_sym = ctx.intern_string("my_var"); assert!(ctx.is_keyword(var_sym).is_none()); } #[test] fn test_module_symbol_table() { let mut module = ModuleSymbolTable::new("my_module"); module.define_exported( "public_func", SymbolInfo { kind: SymbolKind::Function, defined_at: None, type_info: Some("fn() -> i32".to_string()), }, ); module.define_internal( "private_var", SymbolInfo { kind: SymbolKind::Variable, defined_at: None, type_info: Some("String".to_string()), }, ); let pub_sym = GlobalSymbol::from("public_func"); let priv_sym = GlobalSymbol::from("private_var"); assert!(module.lookup(&pub_sym).is_some()); assert!(module.lookup(&priv_sym).is_some()); assert!(module.exported.contains_key(&pub_sym)); assert!(!module.exported.contains_key(&priv_sym)); } } #[derive(Debug)] pub struct ModuleSymbolTable { pub name: GlobalSymbol, pub exported: HashMap<GlobalSymbol, SymbolInfo>, pub internal: HashMap<GlobalSymbol, SymbolInfo>, } }
#![allow(unused)] fn main() { use std::collections::HashMap; use std::sync::{Arc, RwLock}; use symbol_table::{static_symbol, GlobalSymbol, Symbol, SymbolTable}; #[derive(Debug, Clone)] pub struct Identifier { pub symbol: GlobalSymbol, pub span: Span, } #[derive(Debug, Clone, Copy)] pub struct Span { pub start: usize, pub end: usize, } impl Identifier { pub fn new(name: &str, span: Span) -> Self { Self { symbol: GlobalSymbol::from(name), span, } } pub fn as_str(&self) -> &str { self.symbol.as_str() } } pub fn demonstrate_static_symbols() { let if_sym = static_symbol!("if"); let else_sym = static_symbol!("else"); let while_sym = static_symbol!("while"); let for_sym = static_symbol!("for"); let return_sym = static_symbol!("return"); println!("Static symbols created:"); println!(" if: {:?}", if_sym); println!(" else: {:?}", else_sym); println!(" while: {:?}", while_sym); println!(" for: {:?}", for_sym); println!(" return: {:?}", return_sym); let if_sym2 = static_symbol!("if"); println!("\nSymbol equality: if == if2: {}", if_sym == if_sym2); println!( "Pointer equality: {:p} == {:p}", if_sym.as_str(), if_sym2.as_str() ); } pub fn demonstrate_global_symbols() { let sym1 = GlobalSymbol::from("variable_name"); let sym2 = GlobalSymbol::from("variable_name"); let sym3 = GlobalSymbol::from("another_name"); println!("Symbol interning:"); println!(" sym1 == sym2: {}", sym1 == sym2); println!(" sym1 == sym3: {}", sym1 == sym3); println!( " Pointers equal: {}", std::ptr::eq(sym1.as_str(), sym2.as_str()) ); } pub struct CompilerContext { pub symbols: SymbolTable, pub keywords: HashMap<Symbol, TokenKind>, pub string_literals: Vec<Symbol>, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { If, Else, While, For, Return, Function, Let, Const, } impl Default for CompilerContext { fn default() -> Self { Self::new() } } impl CompilerContext { pub fn new() -> Self { let symbols = SymbolTable::new(); let mut keywords = HashMap::new(); keywords.insert(symbols.intern("if"), TokenKind::If); keywords.insert(symbols.intern("else"), TokenKind::Else); keywords.insert(symbols.intern("while"), TokenKind::While); keywords.insert(symbols.intern("for"), TokenKind::For); keywords.insert(symbols.intern("return"), TokenKind::Return); keywords.insert(symbols.intern("function"), TokenKind::Function); keywords.insert(symbols.intern("let"), TokenKind::Let); keywords.insert(symbols.intern("const"), TokenKind::Const); Self { symbols, keywords, string_literals: Vec::new(), } } pub fn intern_string(&mut self, s: &str) -> Symbol { self.symbols.intern(s) } pub fn is_keyword(&self, sym: Symbol) -> Option<&TokenKind> { self.keywords.get(&sym) } pub fn add_string_literal(&mut self, s: &str) -> usize { let sym = self.symbols.intern(s); self.string_literals.push(sym); self.string_literals.len() - 1 } } pub fn demonstrate_compiler_context() { let mut ctx = CompilerContext::new(); let ident = ctx.intern_string("my_variable"); let keyword = ctx.intern_string("if"); println!("Identifier 'my_variable' interned as Symbol"); println!("Is 'my_variable' a keyword: {:?}", ctx.is_keyword(ident)); println!("Is 'if' a keyword: {:?}", ctx.is_keyword(keyword)); let lit_idx = ctx.add_string_literal("Hello, world!"); println!("String literal index: {}", lit_idx); } #[derive(Debug)] pub struct ModuleSymbolTable { pub name: GlobalSymbol, pub exported: HashMap<GlobalSymbol, SymbolInfo>, pub internal: HashMap<GlobalSymbol, SymbolInfo>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub defined_at: Option<Location>, pub type_info: Option<String>, } #[derive(Debug, Clone)] pub enum SymbolKind { Function, Variable, Type, Module, } #[derive(Debug, Clone)] pub struct Location { pub file: GlobalSymbol, pub line: u32, pub column: u32, } pub type ConcurrentSymbolCache = Arc<RwLock<HashMap<GlobalSymbol, String>>>; pub fn create_concurrent_cache() -> ConcurrentSymbolCache { Arc::new(RwLock::new(HashMap::new())) } pub fn demonstrate_concurrent_access() { let cache = create_concurrent_cache(); let symbols: Vec<_> = (0..10) .map(|i| GlobalSymbol::from(format!("symbol_{}", i))) .collect(); { let mut cache_write = cache.write().unwrap(); for (i, sym) in symbols.iter().enumerate() { cache_write.insert(*sym, format!("Value {}", i)); } } { let cache_read = cache.read().unwrap(); println!("Concurrent cache contents:"); for sym in &symbols[..5] { if let Some(value) = cache_read.get(sym) { println!(" {} => {}", sym.as_str(), value); } } } } pub fn benchmark_symbol_creation() { use std::time::Instant; let iterations = 10000; let unique_symbols = 1000; let start = Instant::now(); let table = SymbolTable::new(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); table.intern(&name); } let local_time = start.elapsed(); let start = Instant::now(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); let _ = GlobalSymbol::from(name); } let global_time = start.elapsed(); println!( "Symbol creation benchmark ({} iterations, {} unique):", iterations, unique_symbols ); println!(" Local SymbolTable: {:?}", local_time); println!(" GlobalSymbol: {:?}", global_time); println!( " Ratio: {:.2}x", local_time.as_nanos() as f64 / global_time.as_nanos() as f64 ); } #[cfg(test)] mod tests { use super::*; #[test] fn test_static_symbol_equality() { let sym1 = static_symbol!("test"); let sym2 = static_symbol!("test"); let sym3 = GlobalSymbol::from("test"); assert_eq!(sym1, sym2); assert_eq!(sym1, sym3); assert!(std::ptr::eq(sym1.as_str(), sym2.as_str())); } #[test] fn test_identifier() { let ident = Identifier::new("foo", Span { start: 0, end: 3 }); assert_eq!(ident.as_str(), "foo"); assert_eq!(ident.symbol, GlobalSymbol::from("foo")); } #[test] fn test_compiler_context() { let mut ctx = CompilerContext::new(); let sym = ctx.intern_string("if"); assert!(ctx.is_keyword(sym).is_some()); assert_eq!(*ctx.is_keyword(sym).unwrap(), TokenKind::If); let var_sym = ctx.intern_string("my_var"); assert!(ctx.is_keyword(var_sym).is_none()); } #[test] fn test_module_symbol_table() { let mut module = ModuleSymbolTable::new("my_module"); module.define_exported( "public_func", SymbolInfo { kind: SymbolKind::Function, defined_at: None, type_info: Some("fn() -> i32".to_string()), }, ); module.define_internal( "private_var", SymbolInfo { kind: SymbolKind::Variable, defined_at: None, type_info: Some("String".to_string()), }, ); let pub_sym = GlobalSymbol::from("public_func"); let priv_sym = GlobalSymbol::from("private_var"); assert!(module.lookup(&pub_sym).is_some()); assert!(module.lookup(&priv_sym).is_some()); assert!(module.exported.contains_key(&pub_sym)); assert!(!module.exported.contains_key(&priv_sym)); } } impl ModuleSymbolTable { pub fn new(name: &str) -> Self { Self { name: GlobalSymbol::from(name), exported: HashMap::new(), internal: HashMap::new(), } } pub fn define_exported(&mut self, name: &str, info: SymbolInfo) { self.exported.insert(GlobalSymbol::from(name), info); } pub fn define_internal(&mut self, name: &str, info: SymbolInfo) { self.internal.insert(GlobalSymbol::from(name), info); } pub fn lookup(&self, name: &GlobalSymbol) -> Option<&SymbolInfo> { self.exported.get(name).or_else(|| self.internal.get(name)) } } }
This structure efficiently represents module interfaces. Exported symbols are available to other modules, while internal symbols remain private. The lookup method searches both tables, respecting visibility rules.
Concurrent Access Patterns
For compiler passes that need to share mutable symbol data across threads:
#![allow(unused)] fn main() { use std::collections::HashMap; use std::sync::{Arc, RwLock}; use symbol_table::{static_symbol, GlobalSymbol, Symbol, SymbolTable}; #[derive(Debug, Clone)] pub struct Identifier { pub symbol: GlobalSymbol, pub span: Span, } #[derive(Debug, Clone, Copy)] pub struct Span { pub start: usize, pub end: usize, } impl Identifier { pub fn new(name: &str, span: Span) -> Self { Self { symbol: GlobalSymbol::from(name), span, } } pub fn as_str(&self) -> &str { self.symbol.as_str() } } pub fn demonstrate_static_symbols() { let if_sym = static_symbol!("if"); let else_sym = static_symbol!("else"); let while_sym = static_symbol!("while"); let for_sym = static_symbol!("for"); let return_sym = static_symbol!("return"); println!("Static symbols created:"); println!(" if: {:?}", if_sym); println!(" else: {:?}", else_sym); println!(" while: {:?}", while_sym); println!(" for: {:?}", for_sym); println!(" return: {:?}", return_sym); let if_sym2 = static_symbol!("if"); println!("\nSymbol equality: if == if2: {}", if_sym == if_sym2); println!( "Pointer equality: {:p} == {:p}", if_sym.as_str(), if_sym2.as_str() ); } pub fn demonstrate_global_symbols() { let sym1 = GlobalSymbol::from("variable_name"); let sym2 = GlobalSymbol::from("variable_name"); let sym3 = GlobalSymbol::from("another_name"); println!("Symbol interning:"); println!(" sym1 == sym2: {}", sym1 == sym2); println!(" sym1 == sym3: {}", sym1 == sym3); println!( " Pointers equal: {}", std::ptr::eq(sym1.as_str(), sym2.as_str()) ); } pub struct CompilerContext { pub symbols: SymbolTable, pub keywords: HashMap<Symbol, TokenKind>, pub string_literals: Vec<Symbol>, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { If, Else, While, For, Return, Function, Let, Const, } impl Default for CompilerContext { fn default() -> Self { Self::new() } } impl CompilerContext { pub fn new() -> Self { let symbols = SymbolTable::new(); let mut keywords = HashMap::new(); keywords.insert(symbols.intern("if"), TokenKind::If); keywords.insert(symbols.intern("else"), TokenKind::Else); keywords.insert(symbols.intern("while"), TokenKind::While); keywords.insert(symbols.intern("for"), TokenKind::For); keywords.insert(symbols.intern("return"), TokenKind::Return); keywords.insert(symbols.intern("function"), TokenKind::Function); keywords.insert(symbols.intern("let"), TokenKind::Let); keywords.insert(symbols.intern("const"), TokenKind::Const); Self { symbols, keywords, string_literals: Vec::new(), } } pub fn intern_string(&mut self, s: &str) -> Symbol { self.symbols.intern(s) } pub fn is_keyword(&self, sym: Symbol) -> Option<&TokenKind> { self.keywords.get(&sym) } pub fn add_string_literal(&mut self, s: &str) -> usize { let sym = self.symbols.intern(s); self.string_literals.push(sym); self.string_literals.len() - 1 } } pub fn demonstrate_compiler_context() { let mut ctx = CompilerContext::new(); let ident = ctx.intern_string("my_variable"); let keyword = ctx.intern_string("if"); println!("Identifier 'my_variable' interned as Symbol"); println!("Is 'my_variable' a keyword: {:?}", ctx.is_keyword(ident)); println!("Is 'if' a keyword: {:?}", ctx.is_keyword(keyword)); let lit_idx = ctx.add_string_literal("Hello, world!"); println!("String literal index: {}", lit_idx); } #[derive(Debug)] pub struct ModuleSymbolTable { pub name: GlobalSymbol, pub exported: HashMap<GlobalSymbol, SymbolInfo>, pub internal: HashMap<GlobalSymbol, SymbolInfo>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub defined_at: Option<Location>, pub type_info: Option<String>, } #[derive(Debug, Clone)] pub enum SymbolKind { Function, Variable, Type, Module, } #[derive(Debug, Clone)] pub struct Location { pub file: GlobalSymbol, pub line: u32, pub column: u32, } impl ModuleSymbolTable { pub fn new(name: &str) -> Self { Self { name: GlobalSymbol::from(name), exported: HashMap::new(), internal: HashMap::new(), } } pub fn define_exported(&mut self, name: &str, info: SymbolInfo) { self.exported.insert(GlobalSymbol::from(name), info); } pub fn define_internal(&mut self, name: &str, info: SymbolInfo) { self.internal.insert(GlobalSymbol::from(name), info); } pub fn lookup(&self, name: &GlobalSymbol) -> Option<&SymbolInfo> { self.exported.get(name).or_else(|| self.internal.get(name)) } } pub type ConcurrentSymbolCache = Arc<RwLock<HashMap<GlobalSymbol, String>>>; pub fn demonstrate_concurrent_access() { let cache = create_concurrent_cache(); let symbols: Vec<_> = (0..10) .map(|i| GlobalSymbol::from(format!("symbol_{}", i))) .collect(); { let mut cache_write = cache.write().unwrap(); for (i, sym) in symbols.iter().enumerate() { cache_write.insert(*sym, format!("Value {}", i)); } } { let cache_read = cache.read().unwrap(); println!("Concurrent cache contents:"); for sym in &symbols[..5] { if let Some(value) = cache_read.get(sym) { println!(" {} => {}", sym.as_str(), value); } } } } pub fn benchmark_symbol_creation() { use std::time::Instant; let iterations = 10000; let unique_symbols = 1000; let start = Instant::now(); let table = SymbolTable::new(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); table.intern(&name); } let local_time = start.elapsed(); let start = Instant::now(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); let _ = GlobalSymbol::from(name); } let global_time = start.elapsed(); println!( "Symbol creation benchmark ({} iterations, {} unique):", iterations, unique_symbols ); println!(" Local SymbolTable: {:?}", local_time); println!(" GlobalSymbol: {:?}", global_time); println!( " Ratio: {:.2}x", local_time.as_nanos() as f64 / global_time.as_nanos() as f64 ); } #[cfg(test)] mod tests { use super::*; #[test] fn test_static_symbol_equality() { let sym1 = static_symbol!("test"); let sym2 = static_symbol!("test"); let sym3 = GlobalSymbol::from("test"); assert_eq!(sym1, sym2); assert_eq!(sym1, sym3); assert!(std::ptr::eq(sym1.as_str(), sym2.as_str())); } #[test] fn test_identifier() { let ident = Identifier::new("foo", Span { start: 0, end: 3 }); assert_eq!(ident.as_str(), "foo"); assert_eq!(ident.symbol, GlobalSymbol::from("foo")); } #[test] fn test_compiler_context() { let mut ctx = CompilerContext::new(); let sym = ctx.intern_string("if"); assert!(ctx.is_keyword(sym).is_some()); assert_eq!(*ctx.is_keyword(sym).unwrap(), TokenKind::If); let var_sym = ctx.intern_string("my_var"); assert!(ctx.is_keyword(var_sym).is_none()); } #[test] fn test_module_symbol_table() { let mut module = ModuleSymbolTable::new("my_module"); module.define_exported( "public_func", SymbolInfo { kind: SymbolKind::Function, defined_at: None, type_info: Some("fn() -> i32".to_string()), }, ); module.define_internal( "private_var", SymbolInfo { kind: SymbolKind::Variable, defined_at: None, type_info: Some("String".to_string()), }, ); let pub_sym = GlobalSymbol::from("public_func"); let priv_sym = GlobalSymbol::from("private_var"); assert!(module.lookup(&pub_sym).is_some()); assert!(module.lookup(&priv_sym).is_some()); assert!(module.exported.contains_key(&pub_sym)); assert!(!module.exported.contains_key(&priv_sym)); } } pub fn create_concurrent_cache() -> ConcurrentSymbolCache { Arc::new(RwLock::new(HashMap::new())) } }
#![allow(unused)] fn main() { use std::collections::HashMap; use std::sync::{Arc, RwLock}; use symbol_table::{static_symbol, GlobalSymbol, Symbol, SymbolTable}; #[derive(Debug, Clone)] pub struct Identifier { pub symbol: GlobalSymbol, pub span: Span, } #[derive(Debug, Clone, Copy)] pub struct Span { pub start: usize, pub end: usize, } impl Identifier { pub fn new(name: &str, span: Span) -> Self { Self { symbol: GlobalSymbol::from(name), span, } } pub fn as_str(&self) -> &str { self.symbol.as_str() } } pub fn demonstrate_static_symbols() { let if_sym = static_symbol!("if"); let else_sym = static_symbol!("else"); let while_sym = static_symbol!("while"); let for_sym = static_symbol!("for"); let return_sym = static_symbol!("return"); println!("Static symbols created:"); println!(" if: {:?}", if_sym); println!(" else: {:?}", else_sym); println!(" while: {:?}", while_sym); println!(" for: {:?}", for_sym); println!(" return: {:?}", return_sym); let if_sym2 = static_symbol!("if"); println!("\nSymbol equality: if == if2: {}", if_sym == if_sym2); println!( "Pointer equality: {:p} == {:p}", if_sym.as_str(), if_sym2.as_str() ); } pub fn demonstrate_global_symbols() { let sym1 = GlobalSymbol::from("variable_name"); let sym2 = GlobalSymbol::from("variable_name"); let sym3 = GlobalSymbol::from("another_name"); println!("Symbol interning:"); println!(" sym1 == sym2: {}", sym1 == sym2); println!(" sym1 == sym3: {}", sym1 == sym3); println!( " Pointers equal: {}", std::ptr::eq(sym1.as_str(), sym2.as_str()) ); } pub struct CompilerContext { pub symbols: SymbolTable, pub keywords: HashMap<Symbol, TokenKind>, pub string_literals: Vec<Symbol>, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { If, Else, While, For, Return, Function, Let, Const, } impl Default for CompilerContext { fn default() -> Self { Self::new() } } impl CompilerContext { pub fn new() -> Self { let symbols = SymbolTable::new(); let mut keywords = HashMap::new(); keywords.insert(symbols.intern("if"), TokenKind::If); keywords.insert(symbols.intern("else"), TokenKind::Else); keywords.insert(symbols.intern("while"), TokenKind::While); keywords.insert(symbols.intern("for"), TokenKind::For); keywords.insert(symbols.intern("return"), TokenKind::Return); keywords.insert(symbols.intern("function"), TokenKind::Function); keywords.insert(symbols.intern("let"), TokenKind::Let); keywords.insert(symbols.intern("const"), TokenKind::Const); Self { symbols, keywords, string_literals: Vec::new(), } } pub fn intern_string(&mut self, s: &str) -> Symbol { self.symbols.intern(s) } pub fn is_keyword(&self, sym: Symbol) -> Option<&TokenKind> { self.keywords.get(&sym) } pub fn add_string_literal(&mut self, s: &str) -> usize { let sym = self.symbols.intern(s); self.string_literals.push(sym); self.string_literals.len() - 1 } } pub fn demonstrate_compiler_context() { let mut ctx = CompilerContext::new(); let ident = ctx.intern_string("my_variable"); let keyword = ctx.intern_string("if"); println!("Identifier 'my_variable' interned as Symbol"); println!("Is 'my_variable' a keyword: {:?}", ctx.is_keyword(ident)); println!("Is 'if' a keyword: {:?}", ctx.is_keyword(keyword)); let lit_idx = ctx.add_string_literal("Hello, world!"); println!("String literal index: {}", lit_idx); } #[derive(Debug)] pub struct ModuleSymbolTable { pub name: GlobalSymbol, pub exported: HashMap<GlobalSymbol, SymbolInfo>, pub internal: HashMap<GlobalSymbol, SymbolInfo>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub defined_at: Option<Location>, pub type_info: Option<String>, } #[derive(Debug, Clone)] pub enum SymbolKind { Function, Variable, Type, Module, } #[derive(Debug, Clone)] pub struct Location { pub file: GlobalSymbol, pub line: u32, pub column: u32, } impl ModuleSymbolTable { pub fn new(name: &str) -> Self { Self { name: GlobalSymbol::from(name), exported: HashMap::new(), internal: HashMap::new(), } } pub fn define_exported(&mut self, name: &str, info: SymbolInfo) { self.exported.insert(GlobalSymbol::from(name), info); } pub fn define_internal(&mut self, name: &str, info: SymbolInfo) { self.internal.insert(GlobalSymbol::from(name), info); } pub fn lookup(&self, name: &GlobalSymbol) -> Option<&SymbolInfo> { self.exported.get(name).or_else(|| self.internal.get(name)) } } pub type ConcurrentSymbolCache = Arc<RwLock<HashMap<GlobalSymbol, String>>>; pub fn create_concurrent_cache() -> ConcurrentSymbolCache { Arc::new(RwLock::new(HashMap::new())) } pub fn benchmark_symbol_creation() { use std::time::Instant; let iterations = 10000; let unique_symbols = 1000; let start = Instant::now(); let table = SymbolTable::new(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); table.intern(&name); } let local_time = start.elapsed(); let start = Instant::now(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); let _ = GlobalSymbol::from(name); } let global_time = start.elapsed(); println!( "Symbol creation benchmark ({} iterations, {} unique):", iterations, unique_symbols ); println!(" Local SymbolTable: {:?}", local_time); println!(" GlobalSymbol: {:?}", global_time); println!( " Ratio: {:.2}x", local_time.as_nanos() as f64 / global_time.as_nanos() as f64 ); } #[cfg(test)] mod tests { use super::*; #[test] fn test_static_symbol_equality() { let sym1 = static_symbol!("test"); let sym2 = static_symbol!("test"); let sym3 = GlobalSymbol::from("test"); assert_eq!(sym1, sym2); assert_eq!(sym1, sym3); assert!(std::ptr::eq(sym1.as_str(), sym2.as_str())); } #[test] fn test_identifier() { let ident = Identifier::new("foo", Span { start: 0, end: 3 }); assert_eq!(ident.as_str(), "foo"); assert_eq!(ident.symbol, GlobalSymbol::from("foo")); } #[test] fn test_compiler_context() { let mut ctx = CompilerContext::new(); let sym = ctx.intern_string("if"); assert!(ctx.is_keyword(sym).is_some()); assert_eq!(*ctx.is_keyword(sym).unwrap(), TokenKind::If); let var_sym = ctx.intern_string("my_var"); assert!(ctx.is_keyword(var_sym).is_none()); } #[test] fn test_module_symbol_table() { let mut module = ModuleSymbolTable::new("my_module"); module.define_exported( "public_func", SymbolInfo { kind: SymbolKind::Function, defined_at: None, type_info: Some("fn() -> i32".to_string()), }, ); module.define_internal( "private_var", SymbolInfo { kind: SymbolKind::Variable, defined_at: None, type_info: Some("String".to_string()), }, ); let pub_sym = GlobalSymbol::from("public_func"); let priv_sym = GlobalSymbol::from("private_var"); assert!(module.lookup(&pub_sym).is_some()); assert!(module.lookup(&priv_sym).is_some()); assert!(module.exported.contains_key(&pub_sym)); assert!(!module.exported.contains_key(&priv_sym)); } } pub fn demonstrate_concurrent_access() { let cache = create_concurrent_cache(); let symbols: Vec<_> = (0..10) .map(|i| GlobalSymbol::from(format!("symbol_{}", i))) .collect(); { let mut cache_write = cache.write().unwrap(); for (i, sym) in symbols.iter().enumerate() { cache_write.insert(*sym, format!("Value {}", i)); } } { let cache_read = cache.read().unwrap(); println!("Concurrent cache contents:"); for sym in &symbols[..5] { if let Some(value) = cache_read.get(sym) { println!(" {} => {}", sym.as_str(), value); } } } } }
This pattern uses Arc<RwLock<HashMap>>
to allow multiple readers or a single writer. The GlobalSymbol
keys ensure fast lookups even with concurrent access.
Performance Characteristics
Symbol interning provides several performance benefits:
#![allow(unused)] fn main() { use std::collections::HashMap; use std::sync::{Arc, RwLock}; use symbol_table::{static_symbol, GlobalSymbol, Symbol, SymbolTable}; #[derive(Debug, Clone)] pub struct Identifier { pub symbol: GlobalSymbol, pub span: Span, } #[derive(Debug, Clone, Copy)] pub struct Span { pub start: usize, pub end: usize, } impl Identifier { pub fn new(name: &str, span: Span) -> Self { Self { symbol: GlobalSymbol::from(name), span, } } pub fn as_str(&self) -> &str { self.symbol.as_str() } } pub fn demonstrate_static_symbols() { let if_sym = static_symbol!("if"); let else_sym = static_symbol!("else"); let while_sym = static_symbol!("while"); let for_sym = static_symbol!("for"); let return_sym = static_symbol!("return"); println!("Static symbols created:"); println!(" if: {:?}", if_sym); println!(" else: {:?}", else_sym); println!(" while: {:?}", while_sym); println!(" for: {:?}", for_sym); println!(" return: {:?}", return_sym); let if_sym2 = static_symbol!("if"); println!("\nSymbol equality: if == if2: {}", if_sym == if_sym2); println!( "Pointer equality: {:p} == {:p}", if_sym.as_str(), if_sym2.as_str() ); } pub fn demonstrate_global_symbols() { let sym1 = GlobalSymbol::from("variable_name"); let sym2 = GlobalSymbol::from("variable_name"); let sym3 = GlobalSymbol::from("another_name"); println!("Symbol interning:"); println!(" sym1 == sym2: {}", sym1 == sym2); println!(" sym1 == sym3: {}", sym1 == sym3); println!( " Pointers equal: {}", std::ptr::eq(sym1.as_str(), sym2.as_str()) ); } pub struct CompilerContext { pub symbols: SymbolTable, pub keywords: HashMap<Symbol, TokenKind>, pub string_literals: Vec<Symbol>, } #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { If, Else, While, For, Return, Function, Let, Const, } impl Default for CompilerContext { fn default() -> Self { Self::new() } } impl CompilerContext { pub fn new() -> Self { let symbols = SymbolTable::new(); let mut keywords = HashMap::new(); keywords.insert(symbols.intern("if"), TokenKind::If); keywords.insert(symbols.intern("else"), TokenKind::Else); keywords.insert(symbols.intern("while"), TokenKind::While); keywords.insert(symbols.intern("for"), TokenKind::For); keywords.insert(symbols.intern("return"), TokenKind::Return); keywords.insert(symbols.intern("function"), TokenKind::Function); keywords.insert(symbols.intern("let"), TokenKind::Let); keywords.insert(symbols.intern("const"), TokenKind::Const); Self { symbols, keywords, string_literals: Vec::new(), } } pub fn intern_string(&mut self, s: &str) -> Symbol { self.symbols.intern(s) } pub fn is_keyword(&self, sym: Symbol) -> Option<&TokenKind> { self.keywords.get(&sym) } pub fn add_string_literal(&mut self, s: &str) -> usize { let sym = self.symbols.intern(s); self.string_literals.push(sym); self.string_literals.len() - 1 } } pub fn demonstrate_compiler_context() { let mut ctx = CompilerContext::new(); let ident = ctx.intern_string("my_variable"); let keyword = ctx.intern_string("if"); println!("Identifier 'my_variable' interned as Symbol"); println!("Is 'my_variable' a keyword: {:?}", ctx.is_keyword(ident)); println!("Is 'if' a keyword: {:?}", ctx.is_keyword(keyword)); let lit_idx = ctx.add_string_literal("Hello, world!"); println!("String literal index: {}", lit_idx); } #[derive(Debug)] pub struct ModuleSymbolTable { pub name: GlobalSymbol, pub exported: HashMap<GlobalSymbol, SymbolInfo>, pub internal: HashMap<GlobalSymbol, SymbolInfo>, } #[derive(Debug, Clone)] pub struct SymbolInfo { pub kind: SymbolKind, pub defined_at: Option<Location>, pub type_info: Option<String>, } #[derive(Debug, Clone)] pub enum SymbolKind { Function, Variable, Type, Module, } #[derive(Debug, Clone)] pub struct Location { pub file: GlobalSymbol, pub line: u32, pub column: u32, } impl ModuleSymbolTable { pub fn new(name: &str) -> Self { Self { name: GlobalSymbol::from(name), exported: HashMap::new(), internal: HashMap::new(), } } pub fn define_exported(&mut self, name: &str, info: SymbolInfo) { self.exported.insert(GlobalSymbol::from(name), info); } pub fn define_internal(&mut self, name: &str, info: SymbolInfo) { self.internal.insert(GlobalSymbol::from(name), info); } pub fn lookup(&self, name: &GlobalSymbol) -> Option<&SymbolInfo> { self.exported.get(name).or_else(|| self.internal.get(name)) } } pub type ConcurrentSymbolCache = Arc<RwLock<HashMap<GlobalSymbol, String>>>; pub fn create_concurrent_cache() -> ConcurrentSymbolCache { Arc::new(RwLock::new(HashMap::new())) } pub fn demonstrate_concurrent_access() { let cache = create_concurrent_cache(); let symbols: Vec<_> = (0..10) .map(|i| GlobalSymbol::from(format!("symbol_{}", i))) .collect(); { let mut cache_write = cache.write().unwrap(); for (i, sym) in symbols.iter().enumerate() { cache_write.insert(*sym, format!("Value {}", i)); } } { let cache_read = cache.read().unwrap(); println!("Concurrent cache contents:"); for sym in &symbols[..5] { if let Some(value) = cache_read.get(sym) { println!(" {} => {}", sym.as_str(), value); } } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_static_symbol_equality() { let sym1 = static_symbol!("test"); let sym2 = static_symbol!("test"); let sym3 = GlobalSymbol::from("test"); assert_eq!(sym1, sym2); assert_eq!(sym1, sym3); assert!(std::ptr::eq(sym1.as_str(), sym2.as_str())); } #[test] fn test_identifier() { let ident = Identifier::new("foo", Span { start: 0, end: 3 }); assert_eq!(ident.as_str(), "foo"); assert_eq!(ident.symbol, GlobalSymbol::from("foo")); } #[test] fn test_compiler_context() { let mut ctx = CompilerContext::new(); let sym = ctx.intern_string("if"); assert!(ctx.is_keyword(sym).is_some()); assert_eq!(*ctx.is_keyword(sym).unwrap(), TokenKind::If); let var_sym = ctx.intern_string("my_var"); assert!(ctx.is_keyword(var_sym).is_none()); } #[test] fn test_module_symbol_table() { let mut module = ModuleSymbolTable::new("my_module"); module.define_exported( "public_func", SymbolInfo { kind: SymbolKind::Function, defined_at: None, type_info: Some("fn() -> i32".to_string()), }, ); module.define_internal( "private_var", SymbolInfo { kind: SymbolKind::Variable, defined_at: None, type_info: Some("String".to_string()), }, ); let pub_sym = GlobalSymbol::from("public_func"); let priv_sym = GlobalSymbol::from("private_var"); assert!(module.lookup(&pub_sym).is_some()); assert!(module.lookup(&priv_sym).is_some()); assert!(module.exported.contains_key(&pub_sym)); assert!(!module.exported.contains_key(&priv_sym)); } } pub fn benchmark_symbol_creation() { use std::time::Instant; let iterations = 10000; let unique_symbols = 1000; let start = Instant::now(); let table = SymbolTable::new(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); table.intern(&name); } let local_time = start.elapsed(); let start = Instant::now(); for i in 0..iterations { let name = format!("symbol_{}", i % unique_symbols); let _ = GlobalSymbol::from(name); } let global_time = start.elapsed(); println!( "Symbol creation benchmark ({} iterations, {} unique):", iterations, unique_symbols ); println!(" Local SymbolTable: {:?}", local_time); println!(" GlobalSymbol: {:?}", global_time); println!( " Ratio: {:.2}x", local_time.as_nanos() as f64 / global_time.as_nanos() as f64 ); } }
The benchmark demonstrates that global symbols have competitive performance with local symbol tables while providing thread safety. The actual performance depends on contention levels and symbol reuse patterns.
Best Practices
Use static_symbol!
for all keywords and operators known at compile time. This eliminates runtime interning overhead for the most common symbols. Create a module specifically for language keywords to centralize these definitions.
Prefer GlobalSymbol
over local SymbolTable
in multi-threaded compilers. The global approach simplifies code and enables better parallelization. Local tables only make sense for isolated processing with no symbol sharing.
Design data structures to store symbols rather than strings. This applies to AST nodes, type representations, and error messages. Converting to strings should only happen at boundaries like error reporting or code generation.
Be aware of symbol lifetime. Both SymbolTable
and GlobalSymbol
keep strings alive forever. This is rarely a problem for compilers since the set of unique identifiers is bounded, but consider the implications for long-running language servers.
Use symbols as hash map keys freely. They implement Hash
and Eq
with optimal performance. Many compiler algorithms become simpler when symbols can be used directly as keys.
For incremental compilation, symbols provide stable identities across compilations. Two runs that encounter the same identifier will produce symbols that compare equal, enabling effective caching strategies.
petgraph
The petgraph
crate provides a general-purpose graph data structure library for Rust. In compiler development, graphs are fundamental data structures used to represent control flow graphs (CFG), call graphs, dependency graphs, dominance trees, and data flow analysis. The crate offers both directed and undirected graphs with flexible node and edge weights, along with a comprehensive collection of graph algorithms.
Control flow graphs are perhaps the most common use of graphs in compilers. Each node represents a basic block of instructions, and edges represent possible control flow between blocks. Call graphs track function calling relationships, which is essential for interprocedural analysis and optimization. Dependency graphs help with instruction scheduling and detecting parallelization opportunities.
Building Control Flow Graphs
A control flow graph represents the flow of control through a program. Here we create a simple CFG for an if-then-else statement:
#![allow(unused)] fn main() { use std::collections::HashMap; use petgraph::algo::{dominators, has_path_connecting, toposort}; use petgraph::dot::{Config, Dot}; use petgraph::graph::{DiGraph, NodeIndex}; use petgraph::visit::{Bfs, Dfs, Reversed}; #[derive(Debug, Clone)] pub struct BasicBlock { pub id: String, pub instructions: Vec<String>, } pub type ControlFlowGraph = DiGraph<BasicBlock, String>; pub fn build_loop_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["i = 0".to_string()], }); blocks.insert("entry".to_string(), entry); let loop_header = cfg.add_node(BasicBlock { id: "loop_header".to_string(), instructions: vec!["if i < 10".to_string()], }); blocks.insert("loop_header".to_string(), loop_header); let loop_body = cfg.add_node(BasicBlock { id: "loop_body".to_string(), instructions: vec!["sum += i".to_string(), "i += 1".to_string()], }); blocks.insert("loop_body".to_string(), loop_body); let exit = cfg.add_node(BasicBlock { id: "exit".to_string(), instructions: vec!["return sum".to_string()], }); blocks.insert("exit".to_string(), exit); cfg.add_edge(entry, loop_header, "fallthrough".to_string()); cfg.add_edge(loop_header, loop_body, "true".to_string()); cfg.add_edge(loop_header, exit, "false".to_string()); cfg.add_edge(loop_body, loop_header, "backedge".to_string()); (cfg, blocks) } pub fn perform_dfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut dfs = Dfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = dfs.next(&graph) { visited.push(node); } visited } pub fn perform_bfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut bfs = Bfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = bfs.next(&graph) { visited.push(node); } visited } pub fn find_dominators( graph: &ControlFlowGraph, entry: NodeIndex, ) -> HashMap<NodeIndex, NodeIndex> { let dom_tree = dominators::simple_fast(&graph, entry); let mut dom_map = HashMap::new(); for node in graph.node_indices() { if let Some(idom) = dom_tree.immediate_dominator(node) { if idom != node { dom_map.insert(node, idom); } } } dom_map } pub fn detect_unreachable_code(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let mut unreachable = Vec::new(); for node in graph.node_indices() { if !has_path_connecting(&graph, entry, node, None) { unreachable.push(node); } } unreachable } pub fn topological_ordering(graph: &ControlFlowGraph) -> Option<Vec<NodeIndex>> { toposort(&graph, None).ok() } pub fn print_cfg_dot(graph: &ControlFlowGraph) -> String { format!("{:?}", Dot::with_config(&graph, &[Config::EdgeNoLabel])) } #[derive(Debug, Clone)] pub struct CallGraphNode { pub name: String, pub is_recursive: bool, } pub type CallGraph = DiGraph<CallGraphNode, ()>; pub fn build_call_graph() -> (CallGraph, HashMap<String, NodeIndex>) { let mut cg = CallGraph::new(); let mut funcs = HashMap::new(); let main = cg.add_node(CallGraphNode { name: "main".to_string(), is_recursive: false, }); funcs.insert("main".to_string(), main); let parse = cg.add_node(CallGraphNode { name: "parse".to_string(), is_recursive: false, }); funcs.insert("parse".to_string(), parse); let analyze = cg.add_node(CallGraphNode { name: "analyze".to_string(), is_recursive: false, }); funcs.insert("analyze".to_string(), analyze); let codegen = cg.add_node(CallGraphNode { name: "codegen".to_string(), is_recursive: false, }); funcs.insert("codegen".to_string(), codegen); let optimize = cg.add_node(CallGraphNode { name: "optimize".to_string(), is_recursive: true, }); funcs.insert("optimize".to_string(), optimize); cg.add_edge(main, parse, ()); cg.add_edge(main, analyze, ()); cg.add_edge(main, codegen, ()); cg.add_edge(analyze, optimize, ()); cg.add_edge(optimize, optimize, ()); (cg, funcs) } pub fn find_recursive_functions(graph: &CallGraph) -> Vec<NodeIndex> { let mut recursive = Vec::new(); for node in graph.node_indices() { // Check for self-loops (direct recursion) if graph.find_edge(node, node).is_some() { recursive.push(node); continue; } // Check for indirect recursion (cycles that include this node) let mut dfs = Dfs::new(graph, node); while let Some(visited) = dfs.next(graph) { if visited != node && graph.find_edge(visited, node).is_some() { recursive.push(node); break; } } } recursive } pub fn reverse_postorder(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let reversed = Reversed(&graph); let mut dfs = Dfs::new(&reversed, entry); let mut stack = Vec::new(); while let Some(node) = dfs.next(&reversed) { stack.push(node); } stack.reverse(); stack } #[cfg(test)] mod tests { use super::*; #[test] fn test_cfg_construction() { let (cfg, blocks) = build_simple_cfg(); assert_eq!(cfg.node_count(), 5); assert_eq!(cfg.edge_count(), 5); assert!(blocks.contains_key("entry")); assert!(blocks.contains_key("merge")); } #[test] fn test_dfs_traversal() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let visited = perform_dfs(&cfg, entry); assert_eq!(visited.len(), 5); } #[test] fn test_dominators() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let doms = find_dominators(&cfg, entry); let cond = blocks["cond"]; let merge = blocks["merge"]; assert_eq!(doms[&cond], entry); assert_eq!(doms[&merge], cond); } #[test] fn test_loop_detection() { let (cfg, _) = build_loop_cfg(); let topo_order = topological_ordering(&cfg); assert!(topo_order.is_none()); } #[test] fn test_recursive_functions() { let (cg, _funcs) = build_call_graph(); let recursive = find_recursive_functions(&cg); assert_eq!(recursive.len(), 1); assert_eq!(cg[recursive[0]].name, "optimize"); } } pub fn build_simple_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["x = 10".to_string()], }); blocks.insert("entry".to_string(), entry); let cond = cfg.add_node(BasicBlock { id: "cond".to_string(), instructions: vec!["if x > 5".to_string()], }); blocks.insert("cond".to_string(), cond); let then_block = cfg.add_node(BasicBlock { id: "then".to_string(), instructions: vec!["x = x * 2".to_string()], }); blocks.insert("then".to_string(), then_block); let else_block = cfg.add_node(BasicBlock { id: "else".to_string(), instructions: vec!["x = x + 1".to_string()], }); blocks.insert("else".to_string(), else_block); let merge = cfg.add_node(BasicBlock { id: "merge".to_string(), instructions: vec!["print(x)".to_string()], }); blocks.insert("merge".to_string(), merge); cfg.add_edge(entry, cond, "fallthrough".to_string()); cfg.add_edge(cond, then_block, "true".to_string()); cfg.add_edge(cond, else_block, "false".to_string()); cfg.add_edge(then_block, merge, "fallthrough".to_string()); cfg.add_edge(else_block, merge, "fallthrough".to_string()); (cfg, blocks) } }
This function builds a CFG with an entry block, a conditional block that branches to either a then or else block, and finally a merge block where control flow reconverges. The graph structure makes it easy to analyze properties like dominance relationships and reachability.
Graph Traversal
Compilers frequently need to traverse graphs in specific orders. Depth-first search (DFS) is used for tasks like computing reverse postorder for dataflow analysis:
#![allow(unused)] fn main() { use std::collections::HashMap; use petgraph::algo::{dominators, has_path_connecting, toposort}; use petgraph::dot::{Config, Dot}; use petgraph::graph::{DiGraph, NodeIndex}; use petgraph::visit::{Bfs, Dfs, Reversed}; #[derive(Debug, Clone)] pub struct BasicBlock { pub id: String, pub instructions: Vec<String>, } pub type ControlFlowGraph = DiGraph<BasicBlock, String>; pub fn build_simple_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["x = 10".to_string()], }); blocks.insert("entry".to_string(), entry); let cond = cfg.add_node(BasicBlock { id: "cond".to_string(), instructions: vec!["if x > 5".to_string()], }); blocks.insert("cond".to_string(), cond); let then_block = cfg.add_node(BasicBlock { id: "then".to_string(), instructions: vec!["x = x * 2".to_string()], }); blocks.insert("then".to_string(), then_block); let else_block = cfg.add_node(BasicBlock { id: "else".to_string(), instructions: vec!["x = x + 1".to_string()], }); blocks.insert("else".to_string(), else_block); let merge = cfg.add_node(BasicBlock { id: "merge".to_string(), instructions: vec!["print(x)".to_string()], }); blocks.insert("merge".to_string(), merge); cfg.add_edge(entry, cond, "fallthrough".to_string()); cfg.add_edge(cond, then_block, "true".to_string()); cfg.add_edge(cond, else_block, "false".to_string()); cfg.add_edge(then_block, merge, "fallthrough".to_string()); cfg.add_edge(else_block, merge, "fallthrough".to_string()); (cfg, blocks) } pub fn build_loop_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["i = 0".to_string()], }); blocks.insert("entry".to_string(), entry); let loop_header = cfg.add_node(BasicBlock { id: "loop_header".to_string(), instructions: vec!["if i < 10".to_string()], }); blocks.insert("loop_header".to_string(), loop_header); let loop_body = cfg.add_node(BasicBlock { id: "loop_body".to_string(), instructions: vec!["sum += i".to_string(), "i += 1".to_string()], }); blocks.insert("loop_body".to_string(), loop_body); let exit = cfg.add_node(BasicBlock { id: "exit".to_string(), instructions: vec!["return sum".to_string()], }); blocks.insert("exit".to_string(), exit); cfg.add_edge(entry, loop_header, "fallthrough".to_string()); cfg.add_edge(loop_header, loop_body, "true".to_string()); cfg.add_edge(loop_header, exit, "false".to_string()); cfg.add_edge(loop_body, loop_header, "backedge".to_string()); (cfg, blocks) } pub fn perform_bfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut bfs = Bfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = bfs.next(&graph) { visited.push(node); } visited } pub fn find_dominators( graph: &ControlFlowGraph, entry: NodeIndex, ) -> HashMap<NodeIndex, NodeIndex> { let dom_tree = dominators::simple_fast(&graph, entry); let mut dom_map = HashMap::new(); for node in graph.node_indices() { if let Some(idom) = dom_tree.immediate_dominator(node) { if idom != node { dom_map.insert(node, idom); } } } dom_map } pub fn detect_unreachable_code(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let mut unreachable = Vec::new(); for node in graph.node_indices() { if !has_path_connecting(&graph, entry, node, None) { unreachable.push(node); } } unreachable } pub fn topological_ordering(graph: &ControlFlowGraph) -> Option<Vec<NodeIndex>> { toposort(&graph, None).ok() } pub fn print_cfg_dot(graph: &ControlFlowGraph) -> String { format!("{:?}", Dot::with_config(&graph, &[Config::EdgeNoLabel])) } #[derive(Debug, Clone)] pub struct CallGraphNode { pub name: String, pub is_recursive: bool, } pub type CallGraph = DiGraph<CallGraphNode, ()>; pub fn build_call_graph() -> (CallGraph, HashMap<String, NodeIndex>) { let mut cg = CallGraph::new(); let mut funcs = HashMap::new(); let main = cg.add_node(CallGraphNode { name: "main".to_string(), is_recursive: false, }); funcs.insert("main".to_string(), main); let parse = cg.add_node(CallGraphNode { name: "parse".to_string(), is_recursive: false, }); funcs.insert("parse".to_string(), parse); let analyze = cg.add_node(CallGraphNode { name: "analyze".to_string(), is_recursive: false, }); funcs.insert("analyze".to_string(), analyze); let codegen = cg.add_node(CallGraphNode { name: "codegen".to_string(), is_recursive: false, }); funcs.insert("codegen".to_string(), codegen); let optimize = cg.add_node(CallGraphNode { name: "optimize".to_string(), is_recursive: true, }); funcs.insert("optimize".to_string(), optimize); cg.add_edge(main, parse, ()); cg.add_edge(main, analyze, ()); cg.add_edge(main, codegen, ()); cg.add_edge(analyze, optimize, ()); cg.add_edge(optimize, optimize, ()); (cg, funcs) } pub fn find_recursive_functions(graph: &CallGraph) -> Vec<NodeIndex> { let mut recursive = Vec::new(); for node in graph.node_indices() { // Check for self-loops (direct recursion) if graph.find_edge(node, node).is_some() { recursive.push(node); continue; } // Check for indirect recursion (cycles that include this node) let mut dfs = Dfs::new(graph, node); while let Some(visited) = dfs.next(graph) { if visited != node && graph.find_edge(visited, node).is_some() { recursive.push(node); break; } } } recursive } pub fn reverse_postorder(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let reversed = Reversed(&graph); let mut dfs = Dfs::new(&reversed, entry); let mut stack = Vec::new(); while let Some(node) = dfs.next(&reversed) { stack.push(node); } stack.reverse(); stack } #[cfg(test)] mod tests { use super::*; #[test] fn test_cfg_construction() { let (cfg, blocks) = build_simple_cfg(); assert_eq!(cfg.node_count(), 5); assert_eq!(cfg.edge_count(), 5); assert!(blocks.contains_key("entry")); assert!(blocks.contains_key("merge")); } #[test] fn test_dfs_traversal() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let visited = perform_dfs(&cfg, entry); assert_eq!(visited.len(), 5); } #[test] fn test_dominators() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let doms = find_dominators(&cfg, entry); let cond = blocks["cond"]; let merge = blocks["merge"]; assert_eq!(doms[&cond], entry); assert_eq!(doms[&merge], cond); } #[test] fn test_loop_detection() { let (cfg, _) = build_loop_cfg(); let topo_order = topological_ordering(&cfg); assert!(topo_order.is_none()); } #[test] fn test_recursive_functions() { let (cg, _funcs) = build_call_graph(); let recursive = find_recursive_functions(&cg); assert_eq!(recursive.len(), 1); assert_eq!(cg[recursive[0]].name, "optimize"); } } pub fn perform_dfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut dfs = Dfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = dfs.next(&graph) { visited.push(node); } visited } }
Breadth-first search (BFS) is useful for level-order traversals and finding shortest paths:
#![allow(unused)] fn main() { use std::collections::HashMap; use petgraph::algo::{dominators, has_path_connecting, toposort}; use petgraph::dot::{Config, Dot}; use petgraph::graph::{DiGraph, NodeIndex}; use petgraph::visit::{Bfs, Dfs, Reversed}; #[derive(Debug, Clone)] pub struct BasicBlock { pub id: String, pub instructions: Vec<String>, } pub type ControlFlowGraph = DiGraph<BasicBlock, String>; pub fn build_simple_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["x = 10".to_string()], }); blocks.insert("entry".to_string(), entry); let cond = cfg.add_node(BasicBlock { id: "cond".to_string(), instructions: vec!["if x > 5".to_string()], }); blocks.insert("cond".to_string(), cond); let then_block = cfg.add_node(BasicBlock { id: "then".to_string(), instructions: vec!["x = x * 2".to_string()], }); blocks.insert("then".to_string(), then_block); let else_block = cfg.add_node(BasicBlock { id: "else".to_string(), instructions: vec!["x = x + 1".to_string()], }); blocks.insert("else".to_string(), else_block); let merge = cfg.add_node(BasicBlock { id: "merge".to_string(), instructions: vec!["print(x)".to_string()], }); blocks.insert("merge".to_string(), merge); cfg.add_edge(entry, cond, "fallthrough".to_string()); cfg.add_edge(cond, then_block, "true".to_string()); cfg.add_edge(cond, else_block, "false".to_string()); cfg.add_edge(then_block, merge, "fallthrough".to_string()); cfg.add_edge(else_block, merge, "fallthrough".to_string()); (cfg, blocks) } pub fn build_loop_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["i = 0".to_string()], }); blocks.insert("entry".to_string(), entry); let loop_header = cfg.add_node(BasicBlock { id: "loop_header".to_string(), instructions: vec!["if i < 10".to_string()], }); blocks.insert("loop_header".to_string(), loop_header); let loop_body = cfg.add_node(BasicBlock { id: "loop_body".to_string(), instructions: vec!["sum += i".to_string(), "i += 1".to_string()], }); blocks.insert("loop_body".to_string(), loop_body); let exit = cfg.add_node(BasicBlock { id: "exit".to_string(), instructions: vec!["return sum".to_string()], }); blocks.insert("exit".to_string(), exit); cfg.add_edge(entry, loop_header, "fallthrough".to_string()); cfg.add_edge(loop_header, loop_body, "true".to_string()); cfg.add_edge(loop_header, exit, "false".to_string()); cfg.add_edge(loop_body, loop_header, "backedge".to_string()); (cfg, blocks) } pub fn perform_dfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut dfs = Dfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = dfs.next(&graph) { visited.push(node); } visited } pub fn find_dominators( graph: &ControlFlowGraph, entry: NodeIndex, ) -> HashMap<NodeIndex, NodeIndex> { let dom_tree = dominators::simple_fast(&graph, entry); let mut dom_map = HashMap::new(); for node in graph.node_indices() { if let Some(idom) = dom_tree.immediate_dominator(node) { if idom != node { dom_map.insert(node, idom); } } } dom_map } pub fn detect_unreachable_code(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let mut unreachable = Vec::new(); for node in graph.node_indices() { if !has_path_connecting(&graph, entry, node, None) { unreachable.push(node); } } unreachable } pub fn topological_ordering(graph: &ControlFlowGraph) -> Option<Vec<NodeIndex>> { toposort(&graph, None).ok() } pub fn print_cfg_dot(graph: &ControlFlowGraph) -> String { format!("{:?}", Dot::with_config(&graph, &[Config::EdgeNoLabel])) } #[derive(Debug, Clone)] pub struct CallGraphNode { pub name: String, pub is_recursive: bool, } pub type CallGraph = DiGraph<CallGraphNode, ()>; pub fn build_call_graph() -> (CallGraph, HashMap<String, NodeIndex>) { let mut cg = CallGraph::new(); let mut funcs = HashMap::new(); let main = cg.add_node(CallGraphNode { name: "main".to_string(), is_recursive: false, }); funcs.insert("main".to_string(), main); let parse = cg.add_node(CallGraphNode { name: "parse".to_string(), is_recursive: false, }); funcs.insert("parse".to_string(), parse); let analyze = cg.add_node(CallGraphNode { name: "analyze".to_string(), is_recursive: false, }); funcs.insert("analyze".to_string(), analyze); let codegen = cg.add_node(CallGraphNode { name: "codegen".to_string(), is_recursive: false, }); funcs.insert("codegen".to_string(), codegen); let optimize = cg.add_node(CallGraphNode { name: "optimize".to_string(), is_recursive: true, }); funcs.insert("optimize".to_string(), optimize); cg.add_edge(main, parse, ()); cg.add_edge(main, analyze, ()); cg.add_edge(main, codegen, ()); cg.add_edge(analyze, optimize, ()); cg.add_edge(optimize, optimize, ()); (cg, funcs) } pub fn find_recursive_functions(graph: &CallGraph) -> Vec<NodeIndex> { let mut recursive = Vec::new(); for node in graph.node_indices() { // Check for self-loops (direct recursion) if graph.find_edge(node, node).is_some() { recursive.push(node); continue; } // Check for indirect recursion (cycles that include this node) let mut dfs = Dfs::new(graph, node); while let Some(visited) = dfs.next(graph) { if visited != node && graph.find_edge(visited, node).is_some() { recursive.push(node); break; } } } recursive } pub fn reverse_postorder(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let reversed = Reversed(&graph); let mut dfs = Dfs::new(&reversed, entry); let mut stack = Vec::new(); while let Some(node) = dfs.next(&reversed) { stack.push(node); } stack.reverse(); stack } #[cfg(test)] mod tests { use super::*; #[test] fn test_cfg_construction() { let (cfg, blocks) = build_simple_cfg(); assert_eq!(cfg.node_count(), 5); assert_eq!(cfg.edge_count(), 5); assert!(blocks.contains_key("entry")); assert!(blocks.contains_key("merge")); } #[test] fn test_dfs_traversal() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let visited = perform_dfs(&cfg, entry); assert_eq!(visited.len(), 5); } #[test] fn test_dominators() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let doms = find_dominators(&cfg, entry); let cond = blocks["cond"]; let merge = blocks["merge"]; assert_eq!(doms[&cond], entry); assert_eq!(doms[&merge], cond); } #[test] fn test_loop_detection() { let (cfg, _) = build_loop_cfg(); let topo_order = topological_ordering(&cfg); assert!(topo_order.is_none()); } #[test] fn test_recursive_functions() { let (cg, _funcs) = build_call_graph(); let recursive = find_recursive_functions(&cg); assert_eq!(recursive.len(), 1); assert_eq!(cg[recursive[0]].name, "optimize"); } } pub fn perform_bfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut bfs = Bfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = bfs.next(&graph) { visited.push(node); } visited } }
Dominance Analysis
Dominance is a fundamental concept in compiler optimization. A node A dominates node B if every path from the entry to B must go through A:
#![allow(unused)] fn main() { use std::collections::HashMap; use petgraph::algo::{dominators, has_path_connecting, toposort}; use petgraph::dot::{Config, Dot}; use petgraph::graph::{DiGraph, NodeIndex}; use petgraph::visit::{Bfs, Dfs, Reversed}; #[derive(Debug, Clone)] pub struct BasicBlock { pub id: String, pub instructions: Vec<String>, } pub type ControlFlowGraph = DiGraph<BasicBlock, String>; pub fn build_simple_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["x = 10".to_string()], }); blocks.insert("entry".to_string(), entry); let cond = cfg.add_node(BasicBlock { id: "cond".to_string(), instructions: vec!["if x > 5".to_string()], }); blocks.insert("cond".to_string(), cond); let then_block = cfg.add_node(BasicBlock { id: "then".to_string(), instructions: vec!["x = x * 2".to_string()], }); blocks.insert("then".to_string(), then_block); let else_block = cfg.add_node(BasicBlock { id: "else".to_string(), instructions: vec!["x = x + 1".to_string()], }); blocks.insert("else".to_string(), else_block); let merge = cfg.add_node(BasicBlock { id: "merge".to_string(), instructions: vec!["print(x)".to_string()], }); blocks.insert("merge".to_string(), merge); cfg.add_edge(entry, cond, "fallthrough".to_string()); cfg.add_edge(cond, then_block, "true".to_string()); cfg.add_edge(cond, else_block, "false".to_string()); cfg.add_edge(then_block, merge, "fallthrough".to_string()); cfg.add_edge(else_block, merge, "fallthrough".to_string()); (cfg, blocks) } pub fn build_loop_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["i = 0".to_string()], }); blocks.insert("entry".to_string(), entry); let loop_header = cfg.add_node(BasicBlock { id: "loop_header".to_string(), instructions: vec!["if i < 10".to_string()], }); blocks.insert("loop_header".to_string(), loop_header); let loop_body = cfg.add_node(BasicBlock { id: "loop_body".to_string(), instructions: vec!["sum += i".to_string(), "i += 1".to_string()], }); blocks.insert("loop_body".to_string(), loop_body); let exit = cfg.add_node(BasicBlock { id: "exit".to_string(), instructions: vec!["return sum".to_string()], }); blocks.insert("exit".to_string(), exit); cfg.add_edge(entry, loop_header, "fallthrough".to_string()); cfg.add_edge(loop_header, loop_body, "true".to_string()); cfg.add_edge(loop_header, exit, "false".to_string()); cfg.add_edge(loop_body, loop_header, "backedge".to_string()); (cfg, blocks) } pub fn perform_dfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut dfs = Dfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = dfs.next(&graph) { visited.push(node); } visited } pub fn perform_bfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut bfs = Bfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = bfs.next(&graph) { visited.push(node); } visited } pub fn detect_unreachable_code(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let mut unreachable = Vec::new(); for node in graph.node_indices() { if !has_path_connecting(&graph, entry, node, None) { unreachable.push(node); } } unreachable } pub fn topological_ordering(graph: &ControlFlowGraph) -> Option<Vec<NodeIndex>> { toposort(&graph, None).ok() } pub fn print_cfg_dot(graph: &ControlFlowGraph) -> String { format!("{:?}", Dot::with_config(&graph, &[Config::EdgeNoLabel])) } #[derive(Debug, Clone)] pub struct CallGraphNode { pub name: String, pub is_recursive: bool, } pub type CallGraph = DiGraph<CallGraphNode, ()>; pub fn build_call_graph() -> (CallGraph, HashMap<String, NodeIndex>) { let mut cg = CallGraph::new(); let mut funcs = HashMap::new(); let main = cg.add_node(CallGraphNode { name: "main".to_string(), is_recursive: false, }); funcs.insert("main".to_string(), main); let parse = cg.add_node(CallGraphNode { name: "parse".to_string(), is_recursive: false, }); funcs.insert("parse".to_string(), parse); let analyze = cg.add_node(CallGraphNode { name: "analyze".to_string(), is_recursive: false, }); funcs.insert("analyze".to_string(), analyze); let codegen = cg.add_node(CallGraphNode { name: "codegen".to_string(), is_recursive: false, }); funcs.insert("codegen".to_string(), codegen); let optimize = cg.add_node(CallGraphNode { name: "optimize".to_string(), is_recursive: true, }); funcs.insert("optimize".to_string(), optimize); cg.add_edge(main, parse, ()); cg.add_edge(main, analyze, ()); cg.add_edge(main, codegen, ()); cg.add_edge(analyze, optimize, ()); cg.add_edge(optimize, optimize, ()); (cg, funcs) } pub fn find_recursive_functions(graph: &CallGraph) -> Vec<NodeIndex> { let mut recursive = Vec::new(); for node in graph.node_indices() { // Check for self-loops (direct recursion) if graph.find_edge(node, node).is_some() { recursive.push(node); continue; } // Check for indirect recursion (cycles that include this node) let mut dfs = Dfs::new(graph, node); while let Some(visited) = dfs.next(graph) { if visited != node && graph.find_edge(visited, node).is_some() { recursive.push(node); break; } } } recursive } pub fn reverse_postorder(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let reversed = Reversed(&graph); let mut dfs = Dfs::new(&reversed, entry); let mut stack = Vec::new(); while let Some(node) = dfs.next(&reversed) { stack.push(node); } stack.reverse(); stack } #[cfg(test)] mod tests { use super::*; #[test] fn test_cfg_construction() { let (cfg, blocks) = build_simple_cfg(); assert_eq!(cfg.node_count(), 5); assert_eq!(cfg.edge_count(), 5); assert!(blocks.contains_key("entry")); assert!(blocks.contains_key("merge")); } #[test] fn test_dfs_traversal() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let visited = perform_dfs(&cfg, entry); assert_eq!(visited.len(), 5); } #[test] fn test_dominators() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let doms = find_dominators(&cfg, entry); let cond = blocks["cond"]; let merge = blocks["merge"]; assert_eq!(doms[&cond], entry); assert_eq!(doms[&merge], cond); } #[test] fn test_loop_detection() { let (cfg, _) = build_loop_cfg(); let topo_order = topological_ordering(&cfg); assert!(topo_order.is_none()); } #[test] fn test_recursive_functions() { let (cg, _funcs) = build_call_graph(); let recursive = find_recursive_functions(&cg); assert_eq!(recursive.len(), 1); assert_eq!(cg[recursive[0]].name, "optimize"); } } pub fn find_dominators( graph: &ControlFlowGraph, entry: NodeIndex, ) -> HashMap<NodeIndex, NodeIndex> { let dom_tree = dominators::simple_fast(&graph, entry); let mut dom_map = HashMap::new(); for node in graph.node_indices() { if let Some(idom) = dom_tree.immediate_dominator(node) { if idom != node { dom_map.insert(node, idom); } } } dom_map } }
The dominance frontier is used in SSA form construction, and immediate dominators help build the dominator tree used in many optimizations.
Loop Detection
Detecting loops is crucial for loop optimizations. A graph contains loops if topological sorting fails:
#![allow(unused)] fn main() { use std::collections::HashMap; use petgraph::algo::{dominators, has_path_connecting, toposort}; use petgraph::dot::{Config, Dot}; use petgraph::graph::{DiGraph, NodeIndex}; use petgraph::visit::{Bfs, Dfs, Reversed}; #[derive(Debug, Clone)] pub struct BasicBlock { pub id: String, pub instructions: Vec<String>, } pub type ControlFlowGraph = DiGraph<BasicBlock, String>; pub fn build_simple_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["x = 10".to_string()], }); blocks.insert("entry".to_string(), entry); let cond = cfg.add_node(BasicBlock { id: "cond".to_string(), instructions: vec!["if x > 5".to_string()], }); blocks.insert("cond".to_string(), cond); let then_block = cfg.add_node(BasicBlock { id: "then".to_string(), instructions: vec!["x = x * 2".to_string()], }); blocks.insert("then".to_string(), then_block); let else_block = cfg.add_node(BasicBlock { id: "else".to_string(), instructions: vec!["x = x + 1".to_string()], }); blocks.insert("else".to_string(), else_block); let merge = cfg.add_node(BasicBlock { id: "merge".to_string(), instructions: vec!["print(x)".to_string()], }); blocks.insert("merge".to_string(), merge); cfg.add_edge(entry, cond, "fallthrough".to_string()); cfg.add_edge(cond, then_block, "true".to_string()); cfg.add_edge(cond, else_block, "false".to_string()); cfg.add_edge(then_block, merge, "fallthrough".to_string()); cfg.add_edge(else_block, merge, "fallthrough".to_string()); (cfg, blocks) } pub fn perform_dfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut dfs = Dfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = dfs.next(&graph) { visited.push(node); } visited } pub fn perform_bfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut bfs = Bfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = bfs.next(&graph) { visited.push(node); } visited } pub fn find_dominators( graph: &ControlFlowGraph, entry: NodeIndex, ) -> HashMap<NodeIndex, NodeIndex> { let dom_tree = dominators::simple_fast(&graph, entry); let mut dom_map = HashMap::new(); for node in graph.node_indices() { if let Some(idom) = dom_tree.immediate_dominator(node) { if idom != node { dom_map.insert(node, idom); } } } dom_map } pub fn detect_unreachable_code(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let mut unreachable = Vec::new(); for node in graph.node_indices() { if !has_path_connecting(&graph, entry, node, None) { unreachable.push(node); } } unreachable } pub fn topological_ordering(graph: &ControlFlowGraph) -> Option<Vec<NodeIndex>> { toposort(&graph, None).ok() } pub fn print_cfg_dot(graph: &ControlFlowGraph) -> String { format!("{:?}", Dot::with_config(&graph, &[Config::EdgeNoLabel])) } #[derive(Debug, Clone)] pub struct CallGraphNode { pub name: String, pub is_recursive: bool, } pub type CallGraph = DiGraph<CallGraphNode, ()>; pub fn build_call_graph() -> (CallGraph, HashMap<String, NodeIndex>) { let mut cg = CallGraph::new(); let mut funcs = HashMap::new(); let main = cg.add_node(CallGraphNode { name: "main".to_string(), is_recursive: false, }); funcs.insert("main".to_string(), main); let parse = cg.add_node(CallGraphNode { name: "parse".to_string(), is_recursive: false, }); funcs.insert("parse".to_string(), parse); let analyze = cg.add_node(CallGraphNode { name: "analyze".to_string(), is_recursive: false, }); funcs.insert("analyze".to_string(), analyze); let codegen = cg.add_node(CallGraphNode { name: "codegen".to_string(), is_recursive: false, }); funcs.insert("codegen".to_string(), codegen); let optimize = cg.add_node(CallGraphNode { name: "optimize".to_string(), is_recursive: true, }); funcs.insert("optimize".to_string(), optimize); cg.add_edge(main, parse, ()); cg.add_edge(main, analyze, ()); cg.add_edge(main, codegen, ()); cg.add_edge(analyze, optimize, ()); cg.add_edge(optimize, optimize, ()); (cg, funcs) } pub fn find_recursive_functions(graph: &CallGraph) -> Vec<NodeIndex> { let mut recursive = Vec::new(); for node in graph.node_indices() { // Check for self-loops (direct recursion) if graph.find_edge(node, node).is_some() { recursive.push(node); continue; } // Check for indirect recursion (cycles that include this node) let mut dfs = Dfs::new(graph, node); while let Some(visited) = dfs.next(graph) { if visited != node && graph.find_edge(visited, node).is_some() { recursive.push(node); break; } } } recursive } pub fn reverse_postorder(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let reversed = Reversed(&graph); let mut dfs = Dfs::new(&reversed, entry); let mut stack = Vec::new(); while let Some(node) = dfs.next(&reversed) { stack.push(node); } stack.reverse(); stack } #[cfg(test)] mod tests { use super::*; #[test] fn test_cfg_construction() { let (cfg, blocks) = build_simple_cfg(); assert_eq!(cfg.node_count(), 5); assert_eq!(cfg.edge_count(), 5); assert!(blocks.contains_key("entry")); assert!(blocks.contains_key("merge")); } #[test] fn test_dfs_traversal() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let visited = perform_dfs(&cfg, entry); assert_eq!(visited.len(), 5); } #[test] fn test_dominators() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let doms = find_dominators(&cfg, entry); let cond = blocks["cond"]; let merge = blocks["merge"]; assert_eq!(doms[&cond], entry); assert_eq!(doms[&merge], cond); } #[test] fn test_loop_detection() { let (cfg, _) = build_loop_cfg(); let topo_order = topological_ordering(&cfg); assert!(topo_order.is_none()); } #[test] fn test_recursive_functions() { let (cg, _funcs) = build_call_graph(); let recursive = find_recursive_functions(&cg); assert_eq!(recursive.len(), 1); assert_eq!(cg[recursive[0]].name, "optimize"); } } pub fn build_loop_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["i = 0".to_string()], }); blocks.insert("entry".to_string(), entry); let loop_header = cfg.add_node(BasicBlock { id: "loop_header".to_string(), instructions: vec!["if i < 10".to_string()], }); blocks.insert("loop_header".to_string(), loop_header); let loop_body = cfg.add_node(BasicBlock { id: "loop_body".to_string(), instructions: vec!["sum += i".to_string(), "i += 1".to_string()], }); blocks.insert("loop_body".to_string(), loop_body); let exit = cfg.add_node(BasicBlock { id: "exit".to_string(), instructions: vec!["return sum".to_string()], }); blocks.insert("exit".to_string(), exit); cfg.add_edge(entry, loop_header, "fallthrough".to_string()); cfg.add_edge(loop_header, loop_body, "true".to_string()); cfg.add_edge(loop_header, exit, "false".to_string()); cfg.add_edge(loop_body, loop_header, "backedge".to_string()); (cfg, blocks) } }
This creates a CFG with a simple while loop. The backedge from the loop body to the header creates a cycle in the graph.
Dead Code Detection
Unreachable code detection helps identify blocks that can never be executed:
#![allow(unused)] fn main() { use std::collections::HashMap; use petgraph::algo::{dominators, has_path_connecting, toposort}; use petgraph::dot::{Config, Dot}; use petgraph::graph::{DiGraph, NodeIndex}; use petgraph::visit::{Bfs, Dfs, Reversed}; #[derive(Debug, Clone)] pub struct BasicBlock { pub id: String, pub instructions: Vec<String>, } pub type ControlFlowGraph = DiGraph<BasicBlock, String>; pub fn build_simple_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["x = 10".to_string()], }); blocks.insert("entry".to_string(), entry); let cond = cfg.add_node(BasicBlock { id: "cond".to_string(), instructions: vec!["if x > 5".to_string()], }); blocks.insert("cond".to_string(), cond); let then_block = cfg.add_node(BasicBlock { id: "then".to_string(), instructions: vec!["x = x * 2".to_string()], }); blocks.insert("then".to_string(), then_block); let else_block = cfg.add_node(BasicBlock { id: "else".to_string(), instructions: vec!["x = x + 1".to_string()], }); blocks.insert("else".to_string(), else_block); let merge = cfg.add_node(BasicBlock { id: "merge".to_string(), instructions: vec!["print(x)".to_string()], }); blocks.insert("merge".to_string(), merge); cfg.add_edge(entry, cond, "fallthrough".to_string()); cfg.add_edge(cond, then_block, "true".to_string()); cfg.add_edge(cond, else_block, "false".to_string()); cfg.add_edge(then_block, merge, "fallthrough".to_string()); cfg.add_edge(else_block, merge, "fallthrough".to_string()); (cfg, blocks) } pub fn build_loop_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["i = 0".to_string()], }); blocks.insert("entry".to_string(), entry); let loop_header = cfg.add_node(BasicBlock { id: "loop_header".to_string(), instructions: vec!["if i < 10".to_string()], }); blocks.insert("loop_header".to_string(), loop_header); let loop_body = cfg.add_node(BasicBlock { id: "loop_body".to_string(), instructions: vec!["sum += i".to_string(), "i += 1".to_string()], }); blocks.insert("loop_body".to_string(), loop_body); let exit = cfg.add_node(BasicBlock { id: "exit".to_string(), instructions: vec!["return sum".to_string()], }); blocks.insert("exit".to_string(), exit); cfg.add_edge(entry, loop_header, "fallthrough".to_string()); cfg.add_edge(loop_header, loop_body, "true".to_string()); cfg.add_edge(loop_header, exit, "false".to_string()); cfg.add_edge(loop_body, loop_header, "backedge".to_string()); (cfg, blocks) } pub fn perform_dfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut dfs = Dfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = dfs.next(&graph) { visited.push(node); } visited } pub fn perform_bfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut bfs = Bfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = bfs.next(&graph) { visited.push(node); } visited } pub fn find_dominators( graph: &ControlFlowGraph, entry: NodeIndex, ) -> HashMap<NodeIndex, NodeIndex> { let dom_tree = dominators::simple_fast(&graph, entry); let mut dom_map = HashMap::new(); for node in graph.node_indices() { if let Some(idom) = dom_tree.immediate_dominator(node) { if idom != node { dom_map.insert(node, idom); } } } dom_map } pub fn topological_ordering(graph: &ControlFlowGraph) -> Option<Vec<NodeIndex>> { toposort(&graph, None).ok() } pub fn print_cfg_dot(graph: &ControlFlowGraph) -> String { format!("{:?}", Dot::with_config(&graph, &[Config::EdgeNoLabel])) } #[derive(Debug, Clone)] pub struct CallGraphNode { pub name: String, pub is_recursive: bool, } pub type CallGraph = DiGraph<CallGraphNode, ()>; pub fn build_call_graph() -> (CallGraph, HashMap<String, NodeIndex>) { let mut cg = CallGraph::new(); let mut funcs = HashMap::new(); let main = cg.add_node(CallGraphNode { name: "main".to_string(), is_recursive: false, }); funcs.insert("main".to_string(), main); let parse = cg.add_node(CallGraphNode { name: "parse".to_string(), is_recursive: false, }); funcs.insert("parse".to_string(), parse); let analyze = cg.add_node(CallGraphNode { name: "analyze".to_string(), is_recursive: false, }); funcs.insert("analyze".to_string(), analyze); let codegen = cg.add_node(CallGraphNode { name: "codegen".to_string(), is_recursive: false, }); funcs.insert("codegen".to_string(), codegen); let optimize = cg.add_node(CallGraphNode { name: "optimize".to_string(), is_recursive: true, }); funcs.insert("optimize".to_string(), optimize); cg.add_edge(main, parse, ()); cg.add_edge(main, analyze, ()); cg.add_edge(main, codegen, ()); cg.add_edge(analyze, optimize, ()); cg.add_edge(optimize, optimize, ()); (cg, funcs) } pub fn find_recursive_functions(graph: &CallGraph) -> Vec<NodeIndex> { let mut recursive = Vec::new(); for node in graph.node_indices() { // Check for self-loops (direct recursion) if graph.find_edge(node, node).is_some() { recursive.push(node); continue; } // Check for indirect recursion (cycles that include this node) let mut dfs = Dfs::new(graph, node); while let Some(visited) = dfs.next(graph) { if visited != node && graph.find_edge(visited, node).is_some() { recursive.push(node); break; } } } recursive } pub fn reverse_postorder(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let reversed = Reversed(&graph); let mut dfs = Dfs::new(&reversed, entry); let mut stack = Vec::new(); while let Some(node) = dfs.next(&reversed) { stack.push(node); } stack.reverse(); stack } #[cfg(test)] mod tests { use super::*; #[test] fn test_cfg_construction() { let (cfg, blocks) = build_simple_cfg(); assert_eq!(cfg.node_count(), 5); assert_eq!(cfg.edge_count(), 5); assert!(blocks.contains_key("entry")); assert!(blocks.contains_key("merge")); } #[test] fn test_dfs_traversal() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let visited = perform_dfs(&cfg, entry); assert_eq!(visited.len(), 5); } #[test] fn test_dominators() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let doms = find_dominators(&cfg, entry); let cond = blocks["cond"]; let merge = blocks["merge"]; assert_eq!(doms[&cond], entry); assert_eq!(doms[&merge], cond); } #[test] fn test_loop_detection() { let (cfg, _) = build_loop_cfg(); let topo_order = topological_ordering(&cfg); assert!(topo_order.is_none()); } #[test] fn test_recursive_functions() { let (cg, _funcs) = build_call_graph(); let recursive = find_recursive_functions(&cg); assert_eq!(recursive.len(), 1); assert_eq!(cg[recursive[0]].name, "optimize"); } } pub fn detect_unreachable_code(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let mut unreachable = Vec::new(); for node in graph.node_indices() { if !has_path_connecting(&graph, entry, node, None) { unreachable.push(node); } } unreachable } }
This uses path connectivity to find nodes that cannot be reached from the entry point. Such blocks can be safely removed during optimization.
Call Graphs
Call graphs represent the calling relationships between functions in a program:
#![allow(unused)] fn main() { use std::collections::HashMap; use petgraph::algo::{dominators, has_path_connecting, toposort}; use petgraph::dot::{Config, Dot}; use petgraph::graph::{DiGraph, NodeIndex}; use petgraph::visit::{Bfs, Dfs, Reversed}; #[derive(Debug, Clone)] pub struct BasicBlock { pub id: String, pub instructions: Vec<String>, } pub type ControlFlowGraph = DiGraph<BasicBlock, String>; pub fn build_simple_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["x = 10".to_string()], }); blocks.insert("entry".to_string(), entry); let cond = cfg.add_node(BasicBlock { id: "cond".to_string(), instructions: vec!["if x > 5".to_string()], }); blocks.insert("cond".to_string(), cond); let then_block = cfg.add_node(BasicBlock { id: "then".to_string(), instructions: vec!["x = x * 2".to_string()], }); blocks.insert("then".to_string(), then_block); let else_block = cfg.add_node(BasicBlock { id: "else".to_string(), instructions: vec!["x = x + 1".to_string()], }); blocks.insert("else".to_string(), else_block); let merge = cfg.add_node(BasicBlock { id: "merge".to_string(), instructions: vec!["print(x)".to_string()], }); blocks.insert("merge".to_string(), merge); cfg.add_edge(entry, cond, "fallthrough".to_string()); cfg.add_edge(cond, then_block, "true".to_string()); cfg.add_edge(cond, else_block, "false".to_string()); cfg.add_edge(then_block, merge, "fallthrough".to_string()); cfg.add_edge(else_block, merge, "fallthrough".to_string()); (cfg, blocks) } pub fn build_loop_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["i = 0".to_string()], }); blocks.insert("entry".to_string(), entry); let loop_header = cfg.add_node(BasicBlock { id: "loop_header".to_string(), instructions: vec!["if i < 10".to_string()], }); blocks.insert("loop_header".to_string(), loop_header); let loop_body = cfg.add_node(BasicBlock { id: "loop_body".to_string(), instructions: vec!["sum += i".to_string(), "i += 1".to_string()], }); blocks.insert("loop_body".to_string(), loop_body); let exit = cfg.add_node(BasicBlock { id: "exit".to_string(), instructions: vec!["return sum".to_string()], }); blocks.insert("exit".to_string(), exit); cfg.add_edge(entry, loop_header, "fallthrough".to_string()); cfg.add_edge(loop_header, loop_body, "true".to_string()); cfg.add_edge(loop_header, exit, "false".to_string()); cfg.add_edge(loop_body, loop_header, "backedge".to_string()); (cfg, blocks) } pub fn perform_dfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut dfs = Dfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = dfs.next(&graph) { visited.push(node); } visited } pub fn perform_bfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut bfs = Bfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = bfs.next(&graph) { visited.push(node); } visited } pub fn find_dominators( graph: &ControlFlowGraph, entry: NodeIndex, ) -> HashMap<NodeIndex, NodeIndex> { let dom_tree = dominators::simple_fast(&graph, entry); let mut dom_map = HashMap::new(); for node in graph.node_indices() { if let Some(idom) = dom_tree.immediate_dominator(node) { if idom != node { dom_map.insert(node, idom); } } } dom_map } pub fn detect_unreachable_code(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let mut unreachable = Vec::new(); for node in graph.node_indices() { if !has_path_connecting(&graph, entry, node, None) { unreachable.push(node); } } unreachable } pub fn topological_ordering(graph: &ControlFlowGraph) -> Option<Vec<NodeIndex>> { toposort(&graph, None).ok() } pub fn print_cfg_dot(graph: &ControlFlowGraph) -> String { format!("{:?}", Dot::with_config(&graph, &[Config::EdgeNoLabel])) } #[derive(Debug, Clone)] pub struct CallGraphNode { pub name: String, pub is_recursive: bool, } pub type CallGraph = DiGraph<CallGraphNode, ()>; pub fn find_recursive_functions(graph: &CallGraph) -> Vec<NodeIndex> { let mut recursive = Vec::new(); for node in graph.node_indices() { // Check for self-loops (direct recursion) if graph.find_edge(node, node).is_some() { recursive.push(node); continue; } // Check for indirect recursion (cycles that include this node) let mut dfs = Dfs::new(graph, node); while let Some(visited) = dfs.next(graph) { if visited != node && graph.find_edge(visited, node).is_some() { recursive.push(node); break; } } } recursive } pub fn reverse_postorder(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let reversed = Reversed(&graph); let mut dfs = Dfs::new(&reversed, entry); let mut stack = Vec::new(); while let Some(node) = dfs.next(&reversed) { stack.push(node); } stack.reverse(); stack } #[cfg(test)] mod tests { use super::*; #[test] fn test_cfg_construction() { let (cfg, blocks) = build_simple_cfg(); assert_eq!(cfg.node_count(), 5); assert_eq!(cfg.edge_count(), 5); assert!(blocks.contains_key("entry")); assert!(blocks.contains_key("merge")); } #[test] fn test_dfs_traversal() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let visited = perform_dfs(&cfg, entry); assert_eq!(visited.len(), 5); } #[test] fn test_dominators() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let doms = find_dominators(&cfg, entry); let cond = blocks["cond"]; let merge = blocks["merge"]; assert_eq!(doms[&cond], entry); assert_eq!(doms[&merge], cond); } #[test] fn test_loop_detection() { let (cfg, _) = build_loop_cfg(); let topo_order = topological_ordering(&cfg); assert!(topo_order.is_none()); } #[test] fn test_recursive_functions() { let (cg, _funcs) = build_call_graph(); let recursive = find_recursive_functions(&cg); assert_eq!(recursive.len(), 1); assert_eq!(cg[recursive[0]].name, "optimize"); } } pub fn build_call_graph() -> (CallGraph, HashMap<String, NodeIndex>) { let mut cg = CallGraph::new(); let mut funcs = HashMap::new(); let main = cg.add_node(CallGraphNode { name: "main".to_string(), is_recursive: false, }); funcs.insert("main".to_string(), main); let parse = cg.add_node(CallGraphNode { name: "parse".to_string(), is_recursive: false, }); funcs.insert("parse".to_string(), parse); let analyze = cg.add_node(CallGraphNode { name: "analyze".to_string(), is_recursive: false, }); funcs.insert("analyze".to_string(), analyze); let codegen = cg.add_node(CallGraphNode { name: "codegen".to_string(), is_recursive: false, }); funcs.insert("codegen".to_string(), codegen); let optimize = cg.add_node(CallGraphNode { name: "optimize".to_string(), is_recursive: true, }); funcs.insert("optimize".to_string(), optimize); cg.add_edge(main, parse, ()); cg.add_edge(main, analyze, ()); cg.add_edge(main, codegen, ()); cg.add_edge(analyze, optimize, ()); cg.add_edge(optimize, optimize, ()); (cg, funcs) } }
Call graphs are essential for interprocedural analysis, inlining decisions, and detecting recursive functions:
#![allow(unused)] fn main() { use std::collections::HashMap; use petgraph::algo::{dominators, has_path_connecting, toposort}; use petgraph::dot::{Config, Dot}; use petgraph::graph::{DiGraph, NodeIndex}; use petgraph::visit::{Bfs, Dfs, Reversed}; #[derive(Debug, Clone)] pub struct BasicBlock { pub id: String, pub instructions: Vec<String>, } pub type ControlFlowGraph = DiGraph<BasicBlock, String>; pub fn build_simple_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["x = 10".to_string()], }); blocks.insert("entry".to_string(), entry); let cond = cfg.add_node(BasicBlock { id: "cond".to_string(), instructions: vec!["if x > 5".to_string()], }); blocks.insert("cond".to_string(), cond); let then_block = cfg.add_node(BasicBlock { id: "then".to_string(), instructions: vec!["x = x * 2".to_string()], }); blocks.insert("then".to_string(), then_block); let else_block = cfg.add_node(BasicBlock { id: "else".to_string(), instructions: vec!["x = x + 1".to_string()], }); blocks.insert("else".to_string(), else_block); let merge = cfg.add_node(BasicBlock { id: "merge".to_string(), instructions: vec!["print(x)".to_string()], }); blocks.insert("merge".to_string(), merge); cfg.add_edge(entry, cond, "fallthrough".to_string()); cfg.add_edge(cond, then_block, "true".to_string()); cfg.add_edge(cond, else_block, "false".to_string()); cfg.add_edge(then_block, merge, "fallthrough".to_string()); cfg.add_edge(else_block, merge, "fallthrough".to_string()); (cfg, blocks) } pub fn build_loop_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["i = 0".to_string()], }); blocks.insert("entry".to_string(), entry); let loop_header = cfg.add_node(BasicBlock { id: "loop_header".to_string(), instructions: vec!["if i < 10".to_string()], }); blocks.insert("loop_header".to_string(), loop_header); let loop_body = cfg.add_node(BasicBlock { id: "loop_body".to_string(), instructions: vec!["sum += i".to_string(), "i += 1".to_string()], }); blocks.insert("loop_body".to_string(), loop_body); let exit = cfg.add_node(BasicBlock { id: "exit".to_string(), instructions: vec!["return sum".to_string()], }); blocks.insert("exit".to_string(), exit); cfg.add_edge(entry, loop_header, "fallthrough".to_string()); cfg.add_edge(loop_header, loop_body, "true".to_string()); cfg.add_edge(loop_header, exit, "false".to_string()); cfg.add_edge(loop_body, loop_header, "backedge".to_string()); (cfg, blocks) } pub fn perform_dfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut dfs = Dfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = dfs.next(&graph) { visited.push(node); } visited } pub fn perform_bfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut bfs = Bfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = bfs.next(&graph) { visited.push(node); } visited } pub fn find_dominators( graph: &ControlFlowGraph, entry: NodeIndex, ) -> HashMap<NodeIndex, NodeIndex> { let dom_tree = dominators::simple_fast(&graph, entry); let mut dom_map = HashMap::new(); for node in graph.node_indices() { if let Some(idom) = dom_tree.immediate_dominator(node) { if idom != node { dom_map.insert(node, idom); } } } dom_map } pub fn detect_unreachable_code(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let mut unreachable = Vec::new(); for node in graph.node_indices() { if !has_path_connecting(&graph, entry, node, None) { unreachable.push(node); } } unreachable } pub fn topological_ordering(graph: &ControlFlowGraph) -> Option<Vec<NodeIndex>> { toposort(&graph, None).ok() } pub fn print_cfg_dot(graph: &ControlFlowGraph) -> String { format!("{:?}", Dot::with_config(&graph, &[Config::EdgeNoLabel])) } #[derive(Debug, Clone)] pub struct CallGraphNode { pub name: String, pub is_recursive: bool, } pub type CallGraph = DiGraph<CallGraphNode, ()>; pub fn build_call_graph() -> (CallGraph, HashMap<String, NodeIndex>) { let mut cg = CallGraph::new(); let mut funcs = HashMap::new(); let main = cg.add_node(CallGraphNode { name: "main".to_string(), is_recursive: false, }); funcs.insert("main".to_string(), main); let parse = cg.add_node(CallGraphNode { name: "parse".to_string(), is_recursive: false, }); funcs.insert("parse".to_string(), parse); let analyze = cg.add_node(CallGraphNode { name: "analyze".to_string(), is_recursive: false, }); funcs.insert("analyze".to_string(), analyze); let codegen = cg.add_node(CallGraphNode { name: "codegen".to_string(), is_recursive: false, }); funcs.insert("codegen".to_string(), codegen); let optimize = cg.add_node(CallGraphNode { name: "optimize".to_string(), is_recursive: true, }); funcs.insert("optimize".to_string(), optimize); cg.add_edge(main, parse, ()); cg.add_edge(main, analyze, ()); cg.add_edge(main, codegen, ()); cg.add_edge(analyze, optimize, ()); cg.add_edge(optimize, optimize, ()); (cg, funcs) } pub fn reverse_postorder(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let reversed = Reversed(&graph); let mut dfs = Dfs::new(&reversed, entry); let mut stack = Vec::new(); while let Some(node) = dfs.next(&reversed) { stack.push(node); } stack.reverse(); stack } #[cfg(test)] mod tests { use super::*; #[test] fn test_cfg_construction() { let (cfg, blocks) = build_simple_cfg(); assert_eq!(cfg.node_count(), 5); assert_eq!(cfg.edge_count(), 5); assert!(blocks.contains_key("entry")); assert!(blocks.contains_key("merge")); } #[test] fn test_dfs_traversal() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let visited = perform_dfs(&cfg, entry); assert_eq!(visited.len(), 5); } #[test] fn test_dominators() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let doms = find_dominators(&cfg, entry); let cond = blocks["cond"]; let merge = blocks["merge"]; assert_eq!(doms[&cond], entry); assert_eq!(doms[&merge], cond); } #[test] fn test_loop_detection() { let (cfg, _) = build_loop_cfg(); let topo_order = topological_ordering(&cfg); assert!(topo_order.is_none()); } #[test] fn test_recursive_functions() { let (cg, _funcs) = build_call_graph(); let recursive = find_recursive_functions(&cg); assert_eq!(recursive.len(), 1); assert_eq!(cg[recursive[0]].name, "optimize"); } } pub fn find_recursive_functions(graph: &CallGraph) -> Vec<NodeIndex> { let mut recursive = Vec::new(); for node in graph.node_indices() { // Check for self-loops (direct recursion) if graph.find_edge(node, node).is_some() { recursive.push(node); continue; } // Check for indirect recursion (cycles that include this node) let mut dfs = Dfs::new(graph, node); while let Some(visited) = dfs.next(graph) { if visited != node && graph.find_edge(visited, node).is_some() { recursive.push(node); break; } } } recursive } }
Reverse Postorder
Reverse postorder is the standard iteration order for forward dataflow analyses:
#![allow(unused)] fn main() { use std::collections::HashMap; use petgraph::algo::{dominators, has_path_connecting, toposort}; use petgraph::dot::{Config, Dot}; use petgraph::graph::{DiGraph, NodeIndex}; use petgraph::visit::{Bfs, Dfs, Reversed}; #[derive(Debug, Clone)] pub struct BasicBlock { pub id: String, pub instructions: Vec<String>, } pub type ControlFlowGraph = DiGraph<BasicBlock, String>; pub fn build_simple_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["x = 10".to_string()], }); blocks.insert("entry".to_string(), entry); let cond = cfg.add_node(BasicBlock { id: "cond".to_string(), instructions: vec!["if x > 5".to_string()], }); blocks.insert("cond".to_string(), cond); let then_block = cfg.add_node(BasicBlock { id: "then".to_string(), instructions: vec!["x = x * 2".to_string()], }); blocks.insert("then".to_string(), then_block); let else_block = cfg.add_node(BasicBlock { id: "else".to_string(), instructions: vec!["x = x + 1".to_string()], }); blocks.insert("else".to_string(), else_block); let merge = cfg.add_node(BasicBlock { id: "merge".to_string(), instructions: vec!["print(x)".to_string()], }); blocks.insert("merge".to_string(), merge); cfg.add_edge(entry, cond, "fallthrough".to_string()); cfg.add_edge(cond, then_block, "true".to_string()); cfg.add_edge(cond, else_block, "false".to_string()); cfg.add_edge(then_block, merge, "fallthrough".to_string()); cfg.add_edge(else_block, merge, "fallthrough".to_string()); (cfg, blocks) } pub fn build_loop_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["i = 0".to_string()], }); blocks.insert("entry".to_string(), entry); let loop_header = cfg.add_node(BasicBlock { id: "loop_header".to_string(), instructions: vec!["if i < 10".to_string()], }); blocks.insert("loop_header".to_string(), loop_header); let loop_body = cfg.add_node(BasicBlock { id: "loop_body".to_string(), instructions: vec!["sum += i".to_string(), "i += 1".to_string()], }); blocks.insert("loop_body".to_string(), loop_body); let exit = cfg.add_node(BasicBlock { id: "exit".to_string(), instructions: vec!["return sum".to_string()], }); blocks.insert("exit".to_string(), exit); cfg.add_edge(entry, loop_header, "fallthrough".to_string()); cfg.add_edge(loop_header, loop_body, "true".to_string()); cfg.add_edge(loop_header, exit, "false".to_string()); cfg.add_edge(loop_body, loop_header, "backedge".to_string()); (cfg, blocks) } pub fn perform_dfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut dfs = Dfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = dfs.next(&graph) { visited.push(node); } visited } pub fn perform_bfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut bfs = Bfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = bfs.next(&graph) { visited.push(node); } visited } pub fn find_dominators( graph: &ControlFlowGraph, entry: NodeIndex, ) -> HashMap<NodeIndex, NodeIndex> { let dom_tree = dominators::simple_fast(&graph, entry); let mut dom_map = HashMap::new(); for node in graph.node_indices() { if let Some(idom) = dom_tree.immediate_dominator(node) { if idom != node { dom_map.insert(node, idom); } } } dom_map } pub fn detect_unreachable_code(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let mut unreachable = Vec::new(); for node in graph.node_indices() { if !has_path_connecting(&graph, entry, node, None) { unreachable.push(node); } } unreachable } pub fn topological_ordering(graph: &ControlFlowGraph) -> Option<Vec<NodeIndex>> { toposort(&graph, None).ok() } pub fn print_cfg_dot(graph: &ControlFlowGraph) -> String { format!("{:?}", Dot::with_config(&graph, &[Config::EdgeNoLabel])) } #[derive(Debug, Clone)] pub struct CallGraphNode { pub name: String, pub is_recursive: bool, } pub type CallGraph = DiGraph<CallGraphNode, ()>; pub fn build_call_graph() -> (CallGraph, HashMap<String, NodeIndex>) { let mut cg = CallGraph::new(); let mut funcs = HashMap::new(); let main = cg.add_node(CallGraphNode { name: "main".to_string(), is_recursive: false, }); funcs.insert("main".to_string(), main); let parse = cg.add_node(CallGraphNode { name: "parse".to_string(), is_recursive: false, }); funcs.insert("parse".to_string(), parse); let analyze = cg.add_node(CallGraphNode { name: "analyze".to_string(), is_recursive: false, }); funcs.insert("analyze".to_string(), analyze); let codegen = cg.add_node(CallGraphNode { name: "codegen".to_string(), is_recursive: false, }); funcs.insert("codegen".to_string(), codegen); let optimize = cg.add_node(CallGraphNode { name: "optimize".to_string(), is_recursive: true, }); funcs.insert("optimize".to_string(), optimize); cg.add_edge(main, parse, ()); cg.add_edge(main, analyze, ()); cg.add_edge(main, codegen, ()); cg.add_edge(analyze, optimize, ()); cg.add_edge(optimize, optimize, ()); (cg, funcs) } pub fn find_recursive_functions(graph: &CallGraph) -> Vec<NodeIndex> { let mut recursive = Vec::new(); for node in graph.node_indices() { // Check for self-loops (direct recursion) if graph.find_edge(node, node).is_some() { recursive.push(node); continue; } // Check for indirect recursion (cycles that include this node) let mut dfs = Dfs::new(graph, node); while let Some(visited) = dfs.next(graph) { if visited != node && graph.find_edge(visited, node).is_some() { recursive.push(node); break; } } } recursive } #[cfg(test)] mod tests { use super::*; #[test] fn test_cfg_construction() { let (cfg, blocks) = build_simple_cfg(); assert_eq!(cfg.node_count(), 5); assert_eq!(cfg.edge_count(), 5); assert!(blocks.contains_key("entry")); assert!(blocks.contains_key("merge")); } #[test] fn test_dfs_traversal() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let visited = perform_dfs(&cfg, entry); assert_eq!(visited.len(), 5); } #[test] fn test_dominators() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let doms = find_dominators(&cfg, entry); let cond = blocks["cond"]; let merge = blocks["merge"]; assert_eq!(doms[&cond], entry); assert_eq!(doms[&merge], cond); } #[test] fn test_loop_detection() { let (cfg, _) = build_loop_cfg(); let topo_order = topological_ordering(&cfg); assert!(topo_order.is_none()); } #[test] fn test_recursive_functions() { let (cg, _funcs) = build_call_graph(); let recursive = find_recursive_functions(&cg); assert_eq!(recursive.len(), 1); assert_eq!(cg[recursive[0]].name, "optimize"); } } pub fn reverse_postorder(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let reversed = Reversed(&graph); let mut dfs = Dfs::new(&reversed, entry); let mut stack = Vec::new(); while let Some(node) = dfs.next(&reversed) { stack.push(node); } stack.reverse(); stack } }
This ordering ensures that when visiting a node, most of its predecessors have already been visited, leading to faster convergence in iterative dataflow algorithms.
Graph Visualization
For debugging and understanding complex graphs, petgraph can generate DOT format output:
#![allow(unused)] fn main() { use std::collections::HashMap; use petgraph::algo::{dominators, has_path_connecting, toposort}; use petgraph::dot::{Config, Dot}; use petgraph::graph::{DiGraph, NodeIndex}; use petgraph::visit::{Bfs, Dfs, Reversed}; #[derive(Debug, Clone)] pub struct BasicBlock { pub id: String, pub instructions: Vec<String>, } pub type ControlFlowGraph = DiGraph<BasicBlock, String>; pub fn build_simple_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["x = 10".to_string()], }); blocks.insert("entry".to_string(), entry); let cond = cfg.add_node(BasicBlock { id: "cond".to_string(), instructions: vec!["if x > 5".to_string()], }); blocks.insert("cond".to_string(), cond); let then_block = cfg.add_node(BasicBlock { id: "then".to_string(), instructions: vec!["x = x * 2".to_string()], }); blocks.insert("then".to_string(), then_block); let else_block = cfg.add_node(BasicBlock { id: "else".to_string(), instructions: vec!["x = x + 1".to_string()], }); blocks.insert("else".to_string(), else_block); let merge = cfg.add_node(BasicBlock { id: "merge".to_string(), instructions: vec!["print(x)".to_string()], }); blocks.insert("merge".to_string(), merge); cfg.add_edge(entry, cond, "fallthrough".to_string()); cfg.add_edge(cond, then_block, "true".to_string()); cfg.add_edge(cond, else_block, "false".to_string()); cfg.add_edge(then_block, merge, "fallthrough".to_string()); cfg.add_edge(else_block, merge, "fallthrough".to_string()); (cfg, blocks) } pub fn build_loop_cfg() -> (ControlFlowGraph, HashMap<String, NodeIndex>) { let mut cfg = ControlFlowGraph::new(); let mut blocks = HashMap::new(); let entry = cfg.add_node(BasicBlock { id: "entry".to_string(), instructions: vec!["i = 0".to_string()], }); blocks.insert("entry".to_string(), entry); let loop_header = cfg.add_node(BasicBlock { id: "loop_header".to_string(), instructions: vec!["if i < 10".to_string()], }); blocks.insert("loop_header".to_string(), loop_header); let loop_body = cfg.add_node(BasicBlock { id: "loop_body".to_string(), instructions: vec!["sum += i".to_string(), "i += 1".to_string()], }); blocks.insert("loop_body".to_string(), loop_body); let exit = cfg.add_node(BasicBlock { id: "exit".to_string(), instructions: vec!["return sum".to_string()], }); blocks.insert("exit".to_string(), exit); cfg.add_edge(entry, loop_header, "fallthrough".to_string()); cfg.add_edge(loop_header, loop_body, "true".to_string()); cfg.add_edge(loop_header, exit, "false".to_string()); cfg.add_edge(loop_body, loop_header, "backedge".to_string()); (cfg, blocks) } pub fn perform_dfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut dfs = Dfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = dfs.next(&graph) { visited.push(node); } visited } pub fn perform_bfs(graph: &ControlFlowGraph, start: NodeIndex) -> Vec<NodeIndex> { let mut bfs = Bfs::new(&graph, start); let mut visited = Vec::new(); while let Some(node) = bfs.next(&graph) { visited.push(node); } visited } pub fn find_dominators( graph: &ControlFlowGraph, entry: NodeIndex, ) -> HashMap<NodeIndex, NodeIndex> { let dom_tree = dominators::simple_fast(&graph, entry); let mut dom_map = HashMap::new(); for node in graph.node_indices() { if let Some(idom) = dom_tree.immediate_dominator(node) { if idom != node { dom_map.insert(node, idom); } } } dom_map } pub fn detect_unreachable_code(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let mut unreachable = Vec::new(); for node in graph.node_indices() { if !has_path_connecting(&graph, entry, node, None) { unreachable.push(node); } } unreachable } pub fn topological_ordering(graph: &ControlFlowGraph) -> Option<Vec<NodeIndex>> { toposort(&graph, None).ok() } #[derive(Debug, Clone)] pub struct CallGraphNode { pub name: String, pub is_recursive: bool, } pub type CallGraph = DiGraph<CallGraphNode, ()>; pub fn build_call_graph() -> (CallGraph, HashMap<String, NodeIndex>) { let mut cg = CallGraph::new(); let mut funcs = HashMap::new(); let main = cg.add_node(CallGraphNode { name: "main".to_string(), is_recursive: false, }); funcs.insert("main".to_string(), main); let parse = cg.add_node(CallGraphNode { name: "parse".to_string(), is_recursive: false, }); funcs.insert("parse".to_string(), parse); let analyze = cg.add_node(CallGraphNode { name: "analyze".to_string(), is_recursive: false, }); funcs.insert("analyze".to_string(), analyze); let codegen = cg.add_node(CallGraphNode { name: "codegen".to_string(), is_recursive: false, }); funcs.insert("codegen".to_string(), codegen); let optimize = cg.add_node(CallGraphNode { name: "optimize".to_string(), is_recursive: true, }); funcs.insert("optimize".to_string(), optimize); cg.add_edge(main, parse, ()); cg.add_edge(main, analyze, ()); cg.add_edge(main, codegen, ()); cg.add_edge(analyze, optimize, ()); cg.add_edge(optimize, optimize, ()); (cg, funcs) } pub fn find_recursive_functions(graph: &CallGraph) -> Vec<NodeIndex> { let mut recursive = Vec::new(); for node in graph.node_indices() { // Check for self-loops (direct recursion) if graph.find_edge(node, node).is_some() { recursive.push(node); continue; } // Check for indirect recursion (cycles that include this node) let mut dfs = Dfs::new(graph, node); while let Some(visited) = dfs.next(graph) { if visited != node && graph.find_edge(visited, node).is_some() { recursive.push(node); break; } } } recursive } pub fn reverse_postorder(graph: &ControlFlowGraph, entry: NodeIndex) -> Vec<NodeIndex> { let reversed = Reversed(&graph); let mut dfs = Dfs::new(&reversed, entry); let mut stack = Vec::new(); while let Some(node) = dfs.next(&reversed) { stack.push(node); } stack.reverse(); stack } #[cfg(test)] mod tests { use super::*; #[test] fn test_cfg_construction() { let (cfg, blocks) = build_simple_cfg(); assert_eq!(cfg.node_count(), 5); assert_eq!(cfg.edge_count(), 5); assert!(blocks.contains_key("entry")); assert!(blocks.contains_key("merge")); } #[test] fn test_dfs_traversal() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let visited = perform_dfs(&cfg, entry); assert_eq!(visited.len(), 5); } #[test] fn test_dominators() { let (cfg, blocks) = build_simple_cfg(); let entry = blocks["entry"]; let doms = find_dominators(&cfg, entry); let cond = blocks["cond"]; let merge = blocks["merge"]; assert_eq!(doms[&cond], entry); assert_eq!(doms[&merge], cond); } #[test] fn test_loop_detection() { let (cfg, _) = build_loop_cfg(); let topo_order = topological_ordering(&cfg); assert!(topo_order.is_none()); } #[test] fn test_recursive_functions() { let (cg, _funcs) = build_call_graph(); let recursive = find_recursive_functions(&cg); assert_eq!(recursive.len(), 1); assert_eq!(cg[recursive[0]].name, "optimize"); } } pub fn print_cfg_dot(graph: &ControlFlowGraph) -> String { format!("{:?}", Dot::with_config(&graph, &[Config::EdgeNoLabel])) } }
The DOT output can be rendered with Graphviz to visualize the graph structure, which is invaluable for debugging compiler passes.
Best Practices
Choose the appropriate graph type for your use case. Use DiGraph
for directed graphs like CFGs and call graphs. Use UnGraph
for undirected graphs like interference graphs in register allocation.
Node indices are not stable across node removals. If you need stable identifiers, store them in the node weight or use a separate mapping. For large graphs, consider using StableGraph
which maintains indices across removals at a small performance cost.
Many compiler algorithms benefit from caching graph properties. For example, dominance information should be computed once and reused rather than recalculated for each query. Similarly, strongly connected components and topological orderings can be cached.
For performance-critical paths, be aware that some algorithms have different implementations with different trade-offs. The algo
module provides both simple and optimized versions of many algorithms.
The visitor traits allow you to implement custom traversals efficiently. Use DfsPostOrder
for postorder traversals needed in many analyses. The visit
module provides building blocks for implementing sophisticated graph algorithms.
cranelift
Cranelift is a fast, secure code generator designed as a backend for WebAssembly and programming language implementations. Unlike traditional compiler backends like LLVM, Cranelift prioritizes compilation speed and simplicity over maximum runtime performance, making it ideal for JIT compilation scenarios. The library provides a complete infrastructure for generating machine code from a low-level intermediate representation, handling register allocation, instruction selection, and machine code emission across multiple architectures.
The core design philosophy of Cranelift centers on predictable compilation time and memory usage. It achieves fast compilation through a streamlined architecture that avoids expensive optimization passes, while still producing reasonably efficient code. This makes Cranelift particularly suitable for scenarios where compilation happens at runtime, such as WebAssembly engines, database query compilers, and language virtual machines.
Core Architecture
#![allow(unused)] fn main() { use std::collections::HashMap; use cranelift::codegen::ir::types::*; use cranelift::codegen::ir::{AbiParam, Function, InstBuilder, Signature, UserFuncName}; use cranelift::codegen::settings::{self, Configurable}; use cranelift::codegen::verifier::verify_function; use cranelift::codegen::Context; use cranelift::frontend::{FunctionBuilder, FunctionBuilderContext, Variable}; use cranelift::prelude::*; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{FuncId, Linkage, Module}; impl JitCompiler { pub fn new() -> Self { let mut flag_builder = settings::builder(); flag_builder.set("use_colocated_libcalls", "false").unwrap(); flag_builder.set("is_pic", "false").unwrap(); let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| { panic!("host machine is not supported: {}", msg); }); let isa = isa_builder .finish(settings::Flags::new(flag_builder)) .unwrap(); let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names()); builder.symbol_lookup_fn(Box::new(|name| { // Hook up external functions match name { "println_i64" => Some(println_i64 as *const u8), "println_f64" => Some(println_f64 as *const u8), _ => None, } })); let module = JITModule::new(builder); Self { builder_context: FunctionBuilderContext::new(), ctx: module.make_context(), module, } } pub fn compile_function( &mut self, name: &str, params: Vec<Type>, returns: Vec<Type>, build_fn: impl FnOnce(&mut FunctionBuilder, &[Variable]), ) -> Result<FuncId, String> { // Clear the context self.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), self.make_signature(params.clone(), returns.clone()), ); // Create the function builder let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); // Create variables for parameters let variables: Vec<Variable> = params.iter().map(|ty| builder.declare_var(*ty)).collect(); // Create entry block and append parameters let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); // Define parameters for (i, var) in variables.iter().enumerate() { let val = builder.block_params(entry_block)[i]; builder.def_var(*var, val); } // Call the user's function to build the body build_fn(&mut builder, &variables); // Finalize the function builder.finalize(); // Verify the function if let Err(errors) = verify_function(&self.ctx.func, self.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } // Define the function in the module let func_id = self .module .declare_function(name, Linkage::Export, &self.ctx.func.signature) .map_err(|e| e.to_string())?; self.module .define_function(func_id, &mut self.ctx) .map_err(|e| e.to_string())?; // Clear the context for next use self.module.clear_context(&mut self.ctx); Ok(func_id) } pub fn finalize(&mut self) { self.module.finalize_definitions().unwrap(); } pub fn get_function(&self, func_id: FuncId) -> *const u8 { self.module.get_finalized_function(func_id) } fn make_signature(&self, params: Vec<Type>, returns: Vec<Type>) -> Signature { let mut sig = self.module.make_signature(); for param in params { sig.params.push(AbiParam::new(param)); } for ret in returns { sig.returns.push(AbiParam::new(ret)); } sig } } impl Default for JitCompiler { fn default() -> Self { Self::new() } } extern "C" fn println_i64(x: i64) { println!("{}", x); } extern "C" fn println_f64(x: f64) { println!("{}", x); } /// Example: Compile a simple arithmetic function pub fn compile_add_function(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("add", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let sum = builder.ins().iadd(x, y); builder.ins().return_(&[sum]); }) } /// Example: Compile a factorial function pub fn compile_factorial(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("factorial", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Add block parameters builder.append_block_param(header_block, I64); // i builder.append_block_param(header_block, I64); // result // Entry: jump to header with initial values let one = builder.ins().iconst(I64, 1); builder.ins().jump(header_block, &[one.into(), one.into()]); // Header block: check if i <= n builder.switch_to_block(header_block); let i = builder.block_params(header_block)[0]; let result = builder.block_params(header_block)[1]; let n_val = builder.use_var(n); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, i, n_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body block: result *= i; i++ builder.switch_to_block(body_block); builder.seal_block(body_block); let new_result = builder.ins().imul(result, i); let new_i = builder.ins().iadd_imm(i, 1); builder .ins() .jump(header_block, &[new_i.into(), new_result.into()]); // Exit block: return result builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[result]); }) } /// Example: Compile a Fibonacci function pub fn compile_fibonacci(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("fibonacci", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let check_base = builder.create_block(); let recursive = builder.create_block(); let return_n = builder.create_block(); // Jump to check_base builder.ins().jump(check_base, &[]); // Check if n <= 1 builder.switch_to_block(check_base); let n_val = builder.use_var(n); let one = builder.ins().iconst(I64, 1); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, n_val, one); builder.ins().brif(cmp, return_n, &[], recursive, &[]); // Return n for base case builder.switch_to_block(return_n); builder.seal_block(return_n); builder.ins().return_(&[n_val]); // Recursive case: fib(n-1) + fib(n-2) builder.switch_to_block(recursive); builder.seal_block(recursive); builder.seal_block(check_base); // For simplicity, we'll compute iteratively let two = builder.ins().iconst(I64, 2); let a = builder.ins().iconst(I64, 0); let b = builder.ins().iconst(I64, 1); // Create loop blocks let loop_header = builder.create_block(); let loop_body = builder.create_block(); let loop_exit = builder.create_block(); builder.append_block_param(loop_header, I64); // counter builder.append_block_param(loop_header, I64); // a builder.append_block_param(loop_header, I64); // b builder .ins() .jump(loop_header, &[two.into(), a.into(), b.into()]); // Loop header: check if counter <= n builder.switch_to_block(loop_header); let counter = builder.block_params(loop_header)[0]; let curr_a = builder.block_params(loop_header)[1]; let curr_b = builder.block_params(loop_header)[2]; let cmp = builder .ins() .icmp(IntCC::SignedLessThanOrEqual, counter, n_val); builder.ins().brif(cmp, loop_body, &[], loop_exit, &[]); // Loop body: compute next fibonacci number builder.switch_to_block(loop_body); builder.seal_block(loop_body); let next_fib = builder.ins().iadd(curr_a, curr_b); let next_counter = builder.ins().iadd_imm(counter, 1); builder.ins().jump( loop_header, &[next_counter.into(), curr_b.into(), next_fib.into()], ); // Loop exit: return b builder.switch_to_block(loop_exit); builder.seal_block(loop_exit); builder.seal_block(loop_header); builder.ins().return_(&[curr_b]); }) } /// Example: Working with floating point pub fn compile_quadratic(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "quadratic", vec![F64, F64, F64, F64], vec![F64], |builder, params| { // f(x) = ax² + bx + c let x = builder.use_var(params[0]); let a = builder.use_var(params[1]); let b = builder.use_var(params[2]); let c = builder.use_var(params[3]); let x_squared = builder.ins().fmul(x, x); let ax_squared = builder.ins().fmul(a, x_squared); let bx = builder.ins().fmul(b, x); let ax_squared_plus_bx = builder.ins().fadd(ax_squared, bx); let result = builder.ins().fadd(ax_squared_plus_bx, c); builder.ins().return_(&[result]); }, ) } /// Example: Using external function calls pub fn compile_with_print(jit: &mut JitCompiler) -> Result<FuncId, String> { // First declare the external function let mut sig = jit.module.make_signature(); sig.params.push(AbiParam::new(I64)); let println_id = jit .module .declare_function("println_i64", Linkage::Import, &sig) .unwrap(); // Define the function let func_id = jit .module .declare_function( "print_sum", Linkage::Export, &jit.make_signature(vec![I64, I64], vec![]), ) .unwrap(); // Create function context jit.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), jit.make_signature(vec![I64, I64], vec![]), ); // Build the function { let mut builder = FunctionBuilder::new(&mut jit.ctx.func, &mut jit.builder_context); let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); let x = builder.declare_var(I64); let y = builder.declare_var(I64); let x_val = builder.block_params(entry_block)[0]; let y_val = builder.block_params(entry_block)[1]; builder.def_var(x, x_val); builder.def_var(y, y_val); let x_use = builder.use_var(x); let y_use = builder.use_var(y); let sum = builder.ins().iadd(x_use, y_use); // Declare the function reference for calling let println_ref = jit.module.declare_func_in_func(println_id, builder.func); builder.ins().call(println_ref, &[sum]); builder.ins().return_(&[]); builder.finalize(); } // Verify the function if let Err(errors) = verify_function(&jit.ctx.func, jit.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } jit.module .define_function(func_id, &mut jit.ctx) .map_err(|e| e.to_string())?; jit.module.clear_context(&mut jit.ctx); Ok(func_id) } /// Example: Control flow with multiple returns pub fn compile_max(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("max", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let then_block = builder.create_block(); let else_block = builder.create_block(); // if x > y let cmp = builder.ins().icmp(IntCC::SignedGreaterThan, x, y); builder.ins().brif(cmp, then_block, &[], else_block, &[]); // then: return x builder.switch_to_block(then_block); builder.seal_block(then_block); builder.ins().return_(&[x]); // else: return y builder.switch_to_block(else_block); builder.seal_block(else_block); builder.ins().return_(&[y]); }) } /// Example: Array/memory operations pub fn compile_sum_array(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "sum_array", vec![I64, I64], // ptr, len vec![I64], |builder, params| { let ptr = params[0]; let len = params[1]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Block parameters builder.append_block_param(header_block, I64); // index builder.append_block_param(header_block, I64); // sum builder.append_block_param(header_block, I64); // current_ptr // Initialize loop let zero = builder.ins().iconst(I64, 0); let ptr_val = builder.use_var(ptr); builder .ins() .jump(header_block, &[zero.into(), zero.into(), ptr_val.into()]); // Header: check if index < len builder.switch_to_block(header_block); let index = builder.block_params(header_block)[0]; let sum = builder.block_params(header_block)[1]; let current_ptr = builder.block_params(header_block)[2]; let len_val = builder.use_var(len); let cmp = builder.ins().icmp(IntCC::UnsignedLessThan, index, len_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body: load value and add to sum builder.switch_to_block(body_block); builder.seal_block(body_block); let flags = MemFlags::new(); let value = builder.ins().load(I64, flags, current_ptr, 0); let new_sum = builder.ins().iadd(sum, value); let new_index = builder.ins().iadd_imm(index, 1); let new_ptr = builder.ins().iadd_imm(current_ptr, 8); // 8 bytes for i64 builder.ins().jump( header_block, &[new_index.into(), new_sum.into(), new_ptr.into()], ); // Exit: return sum builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[sum]); }, ) } /// Example: Compile a simple expression evaluator #[derive(Debug, Clone)] pub enum Expr { Const(i64), Add(Box<Expr>, Box<Expr>), Sub(Box<Expr>, Box<Expr>), Mul(Box<Expr>, Box<Expr>), Var(usize), } impl Expr { pub fn compile(&self, builder: &mut FunctionBuilder, vars: &[Variable]) -> Value { match self { Expr::Const(n) => builder.ins().iconst(I64, *n), Expr::Add(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().iadd(a_val, b_val) } Expr::Sub(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().isub(a_val, b_val) } Expr::Mul(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().imul(a_val, b_val) } Expr::Var(idx) => builder.use_var(vars[*idx]), } } } pub fn compile_expression(jit: &mut JitCompiler, expr: Expr) -> Result<FuncId, String> { jit.compile_function( "eval_expr", vec![I64, I64], // two variables vec![I64], |builder, params| { let result = expr.compile(builder, params); builder.ins().return_(&[result]); }, ) } /// Symbol table for variable management pub struct SymbolTable { variables: HashMap<String, Variable>, next_var: usize, } impl SymbolTable { pub fn new() -> Self { Self { variables: HashMap::new(), next_var: 0, } } pub fn declare(&mut self, name: String, builder: &mut FunctionBuilder, ty: Type) -> Variable { let var = builder.declare_var(ty); self.variables.insert(name.clone(), var); self.next_var += 1; var } pub fn get(&self, name: &str) -> Option<Variable> { self.variables.get(name).copied() } } impl Default for SymbolTable { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_compile_add() { let mut jit = JitCompiler::new(); let func_id = compile_add_function(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let add_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(add_fn(2, 3), 5); assert_eq!(add_fn(10, -5), 5); } #[test] fn test_compile_factorial() { let mut jit = JitCompiler::new(); let func_id = compile_factorial(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let factorial_fn = unsafe { std::mem::transmute::<*const u8, fn(i64) -> i64>(code) }; assert_eq!(factorial_fn(0), 1); assert_eq!(factorial_fn(1), 1); assert_eq!(factorial_fn(5), 120); } #[test] fn test_compile_max() { let mut jit = JitCompiler::new(); let func_id = compile_max(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let max_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(max_fn(5, 3), 5); assert_eq!(max_fn(2, 8), 8); assert_eq!(max_fn(-5, -3), -3); } #[test] fn test_compile_expression() { let mut jit = JitCompiler::new(); // (x + 3) * (y - 2) let expr = Expr::Mul( Box::new(Expr::Add(Box::new(Expr::Var(0)), Box::new(Expr::Const(3)))), Box::new(Expr::Sub(Box::new(Expr::Var(1)), Box::new(Expr::Const(2)))), ); let func_id = compile_expression(&mut jit, expr).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let eval_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(eval_fn(5, 7), 40); // (5+3) * (7-2) = 8 * 5 = 40 assert_eq!(eval_fn(2, 4), 10); // (2+3) * (4-2) = 5 * 2 = 10 } #[test] fn test_quadratic() { let mut jit = JitCompiler::new(); let func_id = compile_quadratic(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let quad_fn = unsafe { std::mem::transmute::<*const u8, fn(f64, f64, f64, f64) -> f64>(code) }; // f(x) = 2x² + 3x + 1 // f(2) = 2*4 + 3*2 + 1 = 8 + 6 + 1 = 15 assert_eq!(quad_fn(2.0, 2.0, 3.0, 1.0), 15.0); } } /// A simple JIT compiler using Cranelift pub struct JitCompiler { builder_context: FunctionBuilderContext, ctx: Context, module: JITModule, } }
The JitCompiler structure encapsulates the Cranelift compilation pipeline. The builder context maintains state across function compilations, the module context holds the intermediate representation, and the JITModule manages the generated machine code and symbol resolution.
#![allow(unused)] fn main() { pub fn new() -> Self { let mut flag_builder = settings::builder(); flag_builder.set("use_colocated_libcalls", "false").unwrap(); flag_builder.set("is_pic", "false").unwrap(); let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| { panic!("host machine is not supported: {}", msg); }); let isa = isa_builder .finish(settings::Flags::new(flag_builder)) .unwrap(); let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names()); builder.symbol_lookup_fn(Box::new(|name| { // Hook up external functions match name { "println_i64" => Some(println_i64 as *const u8), "println_f64" => Some(println_f64 as *const u8), _ => None, } })); let module = JITModule::new(builder); Self { builder_context: FunctionBuilderContext::new(), ctx: module.make_context(), module, } } }
Initialization configures the target architecture and compilation settings. The ISA (Instruction Set Architecture) builder automatically detects the host CPU features, while settings control trade-offs between compilation speed and code quality. The symbol lookup function enables linking to external functions, crucial for runtime library calls.
Function Compilation
#![allow(unused)] fn main() { pub fn compile_function( &mut self, name: &str, params: Vec<Type>, returns: Vec<Type>, build_fn: impl FnOnce(&mut FunctionBuilder, &[Variable]), ) -> Result<FuncId, String> { // Clear the context self.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), self.make_signature(params.clone(), returns.clone()), ); // Create the function builder let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); // Create variables for parameters let variables: Vec<Variable> = params.iter().map(|ty| builder.declare_var(*ty)).collect(); // Create entry block and append parameters let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); // Define parameters for (i, var) in variables.iter().enumerate() { let val = builder.block_params(entry_block)[i]; builder.def_var(*var, val); } // Call the user's function to build the body build_fn(&mut builder, &variables); // Finalize the function builder.finalize(); // Verify the function if let Err(errors) = verify_function(&self.ctx.func, self.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } // Define the function in the module let func_id = self .module .declare_function(name, Linkage::Export, &self.ctx.func.signature) .map_err(|e| e.to_string())?; self.module .define_function(func_id, &mut self.ctx) .map_err(|e| e.to_string())?; // Clear the context for next use self.module.clear_context(&mut self.ctx); Ok(func_id) } }
Function compilation transforms high-level operations into machine code through several phases. The FunctionBuilder provides a convenient API for constructing the control flow graph and instruction sequences. Variable management connects high-level variables to SSA values, while block sealing enables efficient phi node insertion. The verification step ensures the generated IR satisfies Cranelift’s invariants before code generation.
Instruction Building
#![allow(unused)] fn main() { use std::collections::HashMap; use cranelift::codegen::ir::types::*; use cranelift::codegen::ir::{AbiParam, Function, InstBuilder, Signature, UserFuncName}; use cranelift::codegen::settings::{self, Configurable}; use cranelift::codegen::verifier::verify_function; use cranelift::codegen::Context; use cranelift::frontend::{FunctionBuilder, FunctionBuilderContext, Variable}; use cranelift::prelude::*; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{FuncId, Linkage, Module}; /// A simple JIT compiler using Cranelift pub struct JitCompiler { builder_context: FunctionBuilderContext, ctx: Context, module: JITModule, } impl JitCompiler { pub fn new() -> Self { let mut flag_builder = settings::builder(); flag_builder.set("use_colocated_libcalls", "false").unwrap(); flag_builder.set("is_pic", "false").unwrap(); let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| { panic!("host machine is not supported: {}", msg); }); let isa = isa_builder .finish(settings::Flags::new(flag_builder)) .unwrap(); let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names()); builder.symbol_lookup_fn(Box::new(|name| { // Hook up external functions match name { "println_i64" => Some(println_i64 as *const u8), "println_f64" => Some(println_f64 as *const u8), _ => None, } })); let module = JITModule::new(builder); Self { builder_context: FunctionBuilderContext::new(), ctx: module.make_context(), module, } } pub fn compile_function( &mut self, name: &str, params: Vec<Type>, returns: Vec<Type>, build_fn: impl FnOnce(&mut FunctionBuilder, &[Variable]), ) -> Result<FuncId, String> { // Clear the context self.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), self.make_signature(params.clone(), returns.clone()), ); // Create the function builder let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); // Create variables for parameters let variables: Vec<Variable> = params.iter().map(|ty| builder.declare_var(*ty)).collect(); // Create entry block and append parameters let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); // Define parameters for (i, var) in variables.iter().enumerate() { let val = builder.block_params(entry_block)[i]; builder.def_var(*var, val); } // Call the user's function to build the body build_fn(&mut builder, &variables); // Finalize the function builder.finalize(); // Verify the function if let Err(errors) = verify_function(&self.ctx.func, self.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } // Define the function in the module let func_id = self .module .declare_function(name, Linkage::Export, &self.ctx.func.signature) .map_err(|e| e.to_string())?; self.module .define_function(func_id, &mut self.ctx) .map_err(|e| e.to_string())?; // Clear the context for next use self.module.clear_context(&mut self.ctx); Ok(func_id) } pub fn finalize(&mut self) { self.module.finalize_definitions().unwrap(); } pub fn get_function(&self, func_id: FuncId) -> *const u8 { self.module.get_finalized_function(func_id) } fn make_signature(&self, params: Vec<Type>, returns: Vec<Type>) -> Signature { let mut sig = self.module.make_signature(); for param in params { sig.params.push(AbiParam::new(param)); } for ret in returns { sig.returns.push(AbiParam::new(ret)); } sig } } impl Default for JitCompiler { fn default() -> Self { Self::new() } } extern "C" fn println_i64(x: i64) { println!("{}", x); } extern "C" fn println_f64(x: f64) { println!("{}", x); } /// Example: Compile a factorial function pub fn compile_factorial(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("factorial", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Add block parameters builder.append_block_param(header_block, I64); // i builder.append_block_param(header_block, I64); // result // Entry: jump to header with initial values let one = builder.ins().iconst(I64, 1); builder.ins().jump(header_block, &[one.into(), one.into()]); // Header block: check if i <= n builder.switch_to_block(header_block); let i = builder.block_params(header_block)[0]; let result = builder.block_params(header_block)[1]; let n_val = builder.use_var(n); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, i, n_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body block: result *= i; i++ builder.switch_to_block(body_block); builder.seal_block(body_block); let new_result = builder.ins().imul(result, i); let new_i = builder.ins().iadd_imm(i, 1); builder .ins() .jump(header_block, &[new_i.into(), new_result.into()]); // Exit block: return result builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[result]); }) } /// Example: Compile a Fibonacci function pub fn compile_fibonacci(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("fibonacci", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let check_base = builder.create_block(); let recursive = builder.create_block(); let return_n = builder.create_block(); // Jump to check_base builder.ins().jump(check_base, &[]); // Check if n <= 1 builder.switch_to_block(check_base); let n_val = builder.use_var(n); let one = builder.ins().iconst(I64, 1); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, n_val, one); builder.ins().brif(cmp, return_n, &[], recursive, &[]); // Return n for base case builder.switch_to_block(return_n); builder.seal_block(return_n); builder.ins().return_(&[n_val]); // Recursive case: fib(n-1) + fib(n-2) builder.switch_to_block(recursive); builder.seal_block(recursive); builder.seal_block(check_base); // For simplicity, we'll compute iteratively let two = builder.ins().iconst(I64, 2); let a = builder.ins().iconst(I64, 0); let b = builder.ins().iconst(I64, 1); // Create loop blocks let loop_header = builder.create_block(); let loop_body = builder.create_block(); let loop_exit = builder.create_block(); builder.append_block_param(loop_header, I64); // counter builder.append_block_param(loop_header, I64); // a builder.append_block_param(loop_header, I64); // b builder .ins() .jump(loop_header, &[two.into(), a.into(), b.into()]); // Loop header: check if counter <= n builder.switch_to_block(loop_header); let counter = builder.block_params(loop_header)[0]; let curr_a = builder.block_params(loop_header)[1]; let curr_b = builder.block_params(loop_header)[2]; let cmp = builder .ins() .icmp(IntCC::SignedLessThanOrEqual, counter, n_val); builder.ins().brif(cmp, loop_body, &[], loop_exit, &[]); // Loop body: compute next fibonacci number builder.switch_to_block(loop_body); builder.seal_block(loop_body); let next_fib = builder.ins().iadd(curr_a, curr_b); let next_counter = builder.ins().iadd_imm(counter, 1); builder.ins().jump( loop_header, &[next_counter.into(), curr_b.into(), next_fib.into()], ); // Loop exit: return b builder.switch_to_block(loop_exit); builder.seal_block(loop_exit); builder.seal_block(loop_header); builder.ins().return_(&[curr_b]); }) } /// Example: Working with floating point pub fn compile_quadratic(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "quadratic", vec![F64, F64, F64, F64], vec![F64], |builder, params| { // f(x) = ax² + bx + c let x = builder.use_var(params[0]); let a = builder.use_var(params[1]); let b = builder.use_var(params[2]); let c = builder.use_var(params[3]); let x_squared = builder.ins().fmul(x, x); let ax_squared = builder.ins().fmul(a, x_squared); let bx = builder.ins().fmul(b, x); let ax_squared_plus_bx = builder.ins().fadd(ax_squared, bx); let result = builder.ins().fadd(ax_squared_plus_bx, c); builder.ins().return_(&[result]); }, ) } /// Example: Using external function calls pub fn compile_with_print(jit: &mut JitCompiler) -> Result<FuncId, String> { // First declare the external function let mut sig = jit.module.make_signature(); sig.params.push(AbiParam::new(I64)); let println_id = jit .module .declare_function("println_i64", Linkage::Import, &sig) .unwrap(); // Define the function let func_id = jit .module .declare_function( "print_sum", Linkage::Export, &jit.make_signature(vec![I64, I64], vec![]), ) .unwrap(); // Create function context jit.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), jit.make_signature(vec![I64, I64], vec![]), ); // Build the function { let mut builder = FunctionBuilder::new(&mut jit.ctx.func, &mut jit.builder_context); let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); let x = builder.declare_var(I64); let y = builder.declare_var(I64); let x_val = builder.block_params(entry_block)[0]; let y_val = builder.block_params(entry_block)[1]; builder.def_var(x, x_val); builder.def_var(y, y_val); let x_use = builder.use_var(x); let y_use = builder.use_var(y); let sum = builder.ins().iadd(x_use, y_use); // Declare the function reference for calling let println_ref = jit.module.declare_func_in_func(println_id, builder.func); builder.ins().call(println_ref, &[sum]); builder.ins().return_(&[]); builder.finalize(); } // Verify the function if let Err(errors) = verify_function(&jit.ctx.func, jit.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } jit.module .define_function(func_id, &mut jit.ctx) .map_err(|e| e.to_string())?; jit.module.clear_context(&mut jit.ctx); Ok(func_id) } /// Example: Control flow with multiple returns pub fn compile_max(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("max", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let then_block = builder.create_block(); let else_block = builder.create_block(); // if x > y let cmp = builder.ins().icmp(IntCC::SignedGreaterThan, x, y); builder.ins().brif(cmp, then_block, &[], else_block, &[]); // then: return x builder.switch_to_block(then_block); builder.seal_block(then_block); builder.ins().return_(&[x]); // else: return y builder.switch_to_block(else_block); builder.seal_block(else_block); builder.ins().return_(&[y]); }) } /// Example: Array/memory operations pub fn compile_sum_array(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "sum_array", vec![I64, I64], // ptr, len vec![I64], |builder, params| { let ptr = params[0]; let len = params[1]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Block parameters builder.append_block_param(header_block, I64); // index builder.append_block_param(header_block, I64); // sum builder.append_block_param(header_block, I64); // current_ptr // Initialize loop let zero = builder.ins().iconst(I64, 0); let ptr_val = builder.use_var(ptr); builder .ins() .jump(header_block, &[zero.into(), zero.into(), ptr_val.into()]); // Header: check if index < len builder.switch_to_block(header_block); let index = builder.block_params(header_block)[0]; let sum = builder.block_params(header_block)[1]; let current_ptr = builder.block_params(header_block)[2]; let len_val = builder.use_var(len); let cmp = builder.ins().icmp(IntCC::UnsignedLessThan, index, len_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body: load value and add to sum builder.switch_to_block(body_block); builder.seal_block(body_block); let flags = MemFlags::new(); let value = builder.ins().load(I64, flags, current_ptr, 0); let new_sum = builder.ins().iadd(sum, value); let new_index = builder.ins().iadd_imm(index, 1); let new_ptr = builder.ins().iadd_imm(current_ptr, 8); // 8 bytes for i64 builder.ins().jump( header_block, &[new_index.into(), new_sum.into(), new_ptr.into()], ); // Exit: return sum builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[sum]); }, ) } /// Example: Compile a simple expression evaluator #[derive(Debug, Clone)] pub enum Expr { Const(i64), Add(Box<Expr>, Box<Expr>), Sub(Box<Expr>, Box<Expr>), Mul(Box<Expr>, Box<Expr>), Var(usize), } impl Expr { pub fn compile(&self, builder: &mut FunctionBuilder, vars: &[Variable]) -> Value { match self { Expr::Const(n) => builder.ins().iconst(I64, *n), Expr::Add(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().iadd(a_val, b_val) } Expr::Sub(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().isub(a_val, b_val) } Expr::Mul(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().imul(a_val, b_val) } Expr::Var(idx) => builder.use_var(vars[*idx]), } } } pub fn compile_expression(jit: &mut JitCompiler, expr: Expr) -> Result<FuncId, String> { jit.compile_function( "eval_expr", vec![I64, I64], // two variables vec![I64], |builder, params| { let result = expr.compile(builder, params); builder.ins().return_(&[result]); }, ) } /// Symbol table for variable management pub struct SymbolTable { variables: HashMap<String, Variable>, next_var: usize, } impl SymbolTable { pub fn new() -> Self { Self { variables: HashMap::new(), next_var: 0, } } pub fn declare(&mut self, name: String, builder: &mut FunctionBuilder, ty: Type) -> Variable { let var = builder.declare_var(ty); self.variables.insert(name.clone(), var); self.next_var += 1; var } pub fn get(&self, name: &str) -> Option<Variable> { self.variables.get(name).copied() } } impl Default for SymbolTable { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_compile_add() { let mut jit = JitCompiler::new(); let func_id = compile_add_function(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let add_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(add_fn(2, 3), 5); assert_eq!(add_fn(10, -5), 5); } #[test] fn test_compile_factorial() { let mut jit = JitCompiler::new(); let func_id = compile_factorial(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let factorial_fn = unsafe { std::mem::transmute::<*const u8, fn(i64) -> i64>(code) }; assert_eq!(factorial_fn(0), 1); assert_eq!(factorial_fn(1), 1); assert_eq!(factorial_fn(5), 120); } #[test] fn test_compile_max() { let mut jit = JitCompiler::new(); let func_id = compile_max(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let max_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(max_fn(5, 3), 5); assert_eq!(max_fn(2, 8), 8); assert_eq!(max_fn(-5, -3), -3); } #[test] fn test_compile_expression() { let mut jit = JitCompiler::new(); // (x + 3) * (y - 2) let expr = Expr::Mul( Box::new(Expr::Add(Box::new(Expr::Var(0)), Box::new(Expr::Const(3)))), Box::new(Expr::Sub(Box::new(Expr::Var(1)), Box::new(Expr::Const(2)))), ); let func_id = compile_expression(&mut jit, expr).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let eval_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(eval_fn(5, 7), 40); // (5+3) * (7-2) = 8 * 5 = 40 assert_eq!(eval_fn(2, 4), 10); // (2+3) * (4-2) = 5 * 2 = 10 } #[test] fn test_quadratic() { let mut jit = JitCompiler::new(); let func_id = compile_quadratic(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let quad_fn = unsafe { std::mem::transmute::<*const u8, fn(f64, f64, f64, f64) -> f64>(code) }; // f(x) = 2x² + 3x + 1 // f(2) = 2*4 + 3*2 + 1 = 8 + 6 + 1 = 15 assert_eq!(quad_fn(2.0, 2.0, 3.0, 1.0), 15.0); } } /// Example: Compile a simple arithmetic function pub fn compile_add_function(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("add", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let sum = builder.ins().iadd(x, y); builder.ins().return_(&[sum]); }) } }
Simple arithmetic operations demonstrate the instruction builder interface. Variables provide a high-level abstraction over SSA values, automatically handling phi nodes at control flow merge points. The return instruction explicitly specifies which values to return, supporting multiple return values naturally.
Control Flow
#![allow(unused)] fn main() { use std::collections::HashMap; use cranelift::codegen::ir::types::*; use cranelift::codegen::ir::{AbiParam, Function, InstBuilder, Signature, UserFuncName}; use cranelift::codegen::settings::{self, Configurable}; use cranelift::codegen::verifier::verify_function; use cranelift::codegen::Context; use cranelift::frontend::{FunctionBuilder, FunctionBuilderContext, Variable}; use cranelift::prelude::*; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{FuncId, Linkage, Module}; /// A simple JIT compiler using Cranelift pub struct JitCompiler { builder_context: FunctionBuilderContext, ctx: Context, module: JITModule, } impl JitCompiler { pub fn new() -> Self { let mut flag_builder = settings::builder(); flag_builder.set("use_colocated_libcalls", "false").unwrap(); flag_builder.set("is_pic", "false").unwrap(); let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| { panic!("host machine is not supported: {}", msg); }); let isa = isa_builder .finish(settings::Flags::new(flag_builder)) .unwrap(); let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names()); builder.symbol_lookup_fn(Box::new(|name| { // Hook up external functions match name { "println_i64" => Some(println_i64 as *const u8), "println_f64" => Some(println_f64 as *const u8), _ => None, } })); let module = JITModule::new(builder); Self { builder_context: FunctionBuilderContext::new(), ctx: module.make_context(), module, } } pub fn compile_function( &mut self, name: &str, params: Vec<Type>, returns: Vec<Type>, build_fn: impl FnOnce(&mut FunctionBuilder, &[Variable]), ) -> Result<FuncId, String> { // Clear the context self.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), self.make_signature(params.clone(), returns.clone()), ); // Create the function builder let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); // Create variables for parameters let variables: Vec<Variable> = params.iter().map(|ty| builder.declare_var(*ty)).collect(); // Create entry block and append parameters let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); // Define parameters for (i, var) in variables.iter().enumerate() { let val = builder.block_params(entry_block)[i]; builder.def_var(*var, val); } // Call the user's function to build the body build_fn(&mut builder, &variables); // Finalize the function builder.finalize(); // Verify the function if let Err(errors) = verify_function(&self.ctx.func, self.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } // Define the function in the module let func_id = self .module .declare_function(name, Linkage::Export, &self.ctx.func.signature) .map_err(|e| e.to_string())?; self.module .define_function(func_id, &mut self.ctx) .map_err(|e| e.to_string())?; // Clear the context for next use self.module.clear_context(&mut self.ctx); Ok(func_id) } pub fn finalize(&mut self) { self.module.finalize_definitions().unwrap(); } pub fn get_function(&self, func_id: FuncId) -> *const u8 { self.module.get_finalized_function(func_id) } fn make_signature(&self, params: Vec<Type>, returns: Vec<Type>) -> Signature { let mut sig = self.module.make_signature(); for param in params { sig.params.push(AbiParam::new(param)); } for ret in returns { sig.returns.push(AbiParam::new(ret)); } sig } } impl Default for JitCompiler { fn default() -> Self { Self::new() } } extern "C" fn println_i64(x: i64) { println!("{}", x); } extern "C" fn println_f64(x: f64) { println!("{}", x); } /// Example: Compile a simple arithmetic function pub fn compile_add_function(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("add", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let sum = builder.ins().iadd(x, y); builder.ins().return_(&[sum]); }) } /// Example: Compile a Fibonacci function pub fn compile_fibonacci(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("fibonacci", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let check_base = builder.create_block(); let recursive = builder.create_block(); let return_n = builder.create_block(); // Jump to check_base builder.ins().jump(check_base, &[]); // Check if n <= 1 builder.switch_to_block(check_base); let n_val = builder.use_var(n); let one = builder.ins().iconst(I64, 1); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, n_val, one); builder.ins().brif(cmp, return_n, &[], recursive, &[]); // Return n for base case builder.switch_to_block(return_n); builder.seal_block(return_n); builder.ins().return_(&[n_val]); // Recursive case: fib(n-1) + fib(n-2) builder.switch_to_block(recursive); builder.seal_block(recursive); builder.seal_block(check_base); // For simplicity, we'll compute iteratively let two = builder.ins().iconst(I64, 2); let a = builder.ins().iconst(I64, 0); let b = builder.ins().iconst(I64, 1); // Create loop blocks let loop_header = builder.create_block(); let loop_body = builder.create_block(); let loop_exit = builder.create_block(); builder.append_block_param(loop_header, I64); // counter builder.append_block_param(loop_header, I64); // a builder.append_block_param(loop_header, I64); // b builder .ins() .jump(loop_header, &[two.into(), a.into(), b.into()]); // Loop header: check if counter <= n builder.switch_to_block(loop_header); let counter = builder.block_params(loop_header)[0]; let curr_a = builder.block_params(loop_header)[1]; let curr_b = builder.block_params(loop_header)[2]; let cmp = builder .ins() .icmp(IntCC::SignedLessThanOrEqual, counter, n_val); builder.ins().brif(cmp, loop_body, &[], loop_exit, &[]); // Loop body: compute next fibonacci number builder.switch_to_block(loop_body); builder.seal_block(loop_body); let next_fib = builder.ins().iadd(curr_a, curr_b); let next_counter = builder.ins().iadd_imm(counter, 1); builder.ins().jump( loop_header, &[next_counter.into(), curr_b.into(), next_fib.into()], ); // Loop exit: return b builder.switch_to_block(loop_exit); builder.seal_block(loop_exit); builder.seal_block(loop_header); builder.ins().return_(&[curr_b]); }) } /// Example: Working with floating point pub fn compile_quadratic(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "quadratic", vec![F64, F64, F64, F64], vec![F64], |builder, params| { // f(x) = ax² + bx + c let x = builder.use_var(params[0]); let a = builder.use_var(params[1]); let b = builder.use_var(params[2]); let c = builder.use_var(params[3]); let x_squared = builder.ins().fmul(x, x); let ax_squared = builder.ins().fmul(a, x_squared); let bx = builder.ins().fmul(b, x); let ax_squared_plus_bx = builder.ins().fadd(ax_squared, bx); let result = builder.ins().fadd(ax_squared_plus_bx, c); builder.ins().return_(&[result]); }, ) } /// Example: Using external function calls pub fn compile_with_print(jit: &mut JitCompiler) -> Result<FuncId, String> { // First declare the external function let mut sig = jit.module.make_signature(); sig.params.push(AbiParam::new(I64)); let println_id = jit .module .declare_function("println_i64", Linkage::Import, &sig) .unwrap(); // Define the function let func_id = jit .module .declare_function( "print_sum", Linkage::Export, &jit.make_signature(vec![I64, I64], vec![]), ) .unwrap(); // Create function context jit.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), jit.make_signature(vec![I64, I64], vec![]), ); // Build the function { let mut builder = FunctionBuilder::new(&mut jit.ctx.func, &mut jit.builder_context); let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); let x = builder.declare_var(I64); let y = builder.declare_var(I64); let x_val = builder.block_params(entry_block)[0]; let y_val = builder.block_params(entry_block)[1]; builder.def_var(x, x_val); builder.def_var(y, y_val); let x_use = builder.use_var(x); let y_use = builder.use_var(y); let sum = builder.ins().iadd(x_use, y_use); // Declare the function reference for calling let println_ref = jit.module.declare_func_in_func(println_id, builder.func); builder.ins().call(println_ref, &[sum]); builder.ins().return_(&[]); builder.finalize(); } // Verify the function if let Err(errors) = verify_function(&jit.ctx.func, jit.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } jit.module .define_function(func_id, &mut jit.ctx) .map_err(|e| e.to_string())?; jit.module.clear_context(&mut jit.ctx); Ok(func_id) } /// Example: Control flow with multiple returns pub fn compile_max(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("max", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let then_block = builder.create_block(); let else_block = builder.create_block(); // if x > y let cmp = builder.ins().icmp(IntCC::SignedGreaterThan, x, y); builder.ins().brif(cmp, then_block, &[], else_block, &[]); // then: return x builder.switch_to_block(then_block); builder.seal_block(then_block); builder.ins().return_(&[x]); // else: return y builder.switch_to_block(else_block); builder.seal_block(else_block); builder.ins().return_(&[y]); }) } /// Example: Array/memory operations pub fn compile_sum_array(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "sum_array", vec![I64, I64], // ptr, len vec![I64], |builder, params| { let ptr = params[0]; let len = params[1]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Block parameters builder.append_block_param(header_block, I64); // index builder.append_block_param(header_block, I64); // sum builder.append_block_param(header_block, I64); // current_ptr // Initialize loop let zero = builder.ins().iconst(I64, 0); let ptr_val = builder.use_var(ptr); builder .ins() .jump(header_block, &[zero.into(), zero.into(), ptr_val.into()]); // Header: check if index < len builder.switch_to_block(header_block); let index = builder.block_params(header_block)[0]; let sum = builder.block_params(header_block)[1]; let current_ptr = builder.block_params(header_block)[2]; let len_val = builder.use_var(len); let cmp = builder.ins().icmp(IntCC::UnsignedLessThan, index, len_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body: load value and add to sum builder.switch_to_block(body_block); builder.seal_block(body_block); let flags = MemFlags::new(); let value = builder.ins().load(I64, flags, current_ptr, 0); let new_sum = builder.ins().iadd(sum, value); let new_index = builder.ins().iadd_imm(index, 1); let new_ptr = builder.ins().iadd_imm(current_ptr, 8); // 8 bytes for i64 builder.ins().jump( header_block, &[new_index.into(), new_sum.into(), new_ptr.into()], ); // Exit: return sum builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[sum]); }, ) } /// Example: Compile a simple expression evaluator #[derive(Debug, Clone)] pub enum Expr { Const(i64), Add(Box<Expr>, Box<Expr>), Sub(Box<Expr>, Box<Expr>), Mul(Box<Expr>, Box<Expr>), Var(usize), } impl Expr { pub fn compile(&self, builder: &mut FunctionBuilder, vars: &[Variable]) -> Value { match self { Expr::Const(n) => builder.ins().iconst(I64, *n), Expr::Add(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().iadd(a_val, b_val) } Expr::Sub(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().isub(a_val, b_val) } Expr::Mul(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().imul(a_val, b_val) } Expr::Var(idx) => builder.use_var(vars[*idx]), } } } pub fn compile_expression(jit: &mut JitCompiler, expr: Expr) -> Result<FuncId, String> { jit.compile_function( "eval_expr", vec![I64, I64], // two variables vec![I64], |builder, params| { let result = expr.compile(builder, params); builder.ins().return_(&[result]); }, ) } /// Symbol table for variable management pub struct SymbolTable { variables: HashMap<String, Variable>, next_var: usize, } impl SymbolTable { pub fn new() -> Self { Self { variables: HashMap::new(), next_var: 0, } } pub fn declare(&mut self, name: String, builder: &mut FunctionBuilder, ty: Type) -> Variable { let var = builder.declare_var(ty); self.variables.insert(name.clone(), var); self.next_var += 1; var } pub fn get(&self, name: &str) -> Option<Variable> { self.variables.get(name).copied() } } impl Default for SymbolTable { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_compile_add() { let mut jit = JitCompiler::new(); let func_id = compile_add_function(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let add_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(add_fn(2, 3), 5); assert_eq!(add_fn(10, -5), 5); } #[test] fn test_compile_factorial() { let mut jit = JitCompiler::new(); let func_id = compile_factorial(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let factorial_fn = unsafe { std::mem::transmute::<*const u8, fn(i64) -> i64>(code) }; assert_eq!(factorial_fn(0), 1); assert_eq!(factorial_fn(1), 1); assert_eq!(factorial_fn(5), 120); } #[test] fn test_compile_max() { let mut jit = JitCompiler::new(); let func_id = compile_max(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let max_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(max_fn(5, 3), 5); assert_eq!(max_fn(2, 8), 8); assert_eq!(max_fn(-5, -3), -3); } #[test] fn test_compile_expression() { let mut jit = JitCompiler::new(); // (x + 3) * (y - 2) let expr = Expr::Mul( Box::new(Expr::Add(Box::new(Expr::Var(0)), Box::new(Expr::Const(3)))), Box::new(Expr::Sub(Box::new(Expr::Var(1)), Box::new(Expr::Const(2)))), ); let func_id = compile_expression(&mut jit, expr).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let eval_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(eval_fn(5, 7), 40); // (5+3) * (7-2) = 8 * 5 = 40 assert_eq!(eval_fn(2, 4), 10); // (2+3) * (4-2) = 5 * 2 = 10 } #[test] fn test_quadratic() { let mut jit = JitCompiler::new(); let func_id = compile_quadratic(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let quad_fn = unsafe { std::mem::transmute::<*const u8, fn(f64, f64, f64, f64) -> f64>(code) }; // f(x) = 2x² + 3x + 1 // f(2) = 2*4 + 3*2 + 1 = 8 + 6 + 1 = 15 assert_eq!(quad_fn(2.0, 2.0, 3.0, 1.0), 15.0); } } /// Example: Compile a factorial function pub fn compile_factorial(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("factorial", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Add block parameters builder.append_block_param(header_block, I64); // i builder.append_block_param(header_block, I64); // result // Entry: jump to header with initial values let one = builder.ins().iconst(I64, 1); builder.ins().jump(header_block, &[one.into(), one.into()]); // Header block: check if i <= n builder.switch_to_block(header_block); let i = builder.block_params(header_block)[0]; let result = builder.block_params(header_block)[1]; let n_val = builder.use_var(n); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, i, n_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body block: result *= i; i++ builder.switch_to_block(body_block); builder.seal_block(body_block); let new_result = builder.ins().imul(result, i); let new_i = builder.ins().iadd_imm(i, 1); builder .ins() .jump(header_block, &[new_i.into(), new_result.into()]); // Exit block: return result builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[result]); }) } }
Loop construction requires explicit block management and parameter passing. Block parameters implement SSA form, making data flow explicit at control flow joins. The seal operation indicates when all predecessors of a block are known, enabling efficient phi node insertion. Conditional branches carry arguments for the taken branch, implementing a form of conditional move at the IR level.
#![allow(unused)] fn main() { use std::collections::HashMap; use cranelift::codegen::ir::types::*; use cranelift::codegen::ir::{AbiParam, Function, InstBuilder, Signature, UserFuncName}; use cranelift::codegen::settings::{self, Configurable}; use cranelift::codegen::verifier::verify_function; use cranelift::codegen::Context; use cranelift::frontend::{FunctionBuilder, FunctionBuilderContext, Variable}; use cranelift::prelude::*; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{FuncId, Linkage, Module}; /// A simple JIT compiler using Cranelift pub struct JitCompiler { builder_context: FunctionBuilderContext, ctx: Context, module: JITModule, } impl JitCompiler { pub fn new() -> Self { let mut flag_builder = settings::builder(); flag_builder.set("use_colocated_libcalls", "false").unwrap(); flag_builder.set("is_pic", "false").unwrap(); let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| { panic!("host machine is not supported: {}", msg); }); let isa = isa_builder .finish(settings::Flags::new(flag_builder)) .unwrap(); let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names()); builder.symbol_lookup_fn(Box::new(|name| { // Hook up external functions match name { "println_i64" => Some(println_i64 as *const u8), "println_f64" => Some(println_f64 as *const u8), _ => None, } })); let module = JITModule::new(builder); Self { builder_context: FunctionBuilderContext::new(), ctx: module.make_context(), module, } } pub fn compile_function( &mut self, name: &str, params: Vec<Type>, returns: Vec<Type>, build_fn: impl FnOnce(&mut FunctionBuilder, &[Variable]), ) -> Result<FuncId, String> { // Clear the context self.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), self.make_signature(params.clone(), returns.clone()), ); // Create the function builder let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); // Create variables for parameters let variables: Vec<Variable> = params.iter().map(|ty| builder.declare_var(*ty)).collect(); // Create entry block and append parameters let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); // Define parameters for (i, var) in variables.iter().enumerate() { let val = builder.block_params(entry_block)[i]; builder.def_var(*var, val); } // Call the user's function to build the body build_fn(&mut builder, &variables); // Finalize the function builder.finalize(); // Verify the function if let Err(errors) = verify_function(&self.ctx.func, self.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } // Define the function in the module let func_id = self .module .declare_function(name, Linkage::Export, &self.ctx.func.signature) .map_err(|e| e.to_string())?; self.module .define_function(func_id, &mut self.ctx) .map_err(|e| e.to_string())?; // Clear the context for next use self.module.clear_context(&mut self.ctx); Ok(func_id) } pub fn finalize(&mut self) { self.module.finalize_definitions().unwrap(); } pub fn get_function(&self, func_id: FuncId) -> *const u8 { self.module.get_finalized_function(func_id) } fn make_signature(&self, params: Vec<Type>, returns: Vec<Type>) -> Signature { let mut sig = self.module.make_signature(); for param in params { sig.params.push(AbiParam::new(param)); } for ret in returns { sig.returns.push(AbiParam::new(ret)); } sig } } impl Default for JitCompiler { fn default() -> Self { Self::new() } } extern "C" fn println_i64(x: i64) { println!("{}", x); } extern "C" fn println_f64(x: f64) { println!("{}", x); } /// Example: Compile a simple arithmetic function pub fn compile_add_function(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("add", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let sum = builder.ins().iadd(x, y); builder.ins().return_(&[sum]); }) } /// Example: Compile a factorial function pub fn compile_factorial(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("factorial", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Add block parameters builder.append_block_param(header_block, I64); // i builder.append_block_param(header_block, I64); // result // Entry: jump to header with initial values let one = builder.ins().iconst(I64, 1); builder.ins().jump(header_block, &[one.into(), one.into()]); // Header block: check if i <= n builder.switch_to_block(header_block); let i = builder.block_params(header_block)[0]; let result = builder.block_params(header_block)[1]; let n_val = builder.use_var(n); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, i, n_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body block: result *= i; i++ builder.switch_to_block(body_block); builder.seal_block(body_block); let new_result = builder.ins().imul(result, i); let new_i = builder.ins().iadd_imm(i, 1); builder .ins() .jump(header_block, &[new_i.into(), new_result.into()]); // Exit block: return result builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[result]); }) } /// Example: Working with floating point pub fn compile_quadratic(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "quadratic", vec![F64, F64, F64, F64], vec![F64], |builder, params| { // f(x) = ax² + bx + c let x = builder.use_var(params[0]); let a = builder.use_var(params[1]); let b = builder.use_var(params[2]); let c = builder.use_var(params[3]); let x_squared = builder.ins().fmul(x, x); let ax_squared = builder.ins().fmul(a, x_squared); let bx = builder.ins().fmul(b, x); let ax_squared_plus_bx = builder.ins().fadd(ax_squared, bx); let result = builder.ins().fadd(ax_squared_plus_bx, c); builder.ins().return_(&[result]); }, ) } /// Example: Using external function calls pub fn compile_with_print(jit: &mut JitCompiler) -> Result<FuncId, String> { // First declare the external function let mut sig = jit.module.make_signature(); sig.params.push(AbiParam::new(I64)); let println_id = jit .module .declare_function("println_i64", Linkage::Import, &sig) .unwrap(); // Define the function let func_id = jit .module .declare_function( "print_sum", Linkage::Export, &jit.make_signature(vec![I64, I64], vec![]), ) .unwrap(); // Create function context jit.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), jit.make_signature(vec![I64, I64], vec![]), ); // Build the function { let mut builder = FunctionBuilder::new(&mut jit.ctx.func, &mut jit.builder_context); let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); let x = builder.declare_var(I64); let y = builder.declare_var(I64); let x_val = builder.block_params(entry_block)[0]; let y_val = builder.block_params(entry_block)[1]; builder.def_var(x, x_val); builder.def_var(y, y_val); let x_use = builder.use_var(x); let y_use = builder.use_var(y); let sum = builder.ins().iadd(x_use, y_use); // Declare the function reference for calling let println_ref = jit.module.declare_func_in_func(println_id, builder.func); builder.ins().call(println_ref, &[sum]); builder.ins().return_(&[]); builder.finalize(); } // Verify the function if let Err(errors) = verify_function(&jit.ctx.func, jit.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } jit.module .define_function(func_id, &mut jit.ctx) .map_err(|e| e.to_string())?; jit.module.clear_context(&mut jit.ctx); Ok(func_id) } /// Example: Control flow with multiple returns pub fn compile_max(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("max", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let then_block = builder.create_block(); let else_block = builder.create_block(); // if x > y let cmp = builder.ins().icmp(IntCC::SignedGreaterThan, x, y); builder.ins().brif(cmp, then_block, &[], else_block, &[]); // then: return x builder.switch_to_block(then_block); builder.seal_block(then_block); builder.ins().return_(&[x]); // else: return y builder.switch_to_block(else_block); builder.seal_block(else_block); builder.ins().return_(&[y]); }) } /// Example: Array/memory operations pub fn compile_sum_array(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "sum_array", vec![I64, I64], // ptr, len vec![I64], |builder, params| { let ptr = params[0]; let len = params[1]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Block parameters builder.append_block_param(header_block, I64); // index builder.append_block_param(header_block, I64); // sum builder.append_block_param(header_block, I64); // current_ptr // Initialize loop let zero = builder.ins().iconst(I64, 0); let ptr_val = builder.use_var(ptr); builder .ins() .jump(header_block, &[zero.into(), zero.into(), ptr_val.into()]); // Header: check if index < len builder.switch_to_block(header_block); let index = builder.block_params(header_block)[0]; let sum = builder.block_params(header_block)[1]; let current_ptr = builder.block_params(header_block)[2]; let len_val = builder.use_var(len); let cmp = builder.ins().icmp(IntCC::UnsignedLessThan, index, len_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body: load value and add to sum builder.switch_to_block(body_block); builder.seal_block(body_block); let flags = MemFlags::new(); let value = builder.ins().load(I64, flags, current_ptr, 0); let new_sum = builder.ins().iadd(sum, value); let new_index = builder.ins().iadd_imm(index, 1); let new_ptr = builder.ins().iadd_imm(current_ptr, 8); // 8 bytes for i64 builder.ins().jump( header_block, &[new_index.into(), new_sum.into(), new_ptr.into()], ); // Exit: return sum builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[sum]); }, ) } /// Example: Compile a simple expression evaluator #[derive(Debug, Clone)] pub enum Expr { Const(i64), Add(Box<Expr>, Box<Expr>), Sub(Box<Expr>, Box<Expr>), Mul(Box<Expr>, Box<Expr>), Var(usize), } impl Expr { pub fn compile(&self, builder: &mut FunctionBuilder, vars: &[Variable]) -> Value { match self { Expr::Const(n) => builder.ins().iconst(I64, *n), Expr::Add(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().iadd(a_val, b_val) } Expr::Sub(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().isub(a_val, b_val) } Expr::Mul(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().imul(a_val, b_val) } Expr::Var(idx) => builder.use_var(vars[*idx]), } } } pub fn compile_expression(jit: &mut JitCompiler, expr: Expr) -> Result<FuncId, String> { jit.compile_function( "eval_expr", vec![I64, I64], // two variables vec![I64], |builder, params| { let result = expr.compile(builder, params); builder.ins().return_(&[result]); }, ) } /// Symbol table for variable management pub struct SymbolTable { variables: HashMap<String, Variable>, next_var: usize, } impl SymbolTable { pub fn new() -> Self { Self { variables: HashMap::new(), next_var: 0, } } pub fn declare(&mut self, name: String, builder: &mut FunctionBuilder, ty: Type) -> Variable { let var = builder.declare_var(ty); self.variables.insert(name.clone(), var); self.next_var += 1; var } pub fn get(&self, name: &str) -> Option<Variable> { self.variables.get(name).copied() } } impl Default for SymbolTable { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_compile_add() { let mut jit = JitCompiler::new(); let func_id = compile_add_function(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let add_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(add_fn(2, 3), 5); assert_eq!(add_fn(10, -5), 5); } #[test] fn test_compile_factorial() { let mut jit = JitCompiler::new(); let func_id = compile_factorial(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let factorial_fn = unsafe { std::mem::transmute::<*const u8, fn(i64) -> i64>(code) }; assert_eq!(factorial_fn(0), 1); assert_eq!(factorial_fn(1), 1); assert_eq!(factorial_fn(5), 120); } #[test] fn test_compile_max() { let mut jit = JitCompiler::new(); let func_id = compile_max(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let max_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(max_fn(5, 3), 5); assert_eq!(max_fn(2, 8), 8); assert_eq!(max_fn(-5, -3), -3); } #[test] fn test_compile_expression() { let mut jit = JitCompiler::new(); // (x + 3) * (y - 2) let expr = Expr::Mul( Box::new(Expr::Add(Box::new(Expr::Var(0)), Box::new(Expr::Const(3)))), Box::new(Expr::Sub(Box::new(Expr::Var(1)), Box::new(Expr::Const(2)))), ); let func_id = compile_expression(&mut jit, expr).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let eval_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(eval_fn(5, 7), 40); // (5+3) * (7-2) = 8 * 5 = 40 assert_eq!(eval_fn(2, 4), 10); // (2+3) * (4-2) = 5 * 2 = 10 } #[test] fn test_quadratic() { let mut jit = JitCompiler::new(); let func_id = compile_quadratic(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let quad_fn = unsafe { std::mem::transmute::<*const u8, fn(f64, f64, f64, f64) -> f64>(code) }; // f(x) = 2x² + 3x + 1 // f(2) = 2*4 + 3*2 + 1 = 8 + 6 + 1 = 15 assert_eq!(quad_fn(2.0, 2.0, 3.0, 1.0), 15.0); } } /// Example: Compile a Fibonacci function pub fn compile_fibonacci(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("fibonacci", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let check_base = builder.create_block(); let recursive = builder.create_block(); let return_n = builder.create_block(); // Jump to check_base builder.ins().jump(check_base, &[]); // Check if n <= 1 builder.switch_to_block(check_base); let n_val = builder.use_var(n); let one = builder.ins().iconst(I64, 1); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, n_val, one); builder.ins().brif(cmp, return_n, &[], recursive, &[]); // Return n for base case builder.switch_to_block(return_n); builder.seal_block(return_n); builder.ins().return_(&[n_val]); // Recursive case: fib(n-1) + fib(n-2) builder.switch_to_block(recursive); builder.seal_block(recursive); builder.seal_block(check_base); // For simplicity, we'll compute iteratively let two = builder.ins().iconst(I64, 2); let a = builder.ins().iconst(I64, 0); let b = builder.ins().iconst(I64, 1); // Create loop blocks let loop_header = builder.create_block(); let loop_body = builder.create_block(); let loop_exit = builder.create_block(); builder.append_block_param(loop_header, I64); // counter builder.append_block_param(loop_header, I64); // a builder.append_block_param(loop_header, I64); // b builder .ins() .jump(loop_header, &[two.into(), a.into(), b.into()]); // Loop header: check if counter <= n builder.switch_to_block(loop_header); let counter = builder.block_params(loop_header)[0]; let curr_a = builder.block_params(loop_header)[1]; let curr_b = builder.block_params(loop_header)[2]; let cmp = builder .ins() .icmp(IntCC::SignedLessThanOrEqual, counter, n_val); builder.ins().brif(cmp, loop_body, &[], loop_exit, &[]); // Loop body: compute next fibonacci number builder.switch_to_block(loop_body); builder.seal_block(loop_body); let next_fib = builder.ins().iadd(curr_a, curr_b); let next_counter = builder.ins().iadd_imm(counter, 1); builder.ins().jump( loop_header, &[next_counter.into(), curr_b.into(), next_fib.into()], ); // Loop exit: return b builder.switch_to_block(loop_exit); builder.seal_block(loop_exit); builder.seal_block(loop_header); builder.ins().return_(&[curr_b]); }) } }
The Fibonacci implementation demonstrates iterative computation with loop-carried dependencies. The loop structure uses block parameters to maintain loop variables without mutable state. This SSA-based approach enables straightforward optimization and register allocation.
Floating Point Operations
#![allow(unused)] fn main() { use std::collections::HashMap; use cranelift::codegen::ir::types::*; use cranelift::codegen::ir::{AbiParam, Function, InstBuilder, Signature, UserFuncName}; use cranelift::codegen::settings::{self, Configurable}; use cranelift::codegen::verifier::verify_function; use cranelift::codegen::Context; use cranelift::frontend::{FunctionBuilder, FunctionBuilderContext, Variable}; use cranelift::prelude::*; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{FuncId, Linkage, Module}; /// A simple JIT compiler using Cranelift pub struct JitCompiler { builder_context: FunctionBuilderContext, ctx: Context, module: JITModule, } impl JitCompiler { pub fn new() -> Self { let mut flag_builder = settings::builder(); flag_builder.set("use_colocated_libcalls", "false").unwrap(); flag_builder.set("is_pic", "false").unwrap(); let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| { panic!("host machine is not supported: {}", msg); }); let isa = isa_builder .finish(settings::Flags::new(flag_builder)) .unwrap(); let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names()); builder.symbol_lookup_fn(Box::new(|name| { // Hook up external functions match name { "println_i64" => Some(println_i64 as *const u8), "println_f64" => Some(println_f64 as *const u8), _ => None, } })); let module = JITModule::new(builder); Self { builder_context: FunctionBuilderContext::new(), ctx: module.make_context(), module, } } pub fn compile_function( &mut self, name: &str, params: Vec<Type>, returns: Vec<Type>, build_fn: impl FnOnce(&mut FunctionBuilder, &[Variable]), ) -> Result<FuncId, String> { // Clear the context self.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), self.make_signature(params.clone(), returns.clone()), ); // Create the function builder let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); // Create variables for parameters let variables: Vec<Variable> = params.iter().map(|ty| builder.declare_var(*ty)).collect(); // Create entry block and append parameters let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); // Define parameters for (i, var) in variables.iter().enumerate() { let val = builder.block_params(entry_block)[i]; builder.def_var(*var, val); } // Call the user's function to build the body build_fn(&mut builder, &variables); // Finalize the function builder.finalize(); // Verify the function if let Err(errors) = verify_function(&self.ctx.func, self.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } // Define the function in the module let func_id = self .module .declare_function(name, Linkage::Export, &self.ctx.func.signature) .map_err(|e| e.to_string())?; self.module .define_function(func_id, &mut self.ctx) .map_err(|e| e.to_string())?; // Clear the context for next use self.module.clear_context(&mut self.ctx); Ok(func_id) } pub fn finalize(&mut self) { self.module.finalize_definitions().unwrap(); } pub fn get_function(&self, func_id: FuncId) -> *const u8 { self.module.get_finalized_function(func_id) } fn make_signature(&self, params: Vec<Type>, returns: Vec<Type>) -> Signature { let mut sig = self.module.make_signature(); for param in params { sig.params.push(AbiParam::new(param)); } for ret in returns { sig.returns.push(AbiParam::new(ret)); } sig } } impl Default for JitCompiler { fn default() -> Self { Self::new() } } extern "C" fn println_i64(x: i64) { println!("{}", x); } extern "C" fn println_f64(x: f64) { println!("{}", x); } /// Example: Compile a simple arithmetic function pub fn compile_add_function(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("add", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let sum = builder.ins().iadd(x, y); builder.ins().return_(&[sum]); }) } /// Example: Compile a factorial function pub fn compile_factorial(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("factorial", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Add block parameters builder.append_block_param(header_block, I64); // i builder.append_block_param(header_block, I64); // result // Entry: jump to header with initial values let one = builder.ins().iconst(I64, 1); builder.ins().jump(header_block, &[one.into(), one.into()]); // Header block: check if i <= n builder.switch_to_block(header_block); let i = builder.block_params(header_block)[0]; let result = builder.block_params(header_block)[1]; let n_val = builder.use_var(n); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, i, n_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body block: result *= i; i++ builder.switch_to_block(body_block); builder.seal_block(body_block); let new_result = builder.ins().imul(result, i); let new_i = builder.ins().iadd_imm(i, 1); builder .ins() .jump(header_block, &[new_i.into(), new_result.into()]); // Exit block: return result builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[result]); }) } /// Example: Compile a Fibonacci function pub fn compile_fibonacci(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("fibonacci", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let check_base = builder.create_block(); let recursive = builder.create_block(); let return_n = builder.create_block(); // Jump to check_base builder.ins().jump(check_base, &[]); // Check if n <= 1 builder.switch_to_block(check_base); let n_val = builder.use_var(n); let one = builder.ins().iconst(I64, 1); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, n_val, one); builder.ins().brif(cmp, return_n, &[], recursive, &[]); // Return n for base case builder.switch_to_block(return_n); builder.seal_block(return_n); builder.ins().return_(&[n_val]); // Recursive case: fib(n-1) + fib(n-2) builder.switch_to_block(recursive); builder.seal_block(recursive); builder.seal_block(check_base); // For simplicity, we'll compute iteratively let two = builder.ins().iconst(I64, 2); let a = builder.ins().iconst(I64, 0); let b = builder.ins().iconst(I64, 1); // Create loop blocks let loop_header = builder.create_block(); let loop_body = builder.create_block(); let loop_exit = builder.create_block(); builder.append_block_param(loop_header, I64); // counter builder.append_block_param(loop_header, I64); // a builder.append_block_param(loop_header, I64); // b builder .ins() .jump(loop_header, &[two.into(), a.into(), b.into()]); // Loop header: check if counter <= n builder.switch_to_block(loop_header); let counter = builder.block_params(loop_header)[0]; let curr_a = builder.block_params(loop_header)[1]; let curr_b = builder.block_params(loop_header)[2]; let cmp = builder .ins() .icmp(IntCC::SignedLessThanOrEqual, counter, n_val); builder.ins().brif(cmp, loop_body, &[], loop_exit, &[]); // Loop body: compute next fibonacci number builder.switch_to_block(loop_body); builder.seal_block(loop_body); let next_fib = builder.ins().iadd(curr_a, curr_b); let next_counter = builder.ins().iadd_imm(counter, 1); builder.ins().jump( loop_header, &[next_counter.into(), curr_b.into(), next_fib.into()], ); // Loop exit: return b builder.switch_to_block(loop_exit); builder.seal_block(loop_exit); builder.seal_block(loop_header); builder.ins().return_(&[curr_b]); }) } /// Example: Using external function calls pub fn compile_with_print(jit: &mut JitCompiler) -> Result<FuncId, String> { // First declare the external function let mut sig = jit.module.make_signature(); sig.params.push(AbiParam::new(I64)); let println_id = jit .module .declare_function("println_i64", Linkage::Import, &sig) .unwrap(); // Define the function let func_id = jit .module .declare_function( "print_sum", Linkage::Export, &jit.make_signature(vec![I64, I64], vec![]), ) .unwrap(); // Create function context jit.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), jit.make_signature(vec![I64, I64], vec![]), ); // Build the function { let mut builder = FunctionBuilder::new(&mut jit.ctx.func, &mut jit.builder_context); let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); let x = builder.declare_var(I64); let y = builder.declare_var(I64); let x_val = builder.block_params(entry_block)[0]; let y_val = builder.block_params(entry_block)[1]; builder.def_var(x, x_val); builder.def_var(y, y_val); let x_use = builder.use_var(x); let y_use = builder.use_var(y); let sum = builder.ins().iadd(x_use, y_use); // Declare the function reference for calling let println_ref = jit.module.declare_func_in_func(println_id, builder.func); builder.ins().call(println_ref, &[sum]); builder.ins().return_(&[]); builder.finalize(); } // Verify the function if let Err(errors) = verify_function(&jit.ctx.func, jit.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } jit.module .define_function(func_id, &mut jit.ctx) .map_err(|e| e.to_string())?; jit.module.clear_context(&mut jit.ctx); Ok(func_id) } /// Example: Control flow with multiple returns pub fn compile_max(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("max", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let then_block = builder.create_block(); let else_block = builder.create_block(); // if x > y let cmp = builder.ins().icmp(IntCC::SignedGreaterThan, x, y); builder.ins().brif(cmp, then_block, &[], else_block, &[]); // then: return x builder.switch_to_block(then_block); builder.seal_block(then_block); builder.ins().return_(&[x]); // else: return y builder.switch_to_block(else_block); builder.seal_block(else_block); builder.ins().return_(&[y]); }) } /// Example: Array/memory operations pub fn compile_sum_array(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "sum_array", vec![I64, I64], // ptr, len vec![I64], |builder, params| { let ptr = params[0]; let len = params[1]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Block parameters builder.append_block_param(header_block, I64); // index builder.append_block_param(header_block, I64); // sum builder.append_block_param(header_block, I64); // current_ptr // Initialize loop let zero = builder.ins().iconst(I64, 0); let ptr_val = builder.use_var(ptr); builder .ins() .jump(header_block, &[zero.into(), zero.into(), ptr_val.into()]); // Header: check if index < len builder.switch_to_block(header_block); let index = builder.block_params(header_block)[0]; let sum = builder.block_params(header_block)[1]; let current_ptr = builder.block_params(header_block)[2]; let len_val = builder.use_var(len); let cmp = builder.ins().icmp(IntCC::UnsignedLessThan, index, len_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body: load value and add to sum builder.switch_to_block(body_block); builder.seal_block(body_block); let flags = MemFlags::new(); let value = builder.ins().load(I64, flags, current_ptr, 0); let new_sum = builder.ins().iadd(sum, value); let new_index = builder.ins().iadd_imm(index, 1); let new_ptr = builder.ins().iadd_imm(current_ptr, 8); // 8 bytes for i64 builder.ins().jump( header_block, &[new_index.into(), new_sum.into(), new_ptr.into()], ); // Exit: return sum builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[sum]); }, ) } /// Example: Compile a simple expression evaluator #[derive(Debug, Clone)] pub enum Expr { Const(i64), Add(Box<Expr>, Box<Expr>), Sub(Box<Expr>, Box<Expr>), Mul(Box<Expr>, Box<Expr>), Var(usize), } impl Expr { pub fn compile(&self, builder: &mut FunctionBuilder, vars: &[Variable]) -> Value { match self { Expr::Const(n) => builder.ins().iconst(I64, *n), Expr::Add(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().iadd(a_val, b_val) } Expr::Sub(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().isub(a_val, b_val) } Expr::Mul(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().imul(a_val, b_val) } Expr::Var(idx) => builder.use_var(vars[*idx]), } } } pub fn compile_expression(jit: &mut JitCompiler, expr: Expr) -> Result<FuncId, String> { jit.compile_function( "eval_expr", vec![I64, I64], // two variables vec![I64], |builder, params| { let result = expr.compile(builder, params); builder.ins().return_(&[result]); }, ) } /// Symbol table for variable management pub struct SymbolTable { variables: HashMap<String, Variable>, next_var: usize, } impl SymbolTable { pub fn new() -> Self { Self { variables: HashMap::new(), next_var: 0, } } pub fn declare(&mut self, name: String, builder: &mut FunctionBuilder, ty: Type) -> Variable { let var = builder.declare_var(ty); self.variables.insert(name.clone(), var); self.next_var += 1; var } pub fn get(&self, name: &str) -> Option<Variable> { self.variables.get(name).copied() } } impl Default for SymbolTable { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_compile_add() { let mut jit = JitCompiler::new(); let func_id = compile_add_function(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let add_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(add_fn(2, 3), 5); assert_eq!(add_fn(10, -5), 5); } #[test] fn test_compile_factorial() { let mut jit = JitCompiler::new(); let func_id = compile_factorial(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let factorial_fn = unsafe { std::mem::transmute::<*const u8, fn(i64) -> i64>(code) }; assert_eq!(factorial_fn(0), 1); assert_eq!(factorial_fn(1), 1); assert_eq!(factorial_fn(5), 120); } #[test] fn test_compile_max() { let mut jit = JitCompiler::new(); let func_id = compile_max(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let max_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(max_fn(5, 3), 5); assert_eq!(max_fn(2, 8), 8); assert_eq!(max_fn(-5, -3), -3); } #[test] fn test_compile_expression() { let mut jit = JitCompiler::new(); // (x + 3) * (y - 2) let expr = Expr::Mul( Box::new(Expr::Add(Box::new(Expr::Var(0)), Box::new(Expr::Const(3)))), Box::new(Expr::Sub(Box::new(Expr::Var(1)), Box::new(Expr::Const(2)))), ); let func_id = compile_expression(&mut jit, expr).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let eval_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(eval_fn(5, 7), 40); // (5+3) * (7-2) = 8 * 5 = 40 assert_eq!(eval_fn(2, 4), 10); // (2+3) * (4-2) = 5 * 2 = 10 } #[test] fn test_quadratic() { let mut jit = JitCompiler::new(); let func_id = compile_quadratic(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let quad_fn = unsafe { std::mem::transmute::<*const u8, fn(f64, f64, f64, f64) -> f64>(code) }; // f(x) = 2x² + 3x + 1 // f(2) = 2*4 + 3*2 + 1 = 8 + 6 + 1 = 15 assert_eq!(quad_fn(2.0, 2.0, 3.0, 1.0), 15.0); } } /// Example: Working with floating point pub fn compile_quadratic(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "quadratic", vec![F64, F64, F64, F64], vec![F64], |builder, params| { // f(x) = ax² + bx + c let x = builder.use_var(params[0]); let a = builder.use_var(params[1]); let b = builder.use_var(params[2]); let c = builder.use_var(params[3]); let x_squared = builder.ins().fmul(x, x); let ax_squared = builder.ins().fmul(a, x_squared); let bx = builder.ins().fmul(b, x); let ax_squared_plus_bx = builder.ins().fadd(ax_squared, bx); let result = builder.ins().fadd(ax_squared_plus_bx, c); builder.ins().return_(&[result]); }, ) } }
Floating point arithmetic follows IEEE 754 semantics with explicit operation chains. The instruction builder maintains type safety, preventing mixing of integer and floating point operations. Compound expressions decompose into primitive operations, exposing optimization opportunities to the code generator.
External Function Calls
#![allow(unused)] fn main() { use std::collections::HashMap; use cranelift::codegen::ir::types::*; use cranelift::codegen::ir::{AbiParam, Function, InstBuilder, Signature, UserFuncName}; use cranelift::codegen::settings::{self, Configurable}; use cranelift::codegen::verifier::verify_function; use cranelift::codegen::Context; use cranelift::frontend::{FunctionBuilder, FunctionBuilderContext, Variable}; use cranelift::prelude::*; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{FuncId, Linkage, Module}; /// A simple JIT compiler using Cranelift pub struct JitCompiler { builder_context: FunctionBuilderContext, ctx: Context, module: JITModule, } impl JitCompiler { pub fn new() -> Self { let mut flag_builder = settings::builder(); flag_builder.set("use_colocated_libcalls", "false").unwrap(); flag_builder.set("is_pic", "false").unwrap(); let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| { panic!("host machine is not supported: {}", msg); }); let isa = isa_builder .finish(settings::Flags::new(flag_builder)) .unwrap(); let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names()); builder.symbol_lookup_fn(Box::new(|name| { // Hook up external functions match name { "println_i64" => Some(println_i64 as *const u8), "println_f64" => Some(println_f64 as *const u8), _ => None, } })); let module = JITModule::new(builder); Self { builder_context: FunctionBuilderContext::new(), ctx: module.make_context(), module, } } pub fn compile_function( &mut self, name: &str, params: Vec<Type>, returns: Vec<Type>, build_fn: impl FnOnce(&mut FunctionBuilder, &[Variable]), ) -> Result<FuncId, String> { // Clear the context self.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), self.make_signature(params.clone(), returns.clone()), ); // Create the function builder let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); // Create variables for parameters let variables: Vec<Variable> = params.iter().map(|ty| builder.declare_var(*ty)).collect(); // Create entry block and append parameters let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); // Define parameters for (i, var) in variables.iter().enumerate() { let val = builder.block_params(entry_block)[i]; builder.def_var(*var, val); } // Call the user's function to build the body build_fn(&mut builder, &variables); // Finalize the function builder.finalize(); // Verify the function if let Err(errors) = verify_function(&self.ctx.func, self.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } // Define the function in the module let func_id = self .module .declare_function(name, Linkage::Export, &self.ctx.func.signature) .map_err(|e| e.to_string())?; self.module .define_function(func_id, &mut self.ctx) .map_err(|e| e.to_string())?; // Clear the context for next use self.module.clear_context(&mut self.ctx); Ok(func_id) } pub fn finalize(&mut self) { self.module.finalize_definitions().unwrap(); } pub fn get_function(&self, func_id: FuncId) -> *const u8 { self.module.get_finalized_function(func_id) } fn make_signature(&self, params: Vec<Type>, returns: Vec<Type>) -> Signature { let mut sig = self.module.make_signature(); for param in params { sig.params.push(AbiParam::new(param)); } for ret in returns { sig.returns.push(AbiParam::new(ret)); } sig } } impl Default for JitCompiler { fn default() -> Self { Self::new() } } extern "C" fn println_i64(x: i64) { println!("{}", x); } extern "C" fn println_f64(x: f64) { println!("{}", x); } /// Example: Compile a simple arithmetic function pub fn compile_add_function(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("add", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let sum = builder.ins().iadd(x, y); builder.ins().return_(&[sum]); }) } /// Example: Compile a factorial function pub fn compile_factorial(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("factorial", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Add block parameters builder.append_block_param(header_block, I64); // i builder.append_block_param(header_block, I64); // result // Entry: jump to header with initial values let one = builder.ins().iconst(I64, 1); builder.ins().jump(header_block, &[one.into(), one.into()]); // Header block: check if i <= n builder.switch_to_block(header_block); let i = builder.block_params(header_block)[0]; let result = builder.block_params(header_block)[1]; let n_val = builder.use_var(n); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, i, n_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body block: result *= i; i++ builder.switch_to_block(body_block); builder.seal_block(body_block); let new_result = builder.ins().imul(result, i); let new_i = builder.ins().iadd_imm(i, 1); builder .ins() .jump(header_block, &[new_i.into(), new_result.into()]); // Exit block: return result builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[result]); }) } /// Example: Compile a Fibonacci function pub fn compile_fibonacci(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("fibonacci", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let check_base = builder.create_block(); let recursive = builder.create_block(); let return_n = builder.create_block(); // Jump to check_base builder.ins().jump(check_base, &[]); // Check if n <= 1 builder.switch_to_block(check_base); let n_val = builder.use_var(n); let one = builder.ins().iconst(I64, 1); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, n_val, one); builder.ins().brif(cmp, return_n, &[], recursive, &[]); // Return n for base case builder.switch_to_block(return_n); builder.seal_block(return_n); builder.ins().return_(&[n_val]); // Recursive case: fib(n-1) + fib(n-2) builder.switch_to_block(recursive); builder.seal_block(recursive); builder.seal_block(check_base); // For simplicity, we'll compute iteratively let two = builder.ins().iconst(I64, 2); let a = builder.ins().iconst(I64, 0); let b = builder.ins().iconst(I64, 1); // Create loop blocks let loop_header = builder.create_block(); let loop_body = builder.create_block(); let loop_exit = builder.create_block(); builder.append_block_param(loop_header, I64); // counter builder.append_block_param(loop_header, I64); // a builder.append_block_param(loop_header, I64); // b builder .ins() .jump(loop_header, &[two.into(), a.into(), b.into()]); // Loop header: check if counter <= n builder.switch_to_block(loop_header); let counter = builder.block_params(loop_header)[0]; let curr_a = builder.block_params(loop_header)[1]; let curr_b = builder.block_params(loop_header)[2]; let cmp = builder .ins() .icmp(IntCC::SignedLessThanOrEqual, counter, n_val); builder.ins().brif(cmp, loop_body, &[], loop_exit, &[]); // Loop body: compute next fibonacci number builder.switch_to_block(loop_body); builder.seal_block(loop_body); let next_fib = builder.ins().iadd(curr_a, curr_b); let next_counter = builder.ins().iadd_imm(counter, 1); builder.ins().jump( loop_header, &[next_counter.into(), curr_b.into(), next_fib.into()], ); // Loop exit: return b builder.switch_to_block(loop_exit); builder.seal_block(loop_exit); builder.seal_block(loop_header); builder.ins().return_(&[curr_b]); }) } /// Example: Working with floating point pub fn compile_quadratic(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "quadratic", vec![F64, F64, F64, F64], vec![F64], |builder, params| { // f(x) = ax² + bx + c let x = builder.use_var(params[0]); let a = builder.use_var(params[1]); let b = builder.use_var(params[2]); let c = builder.use_var(params[3]); let x_squared = builder.ins().fmul(x, x); let ax_squared = builder.ins().fmul(a, x_squared); let bx = builder.ins().fmul(b, x); let ax_squared_plus_bx = builder.ins().fadd(ax_squared, bx); let result = builder.ins().fadd(ax_squared_plus_bx, c); builder.ins().return_(&[result]); }, ) } /// Example: Control flow with multiple returns pub fn compile_max(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("max", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let then_block = builder.create_block(); let else_block = builder.create_block(); // if x > y let cmp = builder.ins().icmp(IntCC::SignedGreaterThan, x, y); builder.ins().brif(cmp, then_block, &[], else_block, &[]); // then: return x builder.switch_to_block(then_block); builder.seal_block(then_block); builder.ins().return_(&[x]); // else: return y builder.switch_to_block(else_block); builder.seal_block(else_block); builder.ins().return_(&[y]); }) } /// Example: Array/memory operations pub fn compile_sum_array(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "sum_array", vec![I64, I64], // ptr, len vec![I64], |builder, params| { let ptr = params[0]; let len = params[1]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Block parameters builder.append_block_param(header_block, I64); // index builder.append_block_param(header_block, I64); // sum builder.append_block_param(header_block, I64); // current_ptr // Initialize loop let zero = builder.ins().iconst(I64, 0); let ptr_val = builder.use_var(ptr); builder .ins() .jump(header_block, &[zero.into(), zero.into(), ptr_val.into()]); // Header: check if index < len builder.switch_to_block(header_block); let index = builder.block_params(header_block)[0]; let sum = builder.block_params(header_block)[1]; let current_ptr = builder.block_params(header_block)[2]; let len_val = builder.use_var(len); let cmp = builder.ins().icmp(IntCC::UnsignedLessThan, index, len_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body: load value and add to sum builder.switch_to_block(body_block); builder.seal_block(body_block); let flags = MemFlags::new(); let value = builder.ins().load(I64, flags, current_ptr, 0); let new_sum = builder.ins().iadd(sum, value); let new_index = builder.ins().iadd_imm(index, 1); let new_ptr = builder.ins().iadd_imm(current_ptr, 8); // 8 bytes for i64 builder.ins().jump( header_block, &[new_index.into(), new_sum.into(), new_ptr.into()], ); // Exit: return sum builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[sum]); }, ) } /// Example: Compile a simple expression evaluator #[derive(Debug, Clone)] pub enum Expr { Const(i64), Add(Box<Expr>, Box<Expr>), Sub(Box<Expr>, Box<Expr>), Mul(Box<Expr>, Box<Expr>), Var(usize), } impl Expr { pub fn compile(&self, builder: &mut FunctionBuilder, vars: &[Variable]) -> Value { match self { Expr::Const(n) => builder.ins().iconst(I64, *n), Expr::Add(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().iadd(a_val, b_val) } Expr::Sub(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().isub(a_val, b_val) } Expr::Mul(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().imul(a_val, b_val) } Expr::Var(idx) => builder.use_var(vars[*idx]), } } } pub fn compile_expression(jit: &mut JitCompiler, expr: Expr) -> Result<FuncId, String> { jit.compile_function( "eval_expr", vec![I64, I64], // two variables vec![I64], |builder, params| { let result = expr.compile(builder, params); builder.ins().return_(&[result]); }, ) } /// Symbol table for variable management pub struct SymbolTable { variables: HashMap<String, Variable>, next_var: usize, } impl SymbolTable { pub fn new() -> Self { Self { variables: HashMap::new(), next_var: 0, } } pub fn declare(&mut self, name: String, builder: &mut FunctionBuilder, ty: Type) -> Variable { let var = builder.declare_var(ty); self.variables.insert(name.clone(), var); self.next_var += 1; var } pub fn get(&self, name: &str) -> Option<Variable> { self.variables.get(name).copied() } } impl Default for SymbolTable { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_compile_add() { let mut jit = JitCompiler::new(); let func_id = compile_add_function(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let add_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(add_fn(2, 3), 5); assert_eq!(add_fn(10, -5), 5); } #[test] fn test_compile_factorial() { let mut jit = JitCompiler::new(); let func_id = compile_factorial(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let factorial_fn = unsafe { std::mem::transmute::<*const u8, fn(i64) -> i64>(code) }; assert_eq!(factorial_fn(0), 1); assert_eq!(factorial_fn(1), 1); assert_eq!(factorial_fn(5), 120); } #[test] fn test_compile_max() { let mut jit = JitCompiler::new(); let func_id = compile_max(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let max_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(max_fn(5, 3), 5); assert_eq!(max_fn(2, 8), 8); assert_eq!(max_fn(-5, -3), -3); } #[test] fn test_compile_expression() { let mut jit = JitCompiler::new(); // (x + 3) * (y - 2) let expr = Expr::Mul( Box::new(Expr::Add(Box::new(Expr::Var(0)), Box::new(Expr::Const(3)))), Box::new(Expr::Sub(Box::new(Expr::Var(1)), Box::new(Expr::Const(2)))), ); let func_id = compile_expression(&mut jit, expr).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let eval_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(eval_fn(5, 7), 40); // (5+3) * (7-2) = 8 * 5 = 40 assert_eq!(eval_fn(2, 4), 10); // (2+3) * (4-2) = 5 * 2 = 10 } #[test] fn test_quadratic() { let mut jit = JitCompiler::new(); let func_id = compile_quadratic(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let quad_fn = unsafe { std::mem::transmute::<*const u8, fn(f64, f64, f64, f64) -> f64>(code) }; // f(x) = 2x² + 3x + 1 // f(2) = 2*4 + 3*2 + 1 = 8 + 6 + 1 = 15 assert_eq!(quad_fn(2.0, 2.0, 3.0, 1.0), 15.0); } } /// Example: Using external function calls pub fn compile_with_print(jit: &mut JitCompiler) -> Result<FuncId, String> { // First declare the external function let mut sig = jit.module.make_signature(); sig.params.push(AbiParam::new(I64)); let println_id = jit .module .declare_function("println_i64", Linkage::Import, &sig) .unwrap(); // Define the function let func_id = jit .module .declare_function( "print_sum", Linkage::Export, &jit.make_signature(vec![I64, I64], vec![]), ) .unwrap(); // Create function context jit.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), jit.make_signature(vec![I64, I64], vec![]), ); // Build the function { let mut builder = FunctionBuilder::new(&mut jit.ctx.func, &mut jit.builder_context); let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); let x = builder.declare_var(I64); let y = builder.declare_var(I64); let x_val = builder.block_params(entry_block)[0]; let y_val = builder.block_params(entry_block)[1]; builder.def_var(x, x_val); builder.def_var(y, y_val); let x_use = builder.use_var(x); let y_use = builder.use_var(y); let sum = builder.ins().iadd(x_use, y_use); // Declare the function reference for calling let println_ref = jit.module.declare_func_in_func(println_id, builder.func); builder.ins().call(println_ref, &[sum]); builder.ins().return_(&[]); builder.finalize(); } // Verify the function if let Err(errors) = verify_function(&jit.ctx.func, jit.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } jit.module .define_function(func_id, &mut jit.ctx) .map_err(|e| e.to_string())?; jit.module.clear_context(&mut jit.ctx); Ok(func_id) } }
External function integration enables interaction with the runtime environment. Function declarations specify the calling convention and signature, while the import linkage indicates external resolution. The module system manages cross-function references, supporting both ahead-of-time and just-in-time linking models.
Memory Operations
#![allow(unused)] fn main() { use std::collections::HashMap; use cranelift::codegen::ir::types::*; use cranelift::codegen::ir::{AbiParam, Function, InstBuilder, Signature, UserFuncName}; use cranelift::codegen::settings::{self, Configurable}; use cranelift::codegen::verifier::verify_function; use cranelift::codegen::Context; use cranelift::frontend::{FunctionBuilder, FunctionBuilderContext, Variable}; use cranelift::prelude::*; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{FuncId, Linkage, Module}; /// A simple JIT compiler using Cranelift pub struct JitCompiler { builder_context: FunctionBuilderContext, ctx: Context, module: JITModule, } impl JitCompiler { pub fn new() -> Self { let mut flag_builder = settings::builder(); flag_builder.set("use_colocated_libcalls", "false").unwrap(); flag_builder.set("is_pic", "false").unwrap(); let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| { panic!("host machine is not supported: {}", msg); }); let isa = isa_builder .finish(settings::Flags::new(flag_builder)) .unwrap(); let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names()); builder.symbol_lookup_fn(Box::new(|name| { // Hook up external functions match name { "println_i64" => Some(println_i64 as *const u8), "println_f64" => Some(println_f64 as *const u8), _ => None, } })); let module = JITModule::new(builder); Self { builder_context: FunctionBuilderContext::new(), ctx: module.make_context(), module, } } pub fn compile_function( &mut self, name: &str, params: Vec<Type>, returns: Vec<Type>, build_fn: impl FnOnce(&mut FunctionBuilder, &[Variable]), ) -> Result<FuncId, String> { // Clear the context self.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), self.make_signature(params.clone(), returns.clone()), ); // Create the function builder let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); // Create variables for parameters let variables: Vec<Variable> = params.iter().map(|ty| builder.declare_var(*ty)).collect(); // Create entry block and append parameters let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); // Define parameters for (i, var) in variables.iter().enumerate() { let val = builder.block_params(entry_block)[i]; builder.def_var(*var, val); } // Call the user's function to build the body build_fn(&mut builder, &variables); // Finalize the function builder.finalize(); // Verify the function if let Err(errors) = verify_function(&self.ctx.func, self.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } // Define the function in the module let func_id = self .module .declare_function(name, Linkage::Export, &self.ctx.func.signature) .map_err(|e| e.to_string())?; self.module .define_function(func_id, &mut self.ctx) .map_err(|e| e.to_string())?; // Clear the context for next use self.module.clear_context(&mut self.ctx); Ok(func_id) } pub fn finalize(&mut self) { self.module.finalize_definitions().unwrap(); } pub fn get_function(&self, func_id: FuncId) -> *const u8 { self.module.get_finalized_function(func_id) } fn make_signature(&self, params: Vec<Type>, returns: Vec<Type>) -> Signature { let mut sig = self.module.make_signature(); for param in params { sig.params.push(AbiParam::new(param)); } for ret in returns { sig.returns.push(AbiParam::new(ret)); } sig } } impl Default for JitCompiler { fn default() -> Self { Self::new() } } extern "C" fn println_i64(x: i64) { println!("{}", x); } extern "C" fn println_f64(x: f64) { println!("{}", x); } /// Example: Compile a simple arithmetic function pub fn compile_add_function(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("add", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let sum = builder.ins().iadd(x, y); builder.ins().return_(&[sum]); }) } /// Example: Compile a factorial function pub fn compile_factorial(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("factorial", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Add block parameters builder.append_block_param(header_block, I64); // i builder.append_block_param(header_block, I64); // result // Entry: jump to header with initial values let one = builder.ins().iconst(I64, 1); builder.ins().jump(header_block, &[one.into(), one.into()]); // Header block: check if i <= n builder.switch_to_block(header_block); let i = builder.block_params(header_block)[0]; let result = builder.block_params(header_block)[1]; let n_val = builder.use_var(n); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, i, n_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body block: result *= i; i++ builder.switch_to_block(body_block); builder.seal_block(body_block); let new_result = builder.ins().imul(result, i); let new_i = builder.ins().iadd_imm(i, 1); builder .ins() .jump(header_block, &[new_i.into(), new_result.into()]); // Exit block: return result builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[result]); }) } /// Example: Compile a Fibonacci function pub fn compile_fibonacci(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("fibonacci", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let check_base = builder.create_block(); let recursive = builder.create_block(); let return_n = builder.create_block(); // Jump to check_base builder.ins().jump(check_base, &[]); // Check if n <= 1 builder.switch_to_block(check_base); let n_val = builder.use_var(n); let one = builder.ins().iconst(I64, 1); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, n_val, one); builder.ins().brif(cmp, return_n, &[], recursive, &[]); // Return n for base case builder.switch_to_block(return_n); builder.seal_block(return_n); builder.ins().return_(&[n_val]); // Recursive case: fib(n-1) + fib(n-2) builder.switch_to_block(recursive); builder.seal_block(recursive); builder.seal_block(check_base); // For simplicity, we'll compute iteratively let two = builder.ins().iconst(I64, 2); let a = builder.ins().iconst(I64, 0); let b = builder.ins().iconst(I64, 1); // Create loop blocks let loop_header = builder.create_block(); let loop_body = builder.create_block(); let loop_exit = builder.create_block(); builder.append_block_param(loop_header, I64); // counter builder.append_block_param(loop_header, I64); // a builder.append_block_param(loop_header, I64); // b builder .ins() .jump(loop_header, &[two.into(), a.into(), b.into()]); // Loop header: check if counter <= n builder.switch_to_block(loop_header); let counter = builder.block_params(loop_header)[0]; let curr_a = builder.block_params(loop_header)[1]; let curr_b = builder.block_params(loop_header)[2]; let cmp = builder .ins() .icmp(IntCC::SignedLessThanOrEqual, counter, n_val); builder.ins().brif(cmp, loop_body, &[], loop_exit, &[]); // Loop body: compute next fibonacci number builder.switch_to_block(loop_body); builder.seal_block(loop_body); let next_fib = builder.ins().iadd(curr_a, curr_b); let next_counter = builder.ins().iadd_imm(counter, 1); builder.ins().jump( loop_header, &[next_counter.into(), curr_b.into(), next_fib.into()], ); // Loop exit: return b builder.switch_to_block(loop_exit); builder.seal_block(loop_exit); builder.seal_block(loop_header); builder.ins().return_(&[curr_b]); }) } /// Example: Working with floating point pub fn compile_quadratic(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "quadratic", vec![F64, F64, F64, F64], vec![F64], |builder, params| { // f(x) = ax² + bx + c let x = builder.use_var(params[0]); let a = builder.use_var(params[1]); let b = builder.use_var(params[2]); let c = builder.use_var(params[3]); let x_squared = builder.ins().fmul(x, x); let ax_squared = builder.ins().fmul(a, x_squared); let bx = builder.ins().fmul(b, x); let ax_squared_plus_bx = builder.ins().fadd(ax_squared, bx); let result = builder.ins().fadd(ax_squared_plus_bx, c); builder.ins().return_(&[result]); }, ) } /// Example: Using external function calls pub fn compile_with_print(jit: &mut JitCompiler) -> Result<FuncId, String> { // First declare the external function let mut sig = jit.module.make_signature(); sig.params.push(AbiParam::new(I64)); let println_id = jit .module .declare_function("println_i64", Linkage::Import, &sig) .unwrap(); // Define the function let func_id = jit .module .declare_function( "print_sum", Linkage::Export, &jit.make_signature(vec![I64, I64], vec![]), ) .unwrap(); // Create function context jit.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), jit.make_signature(vec![I64, I64], vec![]), ); // Build the function { let mut builder = FunctionBuilder::new(&mut jit.ctx.func, &mut jit.builder_context); let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); let x = builder.declare_var(I64); let y = builder.declare_var(I64); let x_val = builder.block_params(entry_block)[0]; let y_val = builder.block_params(entry_block)[1]; builder.def_var(x, x_val); builder.def_var(y, y_val); let x_use = builder.use_var(x); let y_use = builder.use_var(y); let sum = builder.ins().iadd(x_use, y_use); // Declare the function reference for calling let println_ref = jit.module.declare_func_in_func(println_id, builder.func); builder.ins().call(println_ref, &[sum]); builder.ins().return_(&[]); builder.finalize(); } // Verify the function if let Err(errors) = verify_function(&jit.ctx.func, jit.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } jit.module .define_function(func_id, &mut jit.ctx) .map_err(|e| e.to_string())?; jit.module.clear_context(&mut jit.ctx); Ok(func_id) } /// Example: Control flow with multiple returns pub fn compile_max(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("max", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let then_block = builder.create_block(); let else_block = builder.create_block(); // if x > y let cmp = builder.ins().icmp(IntCC::SignedGreaterThan, x, y); builder.ins().brif(cmp, then_block, &[], else_block, &[]); // then: return x builder.switch_to_block(then_block); builder.seal_block(then_block); builder.ins().return_(&[x]); // else: return y builder.switch_to_block(else_block); builder.seal_block(else_block); builder.ins().return_(&[y]); }) } /// Example: Compile a simple expression evaluator #[derive(Debug, Clone)] pub enum Expr { Const(i64), Add(Box<Expr>, Box<Expr>), Sub(Box<Expr>, Box<Expr>), Mul(Box<Expr>, Box<Expr>), Var(usize), } impl Expr { pub fn compile(&self, builder: &mut FunctionBuilder, vars: &[Variable]) -> Value { match self { Expr::Const(n) => builder.ins().iconst(I64, *n), Expr::Add(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().iadd(a_val, b_val) } Expr::Sub(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().isub(a_val, b_val) } Expr::Mul(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().imul(a_val, b_val) } Expr::Var(idx) => builder.use_var(vars[*idx]), } } } pub fn compile_expression(jit: &mut JitCompiler, expr: Expr) -> Result<FuncId, String> { jit.compile_function( "eval_expr", vec![I64, I64], // two variables vec![I64], |builder, params| { let result = expr.compile(builder, params); builder.ins().return_(&[result]); }, ) } /// Symbol table for variable management pub struct SymbolTable { variables: HashMap<String, Variable>, next_var: usize, } impl SymbolTable { pub fn new() -> Self { Self { variables: HashMap::new(), next_var: 0, } } pub fn declare(&mut self, name: String, builder: &mut FunctionBuilder, ty: Type) -> Variable { let var = builder.declare_var(ty); self.variables.insert(name.clone(), var); self.next_var += 1; var } pub fn get(&self, name: &str) -> Option<Variable> { self.variables.get(name).copied() } } impl Default for SymbolTable { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_compile_add() { let mut jit = JitCompiler::new(); let func_id = compile_add_function(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let add_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(add_fn(2, 3), 5); assert_eq!(add_fn(10, -5), 5); } #[test] fn test_compile_factorial() { let mut jit = JitCompiler::new(); let func_id = compile_factorial(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let factorial_fn = unsafe { std::mem::transmute::<*const u8, fn(i64) -> i64>(code) }; assert_eq!(factorial_fn(0), 1); assert_eq!(factorial_fn(1), 1); assert_eq!(factorial_fn(5), 120); } #[test] fn test_compile_max() { let mut jit = JitCompiler::new(); let func_id = compile_max(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let max_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(max_fn(5, 3), 5); assert_eq!(max_fn(2, 8), 8); assert_eq!(max_fn(-5, -3), -3); } #[test] fn test_compile_expression() { let mut jit = JitCompiler::new(); // (x + 3) * (y - 2) let expr = Expr::Mul( Box::new(Expr::Add(Box::new(Expr::Var(0)), Box::new(Expr::Const(3)))), Box::new(Expr::Sub(Box::new(Expr::Var(1)), Box::new(Expr::Const(2)))), ); let func_id = compile_expression(&mut jit, expr).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let eval_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(eval_fn(5, 7), 40); // (5+3) * (7-2) = 8 * 5 = 40 assert_eq!(eval_fn(2, 4), 10); // (2+3) * (4-2) = 5 * 2 = 10 } #[test] fn test_quadratic() { let mut jit = JitCompiler::new(); let func_id = compile_quadratic(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let quad_fn = unsafe { std::mem::transmute::<*const u8, fn(f64, f64, f64, f64) -> f64>(code) }; // f(x) = 2x² + 3x + 1 // f(2) = 2*4 + 3*2 + 1 = 8 + 6 + 1 = 15 assert_eq!(quad_fn(2.0, 2.0, 3.0, 1.0), 15.0); } } /// Example: Array/memory operations pub fn compile_sum_array(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "sum_array", vec![I64, I64], // ptr, len vec![I64], |builder, params| { let ptr = params[0]; let len = params[1]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Block parameters builder.append_block_param(header_block, I64); // index builder.append_block_param(header_block, I64); // sum builder.append_block_param(header_block, I64); // current_ptr // Initialize loop let zero = builder.ins().iconst(I64, 0); let ptr_val = builder.use_var(ptr); builder .ins() .jump(header_block, &[zero.into(), zero.into(), ptr_val.into()]); // Header: check if index < len builder.switch_to_block(header_block); let index = builder.block_params(header_block)[0]; let sum = builder.block_params(header_block)[1]; let current_ptr = builder.block_params(header_block)[2]; let len_val = builder.use_var(len); let cmp = builder.ins().icmp(IntCC::UnsignedLessThan, index, len_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body: load value and add to sum builder.switch_to_block(body_block); builder.seal_block(body_block); let flags = MemFlags::new(); let value = builder.ins().load(I64, flags, current_ptr, 0); let new_sum = builder.ins().iadd(sum, value); let new_index = builder.ins().iadd_imm(index, 1); let new_ptr = builder.ins().iadd_imm(current_ptr, 8); // 8 bytes for i64 builder.ins().jump( header_block, &[new_index.into(), new_sum.into(), new_ptr.into()], ); // Exit: return sum builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[sum]); }, ) } }
Memory access demonstrates pointer arithmetic and load operations. MemFlags specify aliasing and alignment properties, enabling optimization while maintaining correctness. The explicit pointer increment reflects the low-level nature of Cranelift IR, providing direct control over memory access patterns.
Expression Trees
#![allow(unused)] fn main() { use std::collections::HashMap; use cranelift::codegen::ir::types::*; use cranelift::codegen::ir::{AbiParam, Function, InstBuilder, Signature, UserFuncName}; use cranelift::codegen::settings::{self, Configurable}; use cranelift::codegen::verifier::verify_function; use cranelift::codegen::Context; use cranelift::frontend::{FunctionBuilder, FunctionBuilderContext, Variable}; use cranelift::prelude::*; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{FuncId, Linkage, Module}; /// A simple JIT compiler using Cranelift pub struct JitCompiler { builder_context: FunctionBuilderContext, ctx: Context, module: JITModule, } impl JitCompiler { pub fn new() -> Self { let mut flag_builder = settings::builder(); flag_builder.set("use_colocated_libcalls", "false").unwrap(); flag_builder.set("is_pic", "false").unwrap(); let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| { panic!("host machine is not supported: {}", msg); }); let isa = isa_builder .finish(settings::Flags::new(flag_builder)) .unwrap(); let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names()); builder.symbol_lookup_fn(Box::new(|name| { // Hook up external functions match name { "println_i64" => Some(println_i64 as *const u8), "println_f64" => Some(println_f64 as *const u8), _ => None, } })); let module = JITModule::new(builder); Self { builder_context: FunctionBuilderContext::new(), ctx: module.make_context(), module, } } pub fn compile_function( &mut self, name: &str, params: Vec<Type>, returns: Vec<Type>, build_fn: impl FnOnce(&mut FunctionBuilder, &[Variable]), ) -> Result<FuncId, String> { // Clear the context self.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), self.make_signature(params.clone(), returns.clone()), ); // Create the function builder let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); // Create variables for parameters let variables: Vec<Variable> = params.iter().map(|ty| builder.declare_var(*ty)).collect(); // Create entry block and append parameters let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); // Define parameters for (i, var) in variables.iter().enumerate() { let val = builder.block_params(entry_block)[i]; builder.def_var(*var, val); } // Call the user's function to build the body build_fn(&mut builder, &variables); // Finalize the function builder.finalize(); // Verify the function if let Err(errors) = verify_function(&self.ctx.func, self.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } // Define the function in the module let func_id = self .module .declare_function(name, Linkage::Export, &self.ctx.func.signature) .map_err(|e| e.to_string())?; self.module .define_function(func_id, &mut self.ctx) .map_err(|e| e.to_string())?; // Clear the context for next use self.module.clear_context(&mut self.ctx); Ok(func_id) } pub fn finalize(&mut self) { self.module.finalize_definitions().unwrap(); } pub fn get_function(&self, func_id: FuncId) -> *const u8 { self.module.get_finalized_function(func_id) } fn make_signature(&self, params: Vec<Type>, returns: Vec<Type>) -> Signature { let mut sig = self.module.make_signature(); for param in params { sig.params.push(AbiParam::new(param)); } for ret in returns { sig.returns.push(AbiParam::new(ret)); } sig } } impl Default for JitCompiler { fn default() -> Self { Self::new() } } extern "C" fn println_i64(x: i64) { println!("{}", x); } extern "C" fn println_f64(x: f64) { println!("{}", x); } /// Example: Compile a simple arithmetic function pub fn compile_add_function(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("add", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let sum = builder.ins().iadd(x, y); builder.ins().return_(&[sum]); }) } /// Example: Compile a factorial function pub fn compile_factorial(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("factorial", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Add block parameters builder.append_block_param(header_block, I64); // i builder.append_block_param(header_block, I64); // result // Entry: jump to header with initial values let one = builder.ins().iconst(I64, 1); builder.ins().jump(header_block, &[one.into(), one.into()]); // Header block: check if i <= n builder.switch_to_block(header_block); let i = builder.block_params(header_block)[0]; let result = builder.block_params(header_block)[1]; let n_val = builder.use_var(n); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, i, n_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body block: result *= i; i++ builder.switch_to_block(body_block); builder.seal_block(body_block); let new_result = builder.ins().imul(result, i); let new_i = builder.ins().iadd_imm(i, 1); builder .ins() .jump(header_block, &[new_i.into(), new_result.into()]); // Exit block: return result builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[result]); }) } /// Example: Compile a Fibonacci function pub fn compile_fibonacci(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("fibonacci", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let check_base = builder.create_block(); let recursive = builder.create_block(); let return_n = builder.create_block(); // Jump to check_base builder.ins().jump(check_base, &[]); // Check if n <= 1 builder.switch_to_block(check_base); let n_val = builder.use_var(n); let one = builder.ins().iconst(I64, 1); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, n_val, one); builder.ins().brif(cmp, return_n, &[], recursive, &[]); // Return n for base case builder.switch_to_block(return_n); builder.seal_block(return_n); builder.ins().return_(&[n_val]); // Recursive case: fib(n-1) + fib(n-2) builder.switch_to_block(recursive); builder.seal_block(recursive); builder.seal_block(check_base); // For simplicity, we'll compute iteratively let two = builder.ins().iconst(I64, 2); let a = builder.ins().iconst(I64, 0); let b = builder.ins().iconst(I64, 1); // Create loop blocks let loop_header = builder.create_block(); let loop_body = builder.create_block(); let loop_exit = builder.create_block(); builder.append_block_param(loop_header, I64); // counter builder.append_block_param(loop_header, I64); // a builder.append_block_param(loop_header, I64); // b builder .ins() .jump(loop_header, &[two.into(), a.into(), b.into()]); // Loop header: check if counter <= n builder.switch_to_block(loop_header); let counter = builder.block_params(loop_header)[0]; let curr_a = builder.block_params(loop_header)[1]; let curr_b = builder.block_params(loop_header)[2]; let cmp = builder .ins() .icmp(IntCC::SignedLessThanOrEqual, counter, n_val); builder.ins().brif(cmp, loop_body, &[], loop_exit, &[]); // Loop body: compute next fibonacci number builder.switch_to_block(loop_body); builder.seal_block(loop_body); let next_fib = builder.ins().iadd(curr_a, curr_b); let next_counter = builder.ins().iadd_imm(counter, 1); builder.ins().jump( loop_header, &[next_counter.into(), curr_b.into(), next_fib.into()], ); // Loop exit: return b builder.switch_to_block(loop_exit); builder.seal_block(loop_exit); builder.seal_block(loop_header); builder.ins().return_(&[curr_b]); }) } /// Example: Working with floating point pub fn compile_quadratic(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "quadratic", vec![F64, F64, F64, F64], vec![F64], |builder, params| { // f(x) = ax² + bx + c let x = builder.use_var(params[0]); let a = builder.use_var(params[1]); let b = builder.use_var(params[2]); let c = builder.use_var(params[3]); let x_squared = builder.ins().fmul(x, x); let ax_squared = builder.ins().fmul(a, x_squared); let bx = builder.ins().fmul(b, x); let ax_squared_plus_bx = builder.ins().fadd(ax_squared, bx); let result = builder.ins().fadd(ax_squared_plus_bx, c); builder.ins().return_(&[result]); }, ) } /// Example: Using external function calls pub fn compile_with_print(jit: &mut JitCompiler) -> Result<FuncId, String> { // First declare the external function let mut sig = jit.module.make_signature(); sig.params.push(AbiParam::new(I64)); let println_id = jit .module .declare_function("println_i64", Linkage::Import, &sig) .unwrap(); // Define the function let func_id = jit .module .declare_function( "print_sum", Linkage::Export, &jit.make_signature(vec![I64, I64], vec![]), ) .unwrap(); // Create function context jit.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), jit.make_signature(vec![I64, I64], vec![]), ); // Build the function { let mut builder = FunctionBuilder::new(&mut jit.ctx.func, &mut jit.builder_context); let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); let x = builder.declare_var(I64); let y = builder.declare_var(I64); let x_val = builder.block_params(entry_block)[0]; let y_val = builder.block_params(entry_block)[1]; builder.def_var(x, x_val); builder.def_var(y, y_val); let x_use = builder.use_var(x); let y_use = builder.use_var(y); let sum = builder.ins().iadd(x_use, y_use); // Declare the function reference for calling let println_ref = jit.module.declare_func_in_func(println_id, builder.func); builder.ins().call(println_ref, &[sum]); builder.ins().return_(&[]); builder.finalize(); } // Verify the function if let Err(errors) = verify_function(&jit.ctx.func, jit.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } jit.module .define_function(func_id, &mut jit.ctx) .map_err(|e| e.to_string())?; jit.module.clear_context(&mut jit.ctx); Ok(func_id) } /// Example: Control flow with multiple returns pub fn compile_max(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("max", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let then_block = builder.create_block(); let else_block = builder.create_block(); // if x > y let cmp = builder.ins().icmp(IntCC::SignedGreaterThan, x, y); builder.ins().brif(cmp, then_block, &[], else_block, &[]); // then: return x builder.switch_to_block(then_block); builder.seal_block(then_block); builder.ins().return_(&[x]); // else: return y builder.switch_to_block(else_block); builder.seal_block(else_block); builder.ins().return_(&[y]); }) } /// Example: Array/memory operations pub fn compile_sum_array(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "sum_array", vec![I64, I64], // ptr, len vec![I64], |builder, params| { let ptr = params[0]; let len = params[1]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Block parameters builder.append_block_param(header_block, I64); // index builder.append_block_param(header_block, I64); // sum builder.append_block_param(header_block, I64); // current_ptr // Initialize loop let zero = builder.ins().iconst(I64, 0); let ptr_val = builder.use_var(ptr); builder .ins() .jump(header_block, &[zero.into(), zero.into(), ptr_val.into()]); // Header: check if index < len builder.switch_to_block(header_block); let index = builder.block_params(header_block)[0]; let sum = builder.block_params(header_block)[1]; let current_ptr = builder.block_params(header_block)[2]; let len_val = builder.use_var(len); let cmp = builder.ins().icmp(IntCC::UnsignedLessThan, index, len_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body: load value and add to sum builder.switch_to_block(body_block); builder.seal_block(body_block); let flags = MemFlags::new(); let value = builder.ins().load(I64, flags, current_ptr, 0); let new_sum = builder.ins().iadd(sum, value); let new_index = builder.ins().iadd_imm(index, 1); let new_ptr = builder.ins().iadd_imm(current_ptr, 8); // 8 bytes for i64 builder.ins().jump( header_block, &[new_index.into(), new_sum.into(), new_ptr.into()], ); // Exit: return sum builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[sum]); }, ) } impl Expr { pub fn compile(&self, builder: &mut FunctionBuilder, vars: &[Variable]) -> Value { match self { Expr::Const(n) => builder.ins().iconst(I64, *n), Expr::Add(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().iadd(a_val, b_val) } Expr::Sub(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().isub(a_val, b_val) } Expr::Mul(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().imul(a_val, b_val) } Expr::Var(idx) => builder.use_var(vars[*idx]), } } } pub fn compile_expression(jit: &mut JitCompiler, expr: Expr) -> Result<FuncId, String> { jit.compile_function( "eval_expr", vec![I64, I64], // two variables vec![I64], |builder, params| { let result = expr.compile(builder, params); builder.ins().return_(&[result]); }, ) } /// Symbol table for variable management pub struct SymbolTable { variables: HashMap<String, Variable>, next_var: usize, } impl SymbolTable { pub fn new() -> Self { Self { variables: HashMap::new(), next_var: 0, } } pub fn declare(&mut self, name: String, builder: &mut FunctionBuilder, ty: Type) -> Variable { let var = builder.declare_var(ty); self.variables.insert(name.clone(), var); self.next_var += 1; var } pub fn get(&self, name: &str) -> Option<Variable> { self.variables.get(name).copied() } } impl Default for SymbolTable { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_compile_add() { let mut jit = JitCompiler::new(); let func_id = compile_add_function(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let add_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(add_fn(2, 3), 5); assert_eq!(add_fn(10, -5), 5); } #[test] fn test_compile_factorial() { let mut jit = JitCompiler::new(); let func_id = compile_factorial(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let factorial_fn = unsafe { std::mem::transmute::<*const u8, fn(i64) -> i64>(code) }; assert_eq!(factorial_fn(0), 1); assert_eq!(factorial_fn(1), 1); assert_eq!(factorial_fn(5), 120); } #[test] fn test_compile_max() { let mut jit = JitCompiler::new(); let func_id = compile_max(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let max_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(max_fn(5, 3), 5); assert_eq!(max_fn(2, 8), 8); assert_eq!(max_fn(-5, -3), -3); } #[test] fn test_compile_expression() { let mut jit = JitCompiler::new(); // (x + 3) * (y - 2) let expr = Expr::Mul( Box::new(Expr::Add(Box::new(Expr::Var(0)), Box::new(Expr::Const(3)))), Box::new(Expr::Sub(Box::new(Expr::Var(1)), Box::new(Expr::Const(2)))), ); let func_id = compile_expression(&mut jit, expr).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let eval_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(eval_fn(5, 7), 40); // (5+3) * (7-2) = 8 * 5 = 40 assert_eq!(eval_fn(2, 4), 10); // (2+3) * (4-2) = 5 * 2 = 10 } #[test] fn test_quadratic() { let mut jit = JitCompiler::new(); let func_id = compile_quadratic(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let quad_fn = unsafe { std::mem::transmute::<*const u8, fn(f64, f64, f64, f64) -> f64>(code) }; // f(x) = 2x² + 3x + 1 // f(2) = 2*4 + 3*2 + 1 = 8 + 6 + 1 = 15 assert_eq!(quad_fn(2.0, 2.0, 3.0, 1.0), 15.0); } } /// Example: Compile a simple expression evaluator #[derive(Debug, Clone)] pub enum Expr { Const(i64), Add(Box<Expr>, Box<Expr>), Sub(Box<Expr>, Box<Expr>), Mul(Box<Expr>, Box<Expr>), Var(usize), } }
The expression enumeration represents abstract syntax trees for compilation. This recursive structure maps naturally to Cranelift’s instruction builder API.
#![allow(unused)] fn main() { use std::collections::HashMap; use cranelift::codegen::ir::types::*; use cranelift::codegen::ir::{AbiParam, Function, InstBuilder, Signature, UserFuncName}; use cranelift::codegen::settings::{self, Configurable}; use cranelift::codegen::verifier::verify_function; use cranelift::codegen::Context; use cranelift::frontend::{FunctionBuilder, FunctionBuilderContext, Variable}; use cranelift::prelude::*; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{FuncId, Linkage, Module}; /// A simple JIT compiler using Cranelift pub struct JitCompiler { builder_context: FunctionBuilderContext, ctx: Context, module: JITModule, } impl JitCompiler { pub fn new() -> Self { let mut flag_builder = settings::builder(); flag_builder.set("use_colocated_libcalls", "false").unwrap(); flag_builder.set("is_pic", "false").unwrap(); let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| { panic!("host machine is not supported: {}", msg); }); let isa = isa_builder .finish(settings::Flags::new(flag_builder)) .unwrap(); let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names()); builder.symbol_lookup_fn(Box::new(|name| { // Hook up external functions match name { "println_i64" => Some(println_i64 as *const u8), "println_f64" => Some(println_f64 as *const u8), _ => None, } })); let module = JITModule::new(builder); Self { builder_context: FunctionBuilderContext::new(), ctx: module.make_context(), module, } } pub fn compile_function( &mut self, name: &str, params: Vec<Type>, returns: Vec<Type>, build_fn: impl FnOnce(&mut FunctionBuilder, &[Variable]), ) -> Result<FuncId, String> { // Clear the context self.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), self.make_signature(params.clone(), returns.clone()), ); // Create the function builder let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); // Create variables for parameters let variables: Vec<Variable> = params.iter().map(|ty| builder.declare_var(*ty)).collect(); // Create entry block and append parameters let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); // Define parameters for (i, var) in variables.iter().enumerate() { let val = builder.block_params(entry_block)[i]; builder.def_var(*var, val); } // Call the user's function to build the body build_fn(&mut builder, &variables); // Finalize the function builder.finalize(); // Verify the function if let Err(errors) = verify_function(&self.ctx.func, self.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } // Define the function in the module let func_id = self .module .declare_function(name, Linkage::Export, &self.ctx.func.signature) .map_err(|e| e.to_string())?; self.module .define_function(func_id, &mut self.ctx) .map_err(|e| e.to_string())?; // Clear the context for next use self.module.clear_context(&mut self.ctx); Ok(func_id) } pub fn finalize(&mut self) { self.module.finalize_definitions().unwrap(); } pub fn get_function(&self, func_id: FuncId) -> *const u8 { self.module.get_finalized_function(func_id) } fn make_signature(&self, params: Vec<Type>, returns: Vec<Type>) -> Signature { let mut sig = self.module.make_signature(); for param in params { sig.params.push(AbiParam::new(param)); } for ret in returns { sig.returns.push(AbiParam::new(ret)); } sig } } impl Default for JitCompiler { fn default() -> Self { Self::new() } } extern "C" fn println_i64(x: i64) { println!("{}", x); } extern "C" fn println_f64(x: f64) { println!("{}", x); } /// Example: Compile a simple arithmetic function pub fn compile_add_function(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("add", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let sum = builder.ins().iadd(x, y); builder.ins().return_(&[sum]); }) } /// Example: Compile a factorial function pub fn compile_factorial(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("factorial", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Add block parameters builder.append_block_param(header_block, I64); // i builder.append_block_param(header_block, I64); // result // Entry: jump to header with initial values let one = builder.ins().iconst(I64, 1); builder.ins().jump(header_block, &[one.into(), one.into()]); // Header block: check if i <= n builder.switch_to_block(header_block); let i = builder.block_params(header_block)[0]; let result = builder.block_params(header_block)[1]; let n_val = builder.use_var(n); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, i, n_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body block: result *= i; i++ builder.switch_to_block(body_block); builder.seal_block(body_block); let new_result = builder.ins().imul(result, i); let new_i = builder.ins().iadd_imm(i, 1); builder .ins() .jump(header_block, &[new_i.into(), new_result.into()]); // Exit block: return result builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[result]); }) } /// Example: Compile a Fibonacci function pub fn compile_fibonacci(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("fibonacci", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let check_base = builder.create_block(); let recursive = builder.create_block(); let return_n = builder.create_block(); // Jump to check_base builder.ins().jump(check_base, &[]); // Check if n <= 1 builder.switch_to_block(check_base); let n_val = builder.use_var(n); let one = builder.ins().iconst(I64, 1); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, n_val, one); builder.ins().brif(cmp, return_n, &[], recursive, &[]); // Return n for base case builder.switch_to_block(return_n); builder.seal_block(return_n); builder.ins().return_(&[n_val]); // Recursive case: fib(n-1) + fib(n-2) builder.switch_to_block(recursive); builder.seal_block(recursive); builder.seal_block(check_base); // For simplicity, we'll compute iteratively let two = builder.ins().iconst(I64, 2); let a = builder.ins().iconst(I64, 0); let b = builder.ins().iconst(I64, 1); // Create loop blocks let loop_header = builder.create_block(); let loop_body = builder.create_block(); let loop_exit = builder.create_block(); builder.append_block_param(loop_header, I64); // counter builder.append_block_param(loop_header, I64); // a builder.append_block_param(loop_header, I64); // b builder .ins() .jump(loop_header, &[two.into(), a.into(), b.into()]); // Loop header: check if counter <= n builder.switch_to_block(loop_header); let counter = builder.block_params(loop_header)[0]; let curr_a = builder.block_params(loop_header)[1]; let curr_b = builder.block_params(loop_header)[2]; let cmp = builder .ins() .icmp(IntCC::SignedLessThanOrEqual, counter, n_val); builder.ins().brif(cmp, loop_body, &[], loop_exit, &[]); // Loop body: compute next fibonacci number builder.switch_to_block(loop_body); builder.seal_block(loop_body); let next_fib = builder.ins().iadd(curr_a, curr_b); let next_counter = builder.ins().iadd_imm(counter, 1); builder.ins().jump( loop_header, &[next_counter.into(), curr_b.into(), next_fib.into()], ); // Loop exit: return b builder.switch_to_block(loop_exit); builder.seal_block(loop_exit); builder.seal_block(loop_header); builder.ins().return_(&[curr_b]); }) } /// Example: Working with floating point pub fn compile_quadratic(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "quadratic", vec![F64, F64, F64, F64], vec![F64], |builder, params| { // f(x) = ax² + bx + c let x = builder.use_var(params[0]); let a = builder.use_var(params[1]); let b = builder.use_var(params[2]); let c = builder.use_var(params[3]); let x_squared = builder.ins().fmul(x, x); let ax_squared = builder.ins().fmul(a, x_squared); let bx = builder.ins().fmul(b, x); let ax_squared_plus_bx = builder.ins().fadd(ax_squared, bx); let result = builder.ins().fadd(ax_squared_plus_bx, c); builder.ins().return_(&[result]); }, ) } /// Example: Using external function calls pub fn compile_with_print(jit: &mut JitCompiler) -> Result<FuncId, String> { // First declare the external function let mut sig = jit.module.make_signature(); sig.params.push(AbiParam::new(I64)); let println_id = jit .module .declare_function("println_i64", Linkage::Import, &sig) .unwrap(); // Define the function let func_id = jit .module .declare_function( "print_sum", Linkage::Export, &jit.make_signature(vec![I64, I64], vec![]), ) .unwrap(); // Create function context jit.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), jit.make_signature(vec![I64, I64], vec![]), ); // Build the function { let mut builder = FunctionBuilder::new(&mut jit.ctx.func, &mut jit.builder_context); let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); let x = builder.declare_var(I64); let y = builder.declare_var(I64); let x_val = builder.block_params(entry_block)[0]; let y_val = builder.block_params(entry_block)[1]; builder.def_var(x, x_val); builder.def_var(y, y_val); let x_use = builder.use_var(x); let y_use = builder.use_var(y); let sum = builder.ins().iadd(x_use, y_use); // Declare the function reference for calling let println_ref = jit.module.declare_func_in_func(println_id, builder.func); builder.ins().call(println_ref, &[sum]); builder.ins().return_(&[]); builder.finalize(); } // Verify the function if let Err(errors) = verify_function(&jit.ctx.func, jit.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } jit.module .define_function(func_id, &mut jit.ctx) .map_err(|e| e.to_string())?; jit.module.clear_context(&mut jit.ctx); Ok(func_id) } /// Example: Control flow with multiple returns pub fn compile_max(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("max", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let then_block = builder.create_block(); let else_block = builder.create_block(); // if x > y let cmp = builder.ins().icmp(IntCC::SignedGreaterThan, x, y); builder.ins().brif(cmp, then_block, &[], else_block, &[]); // then: return x builder.switch_to_block(then_block); builder.seal_block(then_block); builder.ins().return_(&[x]); // else: return y builder.switch_to_block(else_block); builder.seal_block(else_block); builder.ins().return_(&[y]); }) } /// Example: Array/memory operations pub fn compile_sum_array(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "sum_array", vec![I64, I64], // ptr, len vec![I64], |builder, params| { let ptr = params[0]; let len = params[1]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Block parameters builder.append_block_param(header_block, I64); // index builder.append_block_param(header_block, I64); // sum builder.append_block_param(header_block, I64); // current_ptr // Initialize loop let zero = builder.ins().iconst(I64, 0); let ptr_val = builder.use_var(ptr); builder .ins() .jump(header_block, &[zero.into(), zero.into(), ptr_val.into()]); // Header: check if index < len builder.switch_to_block(header_block); let index = builder.block_params(header_block)[0]; let sum = builder.block_params(header_block)[1]; let current_ptr = builder.block_params(header_block)[2]; let len_val = builder.use_var(len); let cmp = builder.ins().icmp(IntCC::UnsignedLessThan, index, len_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body: load value and add to sum builder.switch_to_block(body_block); builder.seal_block(body_block); let flags = MemFlags::new(); let value = builder.ins().load(I64, flags, current_ptr, 0); let new_sum = builder.ins().iadd(sum, value); let new_index = builder.ins().iadd_imm(index, 1); let new_ptr = builder.ins().iadd_imm(current_ptr, 8); // 8 bytes for i64 builder.ins().jump( header_block, &[new_index.into(), new_sum.into(), new_ptr.into()], ); // Exit: return sum builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[sum]); }, ) } /// Example: Compile a simple expression evaluator #[derive(Debug, Clone)] pub enum Expr { Const(i64), Add(Box<Expr>, Box<Expr>), Sub(Box<Expr>, Box<Expr>), Mul(Box<Expr>, Box<Expr>), Var(usize), } pub fn compile_expression(jit: &mut JitCompiler, expr: Expr) -> Result<FuncId, String> { jit.compile_function( "eval_expr", vec![I64, I64], // two variables vec![I64], |builder, params| { let result = expr.compile(builder, params); builder.ins().return_(&[result]); }, ) } /// Symbol table for variable management pub struct SymbolTable { variables: HashMap<String, Variable>, next_var: usize, } impl SymbolTable { pub fn new() -> Self { Self { variables: HashMap::new(), next_var: 0, } } pub fn declare(&mut self, name: String, builder: &mut FunctionBuilder, ty: Type) -> Variable { let var = builder.declare_var(ty); self.variables.insert(name.clone(), var); self.next_var += 1; var } pub fn get(&self, name: &str) -> Option<Variable> { self.variables.get(name).copied() } } impl Default for SymbolTable { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_compile_add() { let mut jit = JitCompiler::new(); let func_id = compile_add_function(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let add_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(add_fn(2, 3), 5); assert_eq!(add_fn(10, -5), 5); } #[test] fn test_compile_factorial() { let mut jit = JitCompiler::new(); let func_id = compile_factorial(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let factorial_fn = unsafe { std::mem::transmute::<*const u8, fn(i64) -> i64>(code) }; assert_eq!(factorial_fn(0), 1); assert_eq!(factorial_fn(1), 1); assert_eq!(factorial_fn(5), 120); } #[test] fn test_compile_max() { let mut jit = JitCompiler::new(); let func_id = compile_max(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let max_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(max_fn(5, 3), 5); assert_eq!(max_fn(2, 8), 8); assert_eq!(max_fn(-5, -3), -3); } #[test] fn test_compile_expression() { let mut jit = JitCompiler::new(); // (x + 3) * (y - 2) let expr = Expr::Mul( Box::new(Expr::Add(Box::new(Expr::Var(0)), Box::new(Expr::Const(3)))), Box::new(Expr::Sub(Box::new(Expr::Var(1)), Box::new(Expr::Const(2)))), ); let func_id = compile_expression(&mut jit, expr).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let eval_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(eval_fn(5, 7), 40); // (5+3) * (7-2) = 8 * 5 = 40 assert_eq!(eval_fn(2, 4), 10); // (2+3) * (4-2) = 5 * 2 = 10 } #[test] fn test_quadratic() { let mut jit = JitCompiler::new(); let func_id = compile_quadratic(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let quad_fn = unsafe { std::mem::transmute::<*const u8, fn(f64, f64, f64, f64) -> f64>(code) }; // f(x) = 2x² + 3x + 1 // f(2) = 2*4 + 3*2 + 1 = 8 + 6 + 1 = 15 assert_eq!(quad_fn(2.0, 2.0, 3.0, 1.0), 15.0); } } impl Expr { pub fn compile(&self, builder: &mut FunctionBuilder, vars: &[Variable]) -> Value { match self { Expr::Const(n) => builder.ins().iconst(I64, *n), Expr::Add(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().iadd(a_val, b_val) } Expr::Sub(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().isub(a_val, b_val) } Expr::Mul(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().imul(a_val, b_val) } Expr::Var(idx) => builder.use_var(vars[*idx]), } } } }
Recursive compilation transforms expression trees into SSA values. The method directly maps expression nodes to Cranelift instructions, demonstrating the correspondence between high-level operations and low-level IR. Variable references connect to the function’s parameter space, enabling parameterized expression evaluation.
Symbol Management
#![allow(unused)] fn main() { use std::collections::HashMap; use cranelift::codegen::ir::types::*; use cranelift::codegen::ir::{AbiParam, Function, InstBuilder, Signature, UserFuncName}; use cranelift::codegen::settings::{self, Configurable}; use cranelift::codegen::verifier::verify_function; use cranelift::codegen::Context; use cranelift::frontend::{FunctionBuilder, FunctionBuilderContext, Variable}; use cranelift::prelude::*; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{FuncId, Linkage, Module}; /// A simple JIT compiler using Cranelift pub struct JitCompiler { builder_context: FunctionBuilderContext, ctx: Context, module: JITModule, } impl JitCompiler { pub fn new() -> Self { let mut flag_builder = settings::builder(); flag_builder.set("use_colocated_libcalls", "false").unwrap(); flag_builder.set("is_pic", "false").unwrap(); let isa_builder = cranelift_native::builder().unwrap_or_else(|msg| { panic!("host machine is not supported: {}", msg); }); let isa = isa_builder .finish(settings::Flags::new(flag_builder)) .unwrap(); let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names()); builder.symbol_lookup_fn(Box::new(|name| { // Hook up external functions match name { "println_i64" => Some(println_i64 as *const u8), "println_f64" => Some(println_f64 as *const u8), _ => None, } })); let module = JITModule::new(builder); Self { builder_context: FunctionBuilderContext::new(), ctx: module.make_context(), module, } } pub fn compile_function( &mut self, name: &str, params: Vec<Type>, returns: Vec<Type>, build_fn: impl FnOnce(&mut FunctionBuilder, &[Variable]), ) -> Result<FuncId, String> { // Clear the context self.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), self.make_signature(params.clone(), returns.clone()), ); // Create the function builder let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); // Create variables for parameters let variables: Vec<Variable> = params.iter().map(|ty| builder.declare_var(*ty)).collect(); // Create entry block and append parameters let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); // Define parameters for (i, var) in variables.iter().enumerate() { let val = builder.block_params(entry_block)[i]; builder.def_var(*var, val); } // Call the user's function to build the body build_fn(&mut builder, &variables); // Finalize the function builder.finalize(); // Verify the function if let Err(errors) = verify_function(&self.ctx.func, self.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } // Define the function in the module let func_id = self .module .declare_function(name, Linkage::Export, &self.ctx.func.signature) .map_err(|e| e.to_string())?; self.module .define_function(func_id, &mut self.ctx) .map_err(|e| e.to_string())?; // Clear the context for next use self.module.clear_context(&mut self.ctx); Ok(func_id) } pub fn finalize(&mut self) { self.module.finalize_definitions().unwrap(); } pub fn get_function(&self, func_id: FuncId) -> *const u8 { self.module.get_finalized_function(func_id) } fn make_signature(&self, params: Vec<Type>, returns: Vec<Type>) -> Signature { let mut sig = self.module.make_signature(); for param in params { sig.params.push(AbiParam::new(param)); } for ret in returns { sig.returns.push(AbiParam::new(ret)); } sig } } impl Default for JitCompiler { fn default() -> Self { Self::new() } } extern "C" fn println_i64(x: i64) { println!("{}", x); } extern "C" fn println_f64(x: f64) { println!("{}", x); } /// Example: Compile a simple arithmetic function pub fn compile_add_function(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("add", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let sum = builder.ins().iadd(x, y); builder.ins().return_(&[sum]); }) } /// Example: Compile a factorial function pub fn compile_factorial(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("factorial", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Add block parameters builder.append_block_param(header_block, I64); // i builder.append_block_param(header_block, I64); // result // Entry: jump to header with initial values let one = builder.ins().iconst(I64, 1); builder.ins().jump(header_block, &[one.into(), one.into()]); // Header block: check if i <= n builder.switch_to_block(header_block); let i = builder.block_params(header_block)[0]; let result = builder.block_params(header_block)[1]; let n_val = builder.use_var(n); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, i, n_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body block: result *= i; i++ builder.switch_to_block(body_block); builder.seal_block(body_block); let new_result = builder.ins().imul(result, i); let new_i = builder.ins().iadd_imm(i, 1); builder .ins() .jump(header_block, &[new_i.into(), new_result.into()]); // Exit block: return result builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[result]); }) } /// Example: Compile a Fibonacci function pub fn compile_fibonacci(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("fibonacci", vec![I64], vec![I64], |builder, params| { let n = params[0]; // Create blocks let check_base = builder.create_block(); let recursive = builder.create_block(); let return_n = builder.create_block(); // Jump to check_base builder.ins().jump(check_base, &[]); // Check if n <= 1 builder.switch_to_block(check_base); let n_val = builder.use_var(n); let one = builder.ins().iconst(I64, 1); let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, n_val, one); builder.ins().brif(cmp, return_n, &[], recursive, &[]); // Return n for base case builder.switch_to_block(return_n); builder.seal_block(return_n); builder.ins().return_(&[n_val]); // Recursive case: fib(n-1) + fib(n-2) builder.switch_to_block(recursive); builder.seal_block(recursive); builder.seal_block(check_base); // For simplicity, we'll compute iteratively let two = builder.ins().iconst(I64, 2); let a = builder.ins().iconst(I64, 0); let b = builder.ins().iconst(I64, 1); // Create loop blocks let loop_header = builder.create_block(); let loop_body = builder.create_block(); let loop_exit = builder.create_block(); builder.append_block_param(loop_header, I64); // counter builder.append_block_param(loop_header, I64); // a builder.append_block_param(loop_header, I64); // b builder .ins() .jump(loop_header, &[two.into(), a.into(), b.into()]); // Loop header: check if counter <= n builder.switch_to_block(loop_header); let counter = builder.block_params(loop_header)[0]; let curr_a = builder.block_params(loop_header)[1]; let curr_b = builder.block_params(loop_header)[2]; let cmp = builder .ins() .icmp(IntCC::SignedLessThanOrEqual, counter, n_val); builder.ins().brif(cmp, loop_body, &[], loop_exit, &[]); // Loop body: compute next fibonacci number builder.switch_to_block(loop_body); builder.seal_block(loop_body); let next_fib = builder.ins().iadd(curr_a, curr_b); let next_counter = builder.ins().iadd_imm(counter, 1); builder.ins().jump( loop_header, &[next_counter.into(), curr_b.into(), next_fib.into()], ); // Loop exit: return b builder.switch_to_block(loop_exit); builder.seal_block(loop_exit); builder.seal_block(loop_header); builder.ins().return_(&[curr_b]); }) } /// Example: Working with floating point pub fn compile_quadratic(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "quadratic", vec![F64, F64, F64, F64], vec![F64], |builder, params| { // f(x) = ax² + bx + c let x = builder.use_var(params[0]); let a = builder.use_var(params[1]); let b = builder.use_var(params[2]); let c = builder.use_var(params[3]); let x_squared = builder.ins().fmul(x, x); let ax_squared = builder.ins().fmul(a, x_squared); let bx = builder.ins().fmul(b, x); let ax_squared_plus_bx = builder.ins().fadd(ax_squared, bx); let result = builder.ins().fadd(ax_squared_plus_bx, c); builder.ins().return_(&[result]); }, ) } /// Example: Using external function calls pub fn compile_with_print(jit: &mut JitCompiler) -> Result<FuncId, String> { // First declare the external function let mut sig = jit.module.make_signature(); sig.params.push(AbiParam::new(I64)); let println_id = jit .module .declare_function("println_i64", Linkage::Import, &sig) .unwrap(); // Define the function let func_id = jit .module .declare_function( "print_sum", Linkage::Export, &jit.make_signature(vec![I64, I64], vec![]), ) .unwrap(); // Create function context jit.ctx.func = Function::with_name_signature( UserFuncName::user(0, 0), jit.make_signature(vec![I64, I64], vec![]), ); // Build the function { let mut builder = FunctionBuilder::new(&mut jit.ctx.func, &mut jit.builder_context); let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); let x = builder.declare_var(I64); let y = builder.declare_var(I64); let x_val = builder.block_params(entry_block)[0]; let y_val = builder.block_params(entry_block)[1]; builder.def_var(x, x_val); builder.def_var(y, y_val); let x_use = builder.use_var(x); let y_use = builder.use_var(y); let sum = builder.ins().iadd(x_use, y_use); // Declare the function reference for calling let println_ref = jit.module.declare_func_in_func(println_id, builder.func); builder.ins().call(println_ref, &[sum]); builder.ins().return_(&[]); builder.finalize(); } // Verify the function if let Err(errors) = verify_function(&jit.ctx.func, jit.module.isa()) { return Err(format!("Function verification failed: {}", errors)); } jit.module .define_function(func_id, &mut jit.ctx) .map_err(|e| e.to_string())?; jit.module.clear_context(&mut jit.ctx); Ok(func_id) } /// Example: Control flow with multiple returns pub fn compile_max(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function("max", vec![I64, I64], vec![I64], |builder, params| { let x = builder.use_var(params[0]); let y = builder.use_var(params[1]); let then_block = builder.create_block(); let else_block = builder.create_block(); // if x > y let cmp = builder.ins().icmp(IntCC::SignedGreaterThan, x, y); builder.ins().brif(cmp, then_block, &[], else_block, &[]); // then: return x builder.switch_to_block(then_block); builder.seal_block(then_block); builder.ins().return_(&[x]); // else: return y builder.switch_to_block(else_block); builder.seal_block(else_block); builder.ins().return_(&[y]); }) } /// Example: Array/memory operations pub fn compile_sum_array(jit: &mut JitCompiler) -> Result<FuncId, String> { jit.compile_function( "sum_array", vec![I64, I64], // ptr, len vec![I64], |builder, params| { let ptr = params[0]; let len = params[1]; // Create blocks let header_block = builder.create_block(); let body_block = builder.create_block(); let exit_block = builder.create_block(); // Block parameters builder.append_block_param(header_block, I64); // index builder.append_block_param(header_block, I64); // sum builder.append_block_param(header_block, I64); // current_ptr // Initialize loop let zero = builder.ins().iconst(I64, 0); let ptr_val = builder.use_var(ptr); builder .ins() .jump(header_block, &[zero.into(), zero.into(), ptr_val.into()]); // Header: check if index < len builder.switch_to_block(header_block); let index = builder.block_params(header_block)[0]; let sum = builder.block_params(header_block)[1]; let current_ptr = builder.block_params(header_block)[2]; let len_val = builder.use_var(len); let cmp = builder.ins().icmp(IntCC::UnsignedLessThan, index, len_val); builder.ins().brif(cmp, body_block, &[], exit_block, &[]); // Body: load value and add to sum builder.switch_to_block(body_block); builder.seal_block(body_block); let flags = MemFlags::new(); let value = builder.ins().load(I64, flags, current_ptr, 0); let new_sum = builder.ins().iadd(sum, value); let new_index = builder.ins().iadd_imm(index, 1); let new_ptr = builder.ins().iadd_imm(current_ptr, 8); // 8 bytes for i64 builder.ins().jump( header_block, &[new_index.into(), new_sum.into(), new_ptr.into()], ); // Exit: return sum builder.switch_to_block(exit_block); builder.seal_block(exit_block); builder.seal_block(header_block); builder.ins().return_(&[sum]); }, ) } /// Example: Compile a simple expression evaluator #[derive(Debug, Clone)] pub enum Expr { Const(i64), Add(Box<Expr>, Box<Expr>), Sub(Box<Expr>, Box<Expr>), Mul(Box<Expr>, Box<Expr>), Var(usize), } impl Expr { pub fn compile(&self, builder: &mut FunctionBuilder, vars: &[Variable]) -> Value { match self { Expr::Const(n) => builder.ins().iconst(I64, *n), Expr::Add(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().iadd(a_val, b_val) } Expr::Sub(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().isub(a_val, b_val) } Expr::Mul(a, b) => { let a_val = a.compile(builder, vars); let b_val = b.compile(builder, vars); builder.ins().imul(a_val, b_val) } Expr::Var(idx) => builder.use_var(vars[*idx]), } } } pub fn compile_expression(jit: &mut JitCompiler, expr: Expr) -> Result<FuncId, String> { jit.compile_function( "eval_expr", vec![I64, I64], // two variables vec![I64], |builder, params| { let result = expr.compile(builder, params); builder.ins().return_(&[result]); }, ) } impl SymbolTable { pub fn new() -> Self { Self { variables: HashMap::new(), next_var: 0, } } pub fn declare(&mut self, name: String, builder: &mut FunctionBuilder, ty: Type) -> Variable { let var = builder.declare_var(ty); self.variables.insert(name.clone(), var); self.next_var += 1; var } pub fn get(&self, name: &str) -> Option<Variable> { self.variables.get(name).copied() } } impl Default for SymbolTable { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_compile_add() { let mut jit = JitCompiler::new(); let func_id = compile_add_function(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let add_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(add_fn(2, 3), 5); assert_eq!(add_fn(10, -5), 5); } #[test] fn test_compile_factorial() { let mut jit = JitCompiler::new(); let func_id = compile_factorial(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let factorial_fn = unsafe { std::mem::transmute::<*const u8, fn(i64) -> i64>(code) }; assert_eq!(factorial_fn(0), 1); assert_eq!(factorial_fn(1), 1); assert_eq!(factorial_fn(5), 120); } #[test] fn test_compile_max() { let mut jit = JitCompiler::new(); let func_id = compile_max(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let max_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(max_fn(5, 3), 5); assert_eq!(max_fn(2, 8), 8); assert_eq!(max_fn(-5, -3), -3); } #[test] fn test_compile_expression() { let mut jit = JitCompiler::new(); // (x + 3) * (y - 2) let expr = Expr::Mul( Box::new(Expr::Add(Box::new(Expr::Var(0)), Box::new(Expr::Const(3)))), Box::new(Expr::Sub(Box::new(Expr::Var(1)), Box::new(Expr::Const(2)))), ); let func_id = compile_expression(&mut jit, expr).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let eval_fn = unsafe { std::mem::transmute::<*const u8, fn(i64, i64) -> i64>(code) }; assert_eq!(eval_fn(5, 7), 40); // (5+3) * (7-2) = 8 * 5 = 40 assert_eq!(eval_fn(2, 4), 10); // (2+3) * (4-2) = 5 * 2 = 10 } #[test] fn test_quadratic() { let mut jit = JitCompiler::new(); let func_id = compile_quadratic(&mut jit).unwrap(); jit.finalize(); let code = jit.get_function(func_id); let quad_fn = unsafe { std::mem::transmute::<*const u8, fn(f64, f64, f64, f64) -> f64>(code) }; // f(x) = 2x² + 3x + 1 // f(2) = 2*4 + 3*2 + 1 = 8 + 6 + 1 = 15 assert_eq!(quad_fn(2.0, 2.0, 3.0, 1.0), 15.0); } } /// Symbol table for variable management pub struct SymbolTable { variables: HashMap<String, Variable>, next_var: usize, } }
Symbol tables manage the mapping between names and Cranelift variables. The monotonic variable allocation ensures unique SSA names throughout compilation.
#![allow(unused)] fn main() { pub fn declare(&mut self, name: String, builder: &mut FunctionBuilder, ty: Type) -> Variable { let var = builder.declare_var(ty); self.variables.insert(name.clone(), var); self.next_var += 1; var } }
Variable declaration combines allocation with type specification. The builder’s declare_var call registers the variable in the function’s metadata, enabling the use_var and def_var operations that connect high-level variables to SSA values.
Optimization Considerations
Cranelift performs several optimizations during code generation despite prioritizing compilation speed. The instruction combiner merges adjacent operations when beneficial, such as combining additions with small constants into immediate-mode instructions. Simple dead code elimination removes unreachable blocks and unused values.
Register allocation uses a fast linear scan algorithm that produces good code without the compilation time cost of graph coloring or PBQP approaches. The allocator handles live range splitting and spilling automatically, generating reload code as needed.
The code generator exploits CPU features when available, using vector instructions for appropriate operations and conditional moves to avoid branches. Target-specific optimizations include addressing mode selection and instruction scheduling within basic blocks.
Integration Patterns
Cranelift integrates into larger systems through several abstraction layers. The Module trait provides a uniform interface for both JIT and AOT compilation, abstracting over linking and symbol resolution differences. The cranelift-wasm crate demonstrates translation from WebAssembly to Cranelift IR, while maintaining semantic equivalence.
Runtime code generation benefits from Cranelift’s incremental compilation model. Functions can be compiled on-demand, with lazy linking deferring symbol resolution until needed. The JIT module supports code invalidation and recompilation, essential for adaptive optimization systems.
Debugging support includes source location tracking through the IR pipeline, enabling accurate debugging information in generated code. The cranelift-reader crate provides a textual IR format for testing and debugging, while the verifier catches IR inconsistencies early in development.
Performance Characteristics
Compilation speed typically ranges from 10-100 MB/s of generated code, orders of magnitude faster than optimizing compilers. Memory usage scales linearly with function size, avoiding the exponential growth of some optimization algorithms. The generated code typically performs within 2-3x of optimized C code, acceptable for many JIT scenarios.
Cranelift’s architecture enables predictable performance across different input programs. The lack of iterative optimization passes ensures bounded compilation time, while the streaming code generation minimizes memory residence. These properties make Cranelift suitable for latency-sensitive applications where compilation happens on the critical path.
Error Handling
The verifier catches most IR construction errors before code generation, providing clear diagnostics about invalid instruction sequences or type mismatches. Runtime errors manifest as traps, with preservation of source location information for debugging. The compilation pipeline propagates errors explicitly, avoiding panics in production use.
Best Practices
Structure IR generation to minimize variable live ranges, reducing register pressure and improving code quality. Use block parameters instead of variables for values that cross block boundaries, enabling better optimization. Seal blocks as soon as all predecessors are known to enable efficient SSA construction.
Profile compilation time to identify bottlenecks, particularly in function builder usage patterns. Large functions may benefit from splitting into smaller units that compile independently. Consider caching compiled code when possible to amortize compilation costs across multiple executions.
Design the IR generation to preserve high-level semantics where possible. Cranelift’s optimizer works best when the intent of operations is clear, such as using specific instructions for bounds checks rather than generic comparisons.
dynasm-rs
dynasm-rs is a runtime assembler for Rust that allows you to dynamically generate and execute machine code. It provides a plugin and runtime library that work together to offer an assembly-like syntax embedded directly in Rust code. This makes it ideal for JIT compilers, runtime code specialization, and high-performance computing scenarios where static compilation isn’t sufficient.
The library stands out for its compile-time syntax checking and seamless integration with Rust’s type system. Unlike traditional assemblers that process text at runtime, dynasm-rs verifies assembly syntax during compilation, catching errors early while maintaining the flexibility of runtime code generation.
Core Architecture
dynasm-rs consists of two main components: a procedural macro that processes assembly syntax at compile time, and a runtime library that manages code buffers and relocation. The macro translates assembly directives into API calls that construct machine code at runtime.
#![allow(unused)] fn main() { use std::{io, slice}; use dynasmrt::{dynasm, DynasmApi, DynasmLabelApi, ExecutableBuffer}; /// External function called from assembly to print a buffer. /// /// # Safety /// /// The caller must ensure that: /// - `buffer` points to valid memory for at least `length` bytes /// - The memory pointed to by `buffer` is initialized pub unsafe extern "C" fn print(buffer: *const u8, length: u64) -> bool { io::Write::write_all( &mut io::stdout(), slice::from_raw_parts(buffer, length as usize), ) .is_ok() } /// Generates an optimized addition function for two integers. /// /// Creates machine code equivalent to: `fn add(a: i32, b: i32) -> i32 { a + b /// }` pub fn generate_add_function() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; add w0, w0, w1 // Add w1 to w0, result in w0 (ARM64 ABI) ; ret // Return with result in w0 ); ops.finalize().unwrap() } /// Generates a factorial function using recursion. /// /// Demonstrates more complex control flow with conditional jumps and recursive /// calls. pub fn generate_factorial() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); let entry_label = ops.new_dynamic_label(); ops.dynamic_label(entry_label); dynasm!(ops ; .arch aarch64 ; cmp w0, #1 // Compare n with 1 ; b.le ->base_case // Jump if n <= 1 ; stp x29, x30, [sp, #-16]! // Save frame pointer and link register ; stp x19, x20, [sp, #-16]! // Save callee-saved registers ; mov w19, w0 // Save n in w19 ; sub w0, w0, #1 // n - 1 ; adr x1, =>entry_label // Load our own address for recursion ; blr x1 // Recursive call with n-1 ; mul w0, w0, w19 // Multiply result by n ; ldp x19, x20, [sp], #16 // Restore callee-saved registers ; ldp x29, x30, [sp], #16 // Restore frame pointer and link register ; ret ; ->base_case: ; mov w0, #1 // Return 1 for base case ; ret ); ops.finalize().unwrap() } /// Generates a loop that sums an array of integers. /// /// Takes a pointer to an i32 array and its length, returns the sum. pub fn generate_array_sum() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; mov w2, #0 // Initialize sum to 0 ; cbz x1, ->done // If length is 0, return 0 ; ->loop_start: ; ldr w3, [x0], #4 // Load element and increment pointer ; add w2, w2, w3 // Add to sum ; sub x1, x1, #1 // Decrement counter ; cbnz x1, ->loop_start // Continue if not zero ; ->done: ; mov w0, w2 // Move result to return register ; ret ); ops.finalize().unwrap() } /// Generates a function that performs SIMD operations using NEON instructions. /// /// Adds two integer vectors element by element. pub fn generate_vector_add() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; ldp x3, x4, [x0] // Load first two elements from vector 1 ; ldp x5, x6, [x1] // Load first two elements from vector 2 ; add x3, x3, x5 // Add first elements ; add x4, x4, x6 // Add second elements ; stp x3, x4, [x2] // Store result ; ret ); ops.finalize().unwrap() } /// Demonstrates conditional compilation based on runtime values. /// /// Generates specialized code for specific constant values. pub fn generate_multiply_by_constant(constant: i32) -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); // Optimize for powers of two using shifts if constant > 0 && (constant & (constant - 1)) == 0 { let shift = constant.trailing_zeros(); dynasm!(ops ; .arch aarch64 ; lsl w0, w0, shift // Shift left for power of 2 ; ret ); } else { dynasm!(ops ; .arch aarch64 ; mov w1, constant as u32 // Load constant ; mul w0, w0, w1 // Multiply ; ret ); } ops.finalize().unwrap() } /// Generates a memcpy implementation optimized for small sizes. /// /// Uses a simple byte copy loop for ARM64. pub fn generate_memcpy() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; mov x3, x0 // Save destination for return ; cbz x2, ->done // If count is 0, return ; ->copy_loop: ; ldrb w4, [x1], #1 // Load byte from source and increment ; strb w4, [x0], #1 // Store byte to dest and increment ; sub x2, x2, #1 // Decrement count ; cbnz x2, ->copy_loop // Continue if not zero ; ->done: ; mov x0, x3 // Return original destination ; ret ); ops.finalize().unwrap() } /// Helper function to execute generated code safely. /// /// Converts the generated bytes into an executable function pointer. /// /// # Safety /// /// The caller must ensure that: /// - `code` contains valid machine code for the target architecture /// - The code follows the expected calling convention /// - The function pointer type matches the actual generated code signature pub unsafe fn execute_generated_code<F, R>(code: &[u8], f: F) -> R where F: FnOnce(*const u8) -> R, { f(code.as_ptr()) } #[cfg(test)] mod tests { use super::*; // Tests that verify code generation works without executing mod generation { use super::*; #[test] fn test_add_function_generation() { let code = generate_add_function(); // Verify that code was generated (non-empty buffer) assert!(!code.is_empty()); // ARM64 instructions are 4 bytes each assert_eq!(code.len() % 4, 0); } #[test] fn test_factorial_generation() { let code = generate_factorial(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); // Factorial should generate more code due to recursion logic assert!(code.len() > 20); } #[test] fn test_array_sum_generation() { let code = generate_array_sum(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] fn test_multiply_by_constant_generation() { // Test power of two (uses shift) - should generate less code let code_pow2 = generate_multiply_by_constant(8); assert!(!code_pow2.is_empty()); assert_eq!(code_pow2.len() % 4, 0); // Test non-power of two (uses mul) - might generate slightly more code let code_regular = generate_multiply_by_constant(7); assert!(!code_regular.is_empty()); assert_eq!(code_regular.len() % 4, 0); } #[test] fn test_vector_add_generation() { let code = generate_vector_add(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] fn test_memcpy_generation() { let code = generate_memcpy(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] #[cfg(target_arch = "aarch64")] fn test_hello_world_generation() { let code = generate_hello_world(); assert!(!code.is_empty()); // Should include both the string data and the code assert!(code.len() > "Hello World!".len()); } #[test] #[cfg(not(target_arch = "aarch64"))] fn test_hello_world_generation_skipped() { // Skip this test on non-ARM64 architectures because // the function address calculation is architecture-specific println!("Skipping hello_world generation test on non-ARM64 architecture"); } } // Tests that execute the generated code - only run on ARM64 #[cfg(all(test, target_arch = "aarch64"))] #[allow(unused_unsafe)] mod execution { use std::mem; use super::*; #[test] fn test_add_function_execution() { let code = generate_add_function(); let add_fn: extern "C" fn(i32, i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { add_fn(5, 3) }, 8); assert_eq!(unsafe { add_fn(-10, 20) }, 10); } #[test] fn test_factorial_execution() { let code = generate_factorial(); let factorial_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { factorial_fn(0) }, 1); assert_eq!(unsafe { factorial_fn(1) }, 1); assert_eq!(unsafe { factorial_fn(5) }, 120); } #[test] fn test_array_sum_execution() { let code = generate_array_sum(); let sum_fn: extern "C" fn(*const i32, usize) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; let array = [1, 2, 3, 4, 5]; assert_eq!(unsafe { sum_fn(array.as_ptr(), array.len()) }, 15); let empty: [i32; 0] = []; assert_eq!(unsafe { sum_fn(empty.as_ptr(), 0) }, 0); } #[test] fn test_multiply_by_constant_execution() { // Test power of two (uses shift) let code = generate_multiply_by_constant(8); let mul_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { mul_fn(5) }, 40); // Test non-power of two (uses mul) let code = generate_multiply_by_constant(7); let mul_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { mul_fn(6) }, 42); } } } /// Generates a simple "Hello World" function using ARM64 assembly. /// /// This example demonstrates: /// - Embedding data directly in the assembly /// - Using labels for addressing /// - Calling external Rust functions from assembly /// - Stack management for ARM64 ABI pub fn generate_hello_world() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); let string = "Hello World!"; // Embed the string data with a label dynasm!(ops ; .arch aarch64 ; ->hello: ; .bytes string.as_bytes() ); // Generate the function that prints the string // Load the 64-bit function address in chunks (16 bits at a time) let print_addr = print as *const () as usize; dynasm!(ops ; .arch aarch64 ; adr x0, ->hello // Load string address into first arg (ARM64 ABI) ; mov w1, string.len() as u32 // Load string length into second arg ; movz x2, (print_addr & 0xFFFF) as u32 ; movk x2, ((print_addr >> 16) & 0xFFFF) as u32, lsl 16 ; movk x2, ((print_addr >> 32) & 0xFFFF) as u32, lsl 32 ; movk x2, ((print_addr >> 48) & 0xFFFF) as u32, lsl 48 ; blr x2 // Call the print function ; ret // Return ); ops.finalize().unwrap() } }
The dynasm!
macro parses the assembly syntax and generates Rust code that emits the corresponding machine instructions. Labels (prefixed with ->
) are resolved automatically, handling forward and backward references transparently.
Calling Conventions
Different platforms use different calling conventions. dynasm-rs supports multiple conventions, allowing generated code to interface correctly with external functions:
#![allow(unused)] fn main() { use std::{io, slice}; use dynasmrt::{dynasm, DynasmApi, DynasmLabelApi, ExecutableBuffer}; /// Generates a simple "Hello World" function using ARM64 assembly. /// /// This example demonstrates: /// - Embedding data directly in the assembly /// - Using labels for addressing /// - Calling external Rust functions from assembly /// - Stack management for ARM64 ABI pub fn generate_hello_world() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); let string = "Hello World!"; // Embed the string data with a label dynasm!(ops ; .arch aarch64 ; ->hello: ; .bytes string.as_bytes() ); // Generate the function that prints the string // Load the 64-bit function address in chunks (16 bits at a time) let print_addr = print as *const () as usize; dynasm!(ops ; .arch aarch64 ; adr x0, ->hello // Load string address into first arg (ARM64 ABI) ; mov w1, string.len() as u32 // Load string length into second arg ; movz x2, (print_addr & 0xFFFF) as u32 ; movk x2, ((print_addr >> 16) & 0xFFFF) as u32, lsl 16 ; movk x2, ((print_addr >> 32) & 0xFFFF) as u32, lsl 32 ; movk x2, ((print_addr >> 48) & 0xFFFF) as u32, lsl 48 ; blr x2 // Call the print function ; ret // Return ); ops.finalize().unwrap() } /// Generates an optimized addition function for two integers. /// /// Creates machine code equivalent to: `fn add(a: i32, b: i32) -> i32 { a + b /// }` pub fn generate_add_function() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; add w0, w0, w1 // Add w1 to w0, result in w0 (ARM64 ABI) ; ret // Return with result in w0 ); ops.finalize().unwrap() } /// Generates a factorial function using recursion. /// /// Demonstrates more complex control flow with conditional jumps and recursive /// calls. pub fn generate_factorial() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); let entry_label = ops.new_dynamic_label(); ops.dynamic_label(entry_label); dynasm!(ops ; .arch aarch64 ; cmp w0, #1 // Compare n with 1 ; b.le ->base_case // Jump if n <= 1 ; stp x29, x30, [sp, #-16]! // Save frame pointer and link register ; stp x19, x20, [sp, #-16]! // Save callee-saved registers ; mov w19, w0 // Save n in w19 ; sub w0, w0, #1 // n - 1 ; adr x1, =>entry_label // Load our own address for recursion ; blr x1 // Recursive call with n-1 ; mul w0, w0, w19 // Multiply result by n ; ldp x19, x20, [sp], #16 // Restore callee-saved registers ; ldp x29, x30, [sp], #16 // Restore frame pointer and link register ; ret ; ->base_case: ; mov w0, #1 // Return 1 for base case ; ret ); ops.finalize().unwrap() } /// Generates a loop that sums an array of integers. /// /// Takes a pointer to an i32 array and its length, returns the sum. pub fn generate_array_sum() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; mov w2, #0 // Initialize sum to 0 ; cbz x1, ->done // If length is 0, return 0 ; ->loop_start: ; ldr w3, [x0], #4 // Load element and increment pointer ; add w2, w2, w3 // Add to sum ; sub x1, x1, #1 // Decrement counter ; cbnz x1, ->loop_start // Continue if not zero ; ->done: ; mov w0, w2 // Move result to return register ; ret ); ops.finalize().unwrap() } /// Generates a function that performs SIMD operations using NEON instructions. /// /// Adds two integer vectors element by element. pub fn generate_vector_add() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; ldp x3, x4, [x0] // Load first two elements from vector 1 ; ldp x5, x6, [x1] // Load first two elements from vector 2 ; add x3, x3, x5 // Add first elements ; add x4, x4, x6 // Add second elements ; stp x3, x4, [x2] // Store result ; ret ); ops.finalize().unwrap() } /// Demonstrates conditional compilation based on runtime values. /// /// Generates specialized code for specific constant values. pub fn generate_multiply_by_constant(constant: i32) -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); // Optimize for powers of two using shifts if constant > 0 && (constant & (constant - 1)) == 0 { let shift = constant.trailing_zeros(); dynasm!(ops ; .arch aarch64 ; lsl w0, w0, shift // Shift left for power of 2 ; ret ); } else { dynasm!(ops ; .arch aarch64 ; mov w1, constant as u32 // Load constant ; mul w0, w0, w1 // Multiply ; ret ); } ops.finalize().unwrap() } /// Generates a memcpy implementation optimized for small sizes. /// /// Uses a simple byte copy loop for ARM64. pub fn generate_memcpy() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; mov x3, x0 // Save destination for return ; cbz x2, ->done // If count is 0, return ; ->copy_loop: ; ldrb w4, [x1], #1 // Load byte from source and increment ; strb w4, [x0], #1 // Store byte to dest and increment ; sub x2, x2, #1 // Decrement count ; cbnz x2, ->copy_loop // Continue if not zero ; ->done: ; mov x0, x3 // Return original destination ; ret ); ops.finalize().unwrap() } /// Helper function to execute generated code safely. /// /// Converts the generated bytes into an executable function pointer. /// /// # Safety /// /// The caller must ensure that: /// - `code` contains valid machine code for the target architecture /// - The code follows the expected calling convention /// - The function pointer type matches the actual generated code signature pub unsafe fn execute_generated_code<F, R>(code: &[u8], f: F) -> R where F: FnOnce(*const u8) -> R, { f(code.as_ptr()) } #[cfg(test)] mod tests { use super::*; // Tests that verify code generation works without executing mod generation { use super::*; #[test] fn test_add_function_generation() { let code = generate_add_function(); // Verify that code was generated (non-empty buffer) assert!(!code.is_empty()); // ARM64 instructions are 4 bytes each assert_eq!(code.len() % 4, 0); } #[test] fn test_factorial_generation() { let code = generate_factorial(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); // Factorial should generate more code due to recursion logic assert!(code.len() > 20); } #[test] fn test_array_sum_generation() { let code = generate_array_sum(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] fn test_multiply_by_constant_generation() { // Test power of two (uses shift) - should generate less code let code_pow2 = generate_multiply_by_constant(8); assert!(!code_pow2.is_empty()); assert_eq!(code_pow2.len() % 4, 0); // Test non-power of two (uses mul) - might generate slightly more code let code_regular = generate_multiply_by_constant(7); assert!(!code_regular.is_empty()); assert_eq!(code_regular.len() % 4, 0); } #[test] fn test_vector_add_generation() { let code = generate_vector_add(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] fn test_memcpy_generation() { let code = generate_memcpy(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] #[cfg(target_arch = "aarch64")] fn test_hello_world_generation() { let code = generate_hello_world(); assert!(!code.is_empty()); // Should include both the string data and the code assert!(code.len() > "Hello World!".len()); } #[test] #[cfg(not(target_arch = "aarch64"))] fn test_hello_world_generation_skipped() { // Skip this test on non-ARM64 architectures because // the function address calculation is architecture-specific println!("Skipping hello_world generation test on non-ARM64 architecture"); } } // Tests that execute the generated code - only run on ARM64 #[cfg(all(test, target_arch = "aarch64"))] #[allow(unused_unsafe)] mod execution { use std::mem; use super::*; #[test] fn test_add_function_execution() { let code = generate_add_function(); let add_fn: extern "C" fn(i32, i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { add_fn(5, 3) }, 8); assert_eq!(unsafe { add_fn(-10, 20) }, 10); } #[test] fn test_factorial_execution() { let code = generate_factorial(); let factorial_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { factorial_fn(0) }, 1); assert_eq!(unsafe { factorial_fn(1) }, 1); assert_eq!(unsafe { factorial_fn(5) }, 120); } #[test] fn test_array_sum_execution() { let code = generate_array_sum(); let sum_fn: extern "C" fn(*const i32, usize) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; let array = [1, 2, 3, 4, 5]; assert_eq!(unsafe { sum_fn(array.as_ptr(), array.len()) }, 15); let empty: [i32; 0] = []; assert_eq!(unsafe { sum_fn(empty.as_ptr(), 0) }, 0); } #[test] fn test_multiply_by_constant_execution() { // Test power of two (uses shift) let code = generate_multiply_by_constant(8); let mul_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { mul_fn(5) }, 40); // Test non-power of two (uses mul) let code = generate_multiply_by_constant(7); let mul_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { mul_fn(6) }, 42); } } } /// External function called from assembly to print a buffer. /// /// # Safety /// /// The caller must ensure that: /// - `buffer` points to valid memory for at least `length` bytes /// - The memory pointed to by `buffer` is initialized pub unsafe extern "C" fn print(buffer: *const u8, length: u64) -> bool { io::Write::write_all( &mut io::stdout(), slice::from_raw_parts(buffer, length as usize), ) .is_ok() } }
This print function uses the standard C calling convention (extern "C"
), which on ARM64 passes the first two arguments in X0 and X1 registers. The generated assembly code follows this convention when calling the function.
Simple Code Generation
For straightforward operations, dynasm-rs makes code generation remarkably concise:
#![allow(unused)] fn main() { use std::{io, slice}; use dynasmrt::{dynasm, DynasmApi, DynasmLabelApi, ExecutableBuffer}; /// Generates a simple "Hello World" function using ARM64 assembly. /// /// This example demonstrates: /// - Embedding data directly in the assembly /// - Using labels for addressing /// - Calling external Rust functions from assembly /// - Stack management for ARM64 ABI pub fn generate_hello_world() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); let string = "Hello World!"; // Embed the string data with a label dynasm!(ops ; .arch aarch64 ; ->hello: ; .bytes string.as_bytes() ); // Generate the function that prints the string // Load the 64-bit function address in chunks (16 bits at a time) let print_addr = print as *const () as usize; dynasm!(ops ; .arch aarch64 ; adr x0, ->hello // Load string address into first arg (ARM64 ABI) ; mov w1, string.len() as u32 // Load string length into second arg ; movz x2, (print_addr & 0xFFFF) as u32 ; movk x2, ((print_addr >> 16) & 0xFFFF) as u32, lsl 16 ; movk x2, ((print_addr >> 32) & 0xFFFF) as u32, lsl 32 ; movk x2, ((print_addr >> 48) & 0xFFFF) as u32, lsl 48 ; blr x2 // Call the print function ; ret // Return ); ops.finalize().unwrap() } /// External function called from assembly to print a buffer. /// /// # Safety /// /// The caller must ensure that: /// - `buffer` points to valid memory for at least `length` bytes /// - The memory pointed to by `buffer` is initialized pub unsafe extern "C" fn print(buffer: *const u8, length: u64) -> bool { io::Write::write_all( &mut io::stdout(), slice::from_raw_parts(buffer, length as usize), ) .is_ok() } /// Generates a factorial function using recursion. /// /// Demonstrates more complex control flow with conditional jumps and recursive /// calls. pub fn generate_factorial() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); let entry_label = ops.new_dynamic_label(); ops.dynamic_label(entry_label); dynasm!(ops ; .arch aarch64 ; cmp w0, #1 // Compare n with 1 ; b.le ->base_case // Jump if n <= 1 ; stp x29, x30, [sp, #-16]! // Save frame pointer and link register ; stp x19, x20, [sp, #-16]! // Save callee-saved registers ; mov w19, w0 // Save n in w19 ; sub w0, w0, #1 // n - 1 ; adr x1, =>entry_label // Load our own address for recursion ; blr x1 // Recursive call with n-1 ; mul w0, w0, w19 // Multiply result by n ; ldp x19, x20, [sp], #16 // Restore callee-saved registers ; ldp x29, x30, [sp], #16 // Restore frame pointer and link register ; ret ; ->base_case: ; mov w0, #1 // Return 1 for base case ; ret ); ops.finalize().unwrap() } /// Generates a loop that sums an array of integers. /// /// Takes a pointer to an i32 array and its length, returns the sum. pub fn generate_array_sum() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; mov w2, #0 // Initialize sum to 0 ; cbz x1, ->done // If length is 0, return 0 ; ->loop_start: ; ldr w3, [x0], #4 // Load element and increment pointer ; add w2, w2, w3 // Add to sum ; sub x1, x1, #1 // Decrement counter ; cbnz x1, ->loop_start // Continue if not zero ; ->done: ; mov w0, w2 // Move result to return register ; ret ); ops.finalize().unwrap() } /// Generates a function that performs SIMD operations using NEON instructions. /// /// Adds two integer vectors element by element. pub fn generate_vector_add() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; ldp x3, x4, [x0] // Load first two elements from vector 1 ; ldp x5, x6, [x1] // Load first two elements from vector 2 ; add x3, x3, x5 // Add first elements ; add x4, x4, x6 // Add second elements ; stp x3, x4, [x2] // Store result ; ret ); ops.finalize().unwrap() } /// Demonstrates conditional compilation based on runtime values. /// /// Generates specialized code for specific constant values. pub fn generate_multiply_by_constant(constant: i32) -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); // Optimize for powers of two using shifts if constant > 0 && (constant & (constant - 1)) == 0 { let shift = constant.trailing_zeros(); dynasm!(ops ; .arch aarch64 ; lsl w0, w0, shift // Shift left for power of 2 ; ret ); } else { dynasm!(ops ; .arch aarch64 ; mov w1, constant as u32 // Load constant ; mul w0, w0, w1 // Multiply ; ret ); } ops.finalize().unwrap() } /// Generates a memcpy implementation optimized for small sizes. /// /// Uses a simple byte copy loop for ARM64. pub fn generate_memcpy() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; mov x3, x0 // Save destination for return ; cbz x2, ->done // If count is 0, return ; ->copy_loop: ; ldrb w4, [x1], #1 // Load byte from source and increment ; strb w4, [x0], #1 // Store byte to dest and increment ; sub x2, x2, #1 // Decrement count ; cbnz x2, ->copy_loop // Continue if not zero ; ->done: ; mov x0, x3 // Return original destination ; ret ); ops.finalize().unwrap() } /// Helper function to execute generated code safely. /// /// Converts the generated bytes into an executable function pointer. /// /// # Safety /// /// The caller must ensure that: /// - `code` contains valid machine code for the target architecture /// - The code follows the expected calling convention /// - The function pointer type matches the actual generated code signature pub unsafe fn execute_generated_code<F, R>(code: &[u8], f: F) -> R where F: FnOnce(*const u8) -> R, { f(code.as_ptr()) } #[cfg(test)] mod tests { use super::*; // Tests that verify code generation works without executing mod generation { use super::*; #[test] fn test_add_function_generation() { let code = generate_add_function(); // Verify that code was generated (non-empty buffer) assert!(!code.is_empty()); // ARM64 instructions are 4 bytes each assert_eq!(code.len() % 4, 0); } #[test] fn test_factorial_generation() { let code = generate_factorial(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); // Factorial should generate more code due to recursion logic assert!(code.len() > 20); } #[test] fn test_array_sum_generation() { let code = generate_array_sum(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] fn test_multiply_by_constant_generation() { // Test power of two (uses shift) - should generate less code let code_pow2 = generate_multiply_by_constant(8); assert!(!code_pow2.is_empty()); assert_eq!(code_pow2.len() % 4, 0); // Test non-power of two (uses mul) - might generate slightly more code let code_regular = generate_multiply_by_constant(7); assert!(!code_regular.is_empty()); assert_eq!(code_regular.len() % 4, 0); } #[test] fn test_vector_add_generation() { let code = generate_vector_add(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] fn test_memcpy_generation() { let code = generate_memcpy(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] #[cfg(target_arch = "aarch64")] fn test_hello_world_generation() { let code = generate_hello_world(); assert!(!code.is_empty()); // Should include both the string data and the code assert!(code.len() > "Hello World!".len()); } #[test] #[cfg(not(target_arch = "aarch64"))] fn test_hello_world_generation_skipped() { // Skip this test on non-ARM64 architectures because // the function address calculation is architecture-specific println!("Skipping hello_world generation test on non-ARM64 architecture"); } } // Tests that execute the generated code - only run on ARM64 #[cfg(all(test, target_arch = "aarch64"))] #[allow(unused_unsafe)] mod execution { use std::mem; use super::*; #[test] fn test_add_function_execution() { let code = generate_add_function(); let add_fn: extern "C" fn(i32, i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { add_fn(5, 3) }, 8); assert_eq!(unsafe { add_fn(-10, 20) }, 10); } #[test] fn test_factorial_execution() { let code = generate_factorial(); let factorial_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { factorial_fn(0) }, 1); assert_eq!(unsafe { factorial_fn(1) }, 1); assert_eq!(unsafe { factorial_fn(5) }, 120); } #[test] fn test_array_sum_execution() { let code = generate_array_sum(); let sum_fn: extern "C" fn(*const i32, usize) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; let array = [1, 2, 3, 4, 5]; assert_eq!(unsafe { sum_fn(array.as_ptr(), array.len()) }, 15); let empty: [i32; 0] = []; assert_eq!(unsafe { sum_fn(empty.as_ptr(), 0) }, 0); } #[test] fn test_multiply_by_constant_execution() { // Test power of two (uses shift) let code = generate_multiply_by_constant(8); let mul_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { mul_fn(5) }, 40); // Test non-power of two (uses mul) let code = generate_multiply_by_constant(7); let mul_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { mul_fn(6) }, 42); } } } /// Generates an optimized addition function for two integers. /// /// Creates machine code equivalent to: `fn add(a: i32, b: i32) -> i32 { a + b /// }` pub fn generate_add_function() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; add w0, w0, w1 // Add w1 to w0, result in w0 (ARM64 ABI) ; ret // Return with result in w0 ); ops.finalize().unwrap() } }
This generates machine code equivalent to a simple addition function. The assembly directly manipulates registers according to the calling convention, avoiding any overhead from function prologue or epilogue when unnecessary.
Control Flow
More complex control flow patterns like recursion and loops are fully supported:
#![allow(unused)] fn main() { use std::{io, slice}; use dynasmrt::{dynasm, DynasmApi, DynasmLabelApi, ExecutableBuffer}; /// Generates a simple "Hello World" function using ARM64 assembly. /// /// This example demonstrates: /// - Embedding data directly in the assembly /// - Using labels for addressing /// - Calling external Rust functions from assembly /// - Stack management for ARM64 ABI pub fn generate_hello_world() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); let string = "Hello World!"; // Embed the string data with a label dynasm!(ops ; .arch aarch64 ; ->hello: ; .bytes string.as_bytes() ); // Generate the function that prints the string // Load the 64-bit function address in chunks (16 bits at a time) let print_addr = print as *const () as usize; dynasm!(ops ; .arch aarch64 ; adr x0, ->hello // Load string address into first arg (ARM64 ABI) ; mov w1, string.len() as u32 // Load string length into second arg ; movz x2, (print_addr & 0xFFFF) as u32 ; movk x2, ((print_addr >> 16) & 0xFFFF) as u32, lsl 16 ; movk x2, ((print_addr >> 32) & 0xFFFF) as u32, lsl 32 ; movk x2, ((print_addr >> 48) & 0xFFFF) as u32, lsl 48 ; blr x2 // Call the print function ; ret // Return ); ops.finalize().unwrap() } /// External function called from assembly to print a buffer. /// /// # Safety /// /// The caller must ensure that: /// - `buffer` points to valid memory for at least `length` bytes /// - The memory pointed to by `buffer` is initialized pub unsafe extern "C" fn print(buffer: *const u8, length: u64) -> bool { io::Write::write_all( &mut io::stdout(), slice::from_raw_parts(buffer, length as usize), ) .is_ok() } /// Generates an optimized addition function for two integers. /// /// Creates machine code equivalent to: `fn add(a: i32, b: i32) -> i32 { a + b /// }` pub fn generate_add_function() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; add w0, w0, w1 // Add w1 to w0, result in w0 (ARM64 ABI) ; ret // Return with result in w0 ); ops.finalize().unwrap() } /// Generates a loop that sums an array of integers. /// /// Takes a pointer to an i32 array and its length, returns the sum. pub fn generate_array_sum() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; mov w2, #0 // Initialize sum to 0 ; cbz x1, ->done // If length is 0, return 0 ; ->loop_start: ; ldr w3, [x0], #4 // Load element and increment pointer ; add w2, w2, w3 // Add to sum ; sub x1, x1, #1 // Decrement counter ; cbnz x1, ->loop_start // Continue if not zero ; ->done: ; mov w0, w2 // Move result to return register ; ret ); ops.finalize().unwrap() } /// Generates a function that performs SIMD operations using NEON instructions. /// /// Adds two integer vectors element by element. pub fn generate_vector_add() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; ldp x3, x4, [x0] // Load first two elements from vector 1 ; ldp x5, x6, [x1] // Load first two elements from vector 2 ; add x3, x3, x5 // Add first elements ; add x4, x4, x6 // Add second elements ; stp x3, x4, [x2] // Store result ; ret ); ops.finalize().unwrap() } /// Demonstrates conditional compilation based on runtime values. /// /// Generates specialized code for specific constant values. pub fn generate_multiply_by_constant(constant: i32) -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); // Optimize for powers of two using shifts if constant > 0 && (constant & (constant - 1)) == 0 { let shift = constant.trailing_zeros(); dynasm!(ops ; .arch aarch64 ; lsl w0, w0, shift // Shift left for power of 2 ; ret ); } else { dynasm!(ops ; .arch aarch64 ; mov w1, constant as u32 // Load constant ; mul w0, w0, w1 // Multiply ; ret ); } ops.finalize().unwrap() } /// Generates a memcpy implementation optimized for small sizes. /// /// Uses a simple byte copy loop for ARM64. pub fn generate_memcpy() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; mov x3, x0 // Save destination for return ; cbz x2, ->done // If count is 0, return ; ->copy_loop: ; ldrb w4, [x1], #1 // Load byte from source and increment ; strb w4, [x0], #1 // Store byte to dest and increment ; sub x2, x2, #1 // Decrement count ; cbnz x2, ->copy_loop // Continue if not zero ; ->done: ; mov x0, x3 // Return original destination ; ret ); ops.finalize().unwrap() } /// Helper function to execute generated code safely. /// /// Converts the generated bytes into an executable function pointer. /// /// # Safety /// /// The caller must ensure that: /// - `code` contains valid machine code for the target architecture /// - The code follows the expected calling convention /// - The function pointer type matches the actual generated code signature pub unsafe fn execute_generated_code<F, R>(code: &[u8], f: F) -> R where F: FnOnce(*const u8) -> R, { f(code.as_ptr()) } #[cfg(test)] mod tests { use super::*; // Tests that verify code generation works without executing mod generation { use super::*; #[test] fn test_add_function_generation() { let code = generate_add_function(); // Verify that code was generated (non-empty buffer) assert!(!code.is_empty()); // ARM64 instructions are 4 bytes each assert_eq!(code.len() % 4, 0); } #[test] fn test_factorial_generation() { let code = generate_factorial(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); // Factorial should generate more code due to recursion logic assert!(code.len() > 20); } #[test] fn test_array_sum_generation() { let code = generate_array_sum(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] fn test_multiply_by_constant_generation() { // Test power of two (uses shift) - should generate less code let code_pow2 = generate_multiply_by_constant(8); assert!(!code_pow2.is_empty()); assert_eq!(code_pow2.len() % 4, 0); // Test non-power of two (uses mul) - might generate slightly more code let code_regular = generate_multiply_by_constant(7); assert!(!code_regular.is_empty()); assert_eq!(code_regular.len() % 4, 0); } #[test] fn test_vector_add_generation() { let code = generate_vector_add(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] fn test_memcpy_generation() { let code = generate_memcpy(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] #[cfg(target_arch = "aarch64")] fn test_hello_world_generation() { let code = generate_hello_world(); assert!(!code.is_empty()); // Should include both the string data and the code assert!(code.len() > "Hello World!".len()); } #[test] #[cfg(not(target_arch = "aarch64"))] fn test_hello_world_generation_skipped() { // Skip this test on non-ARM64 architectures because // the function address calculation is architecture-specific println!("Skipping hello_world generation test on non-ARM64 architecture"); } } // Tests that execute the generated code - only run on ARM64 #[cfg(all(test, target_arch = "aarch64"))] #[allow(unused_unsafe)] mod execution { use std::mem; use super::*; #[test] fn test_add_function_execution() { let code = generate_add_function(); let add_fn: extern "C" fn(i32, i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { add_fn(5, 3) }, 8); assert_eq!(unsafe { add_fn(-10, 20) }, 10); } #[test] fn test_factorial_execution() { let code = generate_factorial(); let factorial_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { factorial_fn(0) }, 1); assert_eq!(unsafe { factorial_fn(1) }, 1); assert_eq!(unsafe { factorial_fn(5) }, 120); } #[test] fn test_array_sum_execution() { let code = generate_array_sum(); let sum_fn: extern "C" fn(*const i32, usize) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; let array = [1, 2, 3, 4, 5]; assert_eq!(unsafe { sum_fn(array.as_ptr(), array.len()) }, 15); let empty: [i32; 0] = []; assert_eq!(unsafe { sum_fn(empty.as_ptr(), 0) }, 0); } #[test] fn test_multiply_by_constant_execution() { // Test power of two (uses shift) let code = generate_multiply_by_constant(8); let mul_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { mul_fn(5) }, 40); // Test non-power of two (uses mul) let code = generate_multiply_by_constant(7); let mul_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { mul_fn(6) }, 42); } } } /// Generates a factorial function using recursion. /// /// Demonstrates more complex control flow with conditional jumps and recursive /// calls. pub fn generate_factorial() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); let entry_label = ops.new_dynamic_label(); ops.dynamic_label(entry_label); dynasm!(ops ; .arch aarch64 ; cmp w0, #1 // Compare n with 1 ; b.le ->base_case // Jump if n <= 1 ; stp x29, x30, [sp, #-16]! // Save frame pointer and link register ; stp x19, x20, [sp, #-16]! // Save callee-saved registers ; mov w19, w0 // Save n in w19 ; sub w0, w0, #1 // n - 1 ; adr x1, =>entry_label // Load our own address for recursion ; blr x1 // Recursive call with n-1 ; mul w0, w0, w19 // Multiply result by n ; ldp x19, x20, [sp], #16 // Restore callee-saved registers ; ldp x29, x30, [sp], #16 // Restore frame pointer and link register ; ret ; ->base_case: ; mov w0, #1 // Return 1 for base case ; ret ); ops.finalize().unwrap() } }
The factorial implementation demonstrates conditional jumps, stack management for callee-saved registers, and recursive calls. The assembler handles label resolution and relative addressing automatically.
Working with Memory
dynasm-rs excels at generating efficient memory access patterns:
#![allow(unused)] fn main() { use std::{io, slice}; use dynasmrt::{dynasm, DynasmApi, DynasmLabelApi, ExecutableBuffer}; /// Generates a simple "Hello World" function using ARM64 assembly. /// /// This example demonstrates: /// - Embedding data directly in the assembly /// - Using labels for addressing /// - Calling external Rust functions from assembly /// - Stack management for ARM64 ABI pub fn generate_hello_world() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); let string = "Hello World!"; // Embed the string data with a label dynasm!(ops ; .arch aarch64 ; ->hello: ; .bytes string.as_bytes() ); // Generate the function that prints the string // Load the 64-bit function address in chunks (16 bits at a time) let print_addr = print as *const () as usize; dynasm!(ops ; .arch aarch64 ; adr x0, ->hello // Load string address into first arg (ARM64 ABI) ; mov w1, string.len() as u32 // Load string length into second arg ; movz x2, (print_addr & 0xFFFF) as u32 ; movk x2, ((print_addr >> 16) & 0xFFFF) as u32, lsl 16 ; movk x2, ((print_addr >> 32) & 0xFFFF) as u32, lsl 32 ; movk x2, ((print_addr >> 48) & 0xFFFF) as u32, lsl 48 ; blr x2 // Call the print function ; ret // Return ); ops.finalize().unwrap() } /// External function called from assembly to print a buffer. /// /// # Safety /// /// The caller must ensure that: /// - `buffer` points to valid memory for at least `length` bytes /// - The memory pointed to by `buffer` is initialized pub unsafe extern "C" fn print(buffer: *const u8, length: u64) -> bool { io::Write::write_all( &mut io::stdout(), slice::from_raw_parts(buffer, length as usize), ) .is_ok() } /// Generates an optimized addition function for two integers. /// /// Creates machine code equivalent to: `fn add(a: i32, b: i32) -> i32 { a + b /// }` pub fn generate_add_function() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; add w0, w0, w1 // Add w1 to w0, result in w0 (ARM64 ABI) ; ret // Return with result in w0 ); ops.finalize().unwrap() } /// Generates a factorial function using recursion. /// /// Demonstrates more complex control flow with conditional jumps and recursive /// calls. pub fn generate_factorial() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); let entry_label = ops.new_dynamic_label(); ops.dynamic_label(entry_label); dynasm!(ops ; .arch aarch64 ; cmp w0, #1 // Compare n with 1 ; b.le ->base_case // Jump if n <= 1 ; stp x29, x30, [sp, #-16]! // Save frame pointer and link register ; stp x19, x20, [sp, #-16]! // Save callee-saved registers ; mov w19, w0 // Save n in w19 ; sub w0, w0, #1 // n - 1 ; adr x1, =>entry_label // Load our own address for recursion ; blr x1 // Recursive call with n-1 ; mul w0, w0, w19 // Multiply result by n ; ldp x19, x20, [sp], #16 // Restore callee-saved registers ; ldp x29, x30, [sp], #16 // Restore frame pointer and link register ; ret ; ->base_case: ; mov w0, #1 // Return 1 for base case ; ret ); ops.finalize().unwrap() } /// Generates a function that performs SIMD operations using NEON instructions. /// /// Adds two integer vectors element by element. pub fn generate_vector_add() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; ldp x3, x4, [x0] // Load first two elements from vector 1 ; ldp x5, x6, [x1] // Load first two elements from vector 2 ; add x3, x3, x5 // Add first elements ; add x4, x4, x6 // Add second elements ; stp x3, x4, [x2] // Store result ; ret ); ops.finalize().unwrap() } /// Demonstrates conditional compilation based on runtime values. /// /// Generates specialized code for specific constant values. pub fn generate_multiply_by_constant(constant: i32) -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); // Optimize for powers of two using shifts if constant > 0 && (constant & (constant - 1)) == 0 { let shift = constant.trailing_zeros(); dynasm!(ops ; .arch aarch64 ; lsl w0, w0, shift // Shift left for power of 2 ; ret ); } else { dynasm!(ops ; .arch aarch64 ; mov w1, constant as u32 // Load constant ; mul w0, w0, w1 // Multiply ; ret ); } ops.finalize().unwrap() } /// Generates a memcpy implementation optimized for small sizes. /// /// Uses a simple byte copy loop for ARM64. pub fn generate_memcpy() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; mov x3, x0 // Save destination for return ; cbz x2, ->done // If count is 0, return ; ->copy_loop: ; ldrb w4, [x1], #1 // Load byte from source and increment ; strb w4, [x0], #1 // Store byte to dest and increment ; sub x2, x2, #1 // Decrement count ; cbnz x2, ->copy_loop // Continue if not zero ; ->done: ; mov x0, x3 // Return original destination ; ret ); ops.finalize().unwrap() } /// Helper function to execute generated code safely. /// /// Converts the generated bytes into an executable function pointer. /// /// # Safety /// /// The caller must ensure that: /// - `code` contains valid machine code for the target architecture /// - The code follows the expected calling convention /// - The function pointer type matches the actual generated code signature pub unsafe fn execute_generated_code<F, R>(code: &[u8], f: F) -> R where F: FnOnce(*const u8) -> R, { f(code.as_ptr()) } #[cfg(test)] mod tests { use super::*; // Tests that verify code generation works without executing mod generation { use super::*; #[test] fn test_add_function_generation() { let code = generate_add_function(); // Verify that code was generated (non-empty buffer) assert!(!code.is_empty()); // ARM64 instructions are 4 bytes each assert_eq!(code.len() % 4, 0); } #[test] fn test_factorial_generation() { let code = generate_factorial(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); // Factorial should generate more code due to recursion logic assert!(code.len() > 20); } #[test] fn test_array_sum_generation() { let code = generate_array_sum(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] fn test_multiply_by_constant_generation() { // Test power of two (uses shift) - should generate less code let code_pow2 = generate_multiply_by_constant(8); assert!(!code_pow2.is_empty()); assert_eq!(code_pow2.len() % 4, 0); // Test non-power of two (uses mul) - might generate slightly more code let code_regular = generate_multiply_by_constant(7); assert!(!code_regular.is_empty()); assert_eq!(code_regular.len() % 4, 0); } #[test] fn test_vector_add_generation() { let code = generate_vector_add(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] fn test_memcpy_generation() { let code = generate_memcpy(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] #[cfg(target_arch = "aarch64")] fn test_hello_world_generation() { let code = generate_hello_world(); assert!(!code.is_empty()); // Should include both the string data and the code assert!(code.len() > "Hello World!".len()); } #[test] #[cfg(not(target_arch = "aarch64"))] fn test_hello_world_generation_skipped() { // Skip this test on non-ARM64 architectures because // the function address calculation is architecture-specific println!("Skipping hello_world generation test on non-ARM64 architecture"); } } // Tests that execute the generated code - only run on ARM64 #[cfg(all(test, target_arch = "aarch64"))] #[allow(unused_unsafe)] mod execution { use std::mem; use super::*; #[test] fn test_add_function_execution() { let code = generate_add_function(); let add_fn: extern "C" fn(i32, i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { add_fn(5, 3) }, 8); assert_eq!(unsafe { add_fn(-10, 20) }, 10); } #[test] fn test_factorial_execution() { let code = generate_factorial(); let factorial_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { factorial_fn(0) }, 1); assert_eq!(unsafe { factorial_fn(1) }, 1); assert_eq!(unsafe { factorial_fn(5) }, 120); } #[test] fn test_array_sum_execution() { let code = generate_array_sum(); let sum_fn: extern "C" fn(*const i32, usize) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; let array = [1, 2, 3, 4, 5]; assert_eq!(unsafe { sum_fn(array.as_ptr(), array.len()) }, 15); let empty: [i32; 0] = []; assert_eq!(unsafe { sum_fn(empty.as_ptr(), 0) }, 0); } #[test] fn test_multiply_by_constant_execution() { // Test power of two (uses shift) let code = generate_multiply_by_constant(8); let mul_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { mul_fn(5) }, 40); // Test non-power of two (uses mul) let code = generate_multiply_by_constant(7); let mul_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { mul_fn(6) }, 42); } } } /// Generates a loop that sums an array of integers. /// /// Takes a pointer to an i32 array and its length, returns the sum. pub fn generate_array_sum() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; mov w2, #0 // Initialize sum to 0 ; cbz x1, ->done // If length is 0, return 0 ; ->loop_start: ; ldr w3, [x0], #4 // Load element and increment pointer ; add w2, w2, w3 // Add to sum ; sub x1, x1, #1 // Decrement counter ; cbnz x1, ->loop_start // Continue if not zero ; ->done: ; mov w0, w2 // Move result to return register ; ret ); ops.finalize().unwrap() } }
This array summation routine showcases pointer arithmetic and loop control. The generated code is as efficient as hand-written assembly, with no abstraction overhead.
SIMD Operations
Modern processors provide SIMD instructions for parallel data processing. dynasm-rs supports these advanced instruction sets:
#![allow(unused)] fn main() { use std::{io, slice}; use dynasmrt::{dynasm, DynasmApi, DynasmLabelApi, ExecutableBuffer}; /// Generates a simple "Hello World" function using ARM64 assembly. /// /// This example demonstrates: /// - Embedding data directly in the assembly /// - Using labels for addressing /// - Calling external Rust functions from assembly /// - Stack management for ARM64 ABI pub fn generate_hello_world() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); let string = "Hello World!"; // Embed the string data with a label dynasm!(ops ; .arch aarch64 ; ->hello: ; .bytes string.as_bytes() ); // Generate the function that prints the string // Load the 64-bit function address in chunks (16 bits at a time) let print_addr = print as *const () as usize; dynasm!(ops ; .arch aarch64 ; adr x0, ->hello // Load string address into first arg (ARM64 ABI) ; mov w1, string.len() as u32 // Load string length into second arg ; movz x2, (print_addr & 0xFFFF) as u32 ; movk x2, ((print_addr >> 16) & 0xFFFF) as u32, lsl 16 ; movk x2, ((print_addr >> 32) & 0xFFFF) as u32, lsl 32 ; movk x2, ((print_addr >> 48) & 0xFFFF) as u32, lsl 48 ; blr x2 // Call the print function ; ret // Return ); ops.finalize().unwrap() } /// External function called from assembly to print a buffer. /// /// # Safety /// /// The caller must ensure that: /// - `buffer` points to valid memory for at least `length` bytes /// - The memory pointed to by `buffer` is initialized pub unsafe extern "C" fn print(buffer: *const u8, length: u64) -> bool { io::Write::write_all( &mut io::stdout(), slice::from_raw_parts(buffer, length as usize), ) .is_ok() } /// Generates an optimized addition function for two integers. /// /// Creates machine code equivalent to: `fn add(a: i32, b: i32) -> i32 { a + b /// }` pub fn generate_add_function() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; add w0, w0, w1 // Add w1 to w0, result in w0 (ARM64 ABI) ; ret // Return with result in w0 ); ops.finalize().unwrap() } /// Generates a factorial function using recursion. /// /// Demonstrates more complex control flow with conditional jumps and recursive /// calls. pub fn generate_factorial() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); let entry_label = ops.new_dynamic_label(); ops.dynamic_label(entry_label); dynasm!(ops ; .arch aarch64 ; cmp w0, #1 // Compare n with 1 ; b.le ->base_case // Jump if n <= 1 ; stp x29, x30, [sp, #-16]! // Save frame pointer and link register ; stp x19, x20, [sp, #-16]! // Save callee-saved registers ; mov w19, w0 // Save n in w19 ; sub w0, w0, #1 // n - 1 ; adr x1, =>entry_label // Load our own address for recursion ; blr x1 // Recursive call with n-1 ; mul w0, w0, w19 // Multiply result by n ; ldp x19, x20, [sp], #16 // Restore callee-saved registers ; ldp x29, x30, [sp], #16 // Restore frame pointer and link register ; ret ; ->base_case: ; mov w0, #1 // Return 1 for base case ; ret ); ops.finalize().unwrap() } /// Generates a loop that sums an array of integers. /// /// Takes a pointer to an i32 array and its length, returns the sum. pub fn generate_array_sum() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; mov w2, #0 // Initialize sum to 0 ; cbz x1, ->done // If length is 0, return 0 ; ->loop_start: ; ldr w3, [x0], #4 // Load element and increment pointer ; add w2, w2, w3 // Add to sum ; sub x1, x1, #1 // Decrement counter ; cbnz x1, ->loop_start // Continue if not zero ; ->done: ; mov w0, w2 // Move result to return register ; ret ); ops.finalize().unwrap() } /// Demonstrates conditional compilation based on runtime values. /// /// Generates specialized code for specific constant values. pub fn generate_multiply_by_constant(constant: i32) -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); // Optimize for powers of two using shifts if constant > 0 && (constant & (constant - 1)) == 0 { let shift = constant.trailing_zeros(); dynasm!(ops ; .arch aarch64 ; lsl w0, w0, shift // Shift left for power of 2 ; ret ); } else { dynasm!(ops ; .arch aarch64 ; mov w1, constant as u32 // Load constant ; mul w0, w0, w1 // Multiply ; ret ); } ops.finalize().unwrap() } /// Generates a memcpy implementation optimized for small sizes. /// /// Uses a simple byte copy loop for ARM64. pub fn generate_memcpy() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; mov x3, x0 // Save destination for return ; cbz x2, ->done // If count is 0, return ; ->copy_loop: ; ldrb w4, [x1], #1 // Load byte from source and increment ; strb w4, [x0], #1 // Store byte to dest and increment ; sub x2, x2, #1 // Decrement count ; cbnz x2, ->copy_loop // Continue if not zero ; ->done: ; mov x0, x3 // Return original destination ; ret ); ops.finalize().unwrap() } /// Helper function to execute generated code safely. /// /// Converts the generated bytes into an executable function pointer. /// /// # Safety /// /// The caller must ensure that: /// - `code` contains valid machine code for the target architecture /// - The code follows the expected calling convention /// - The function pointer type matches the actual generated code signature pub unsafe fn execute_generated_code<F, R>(code: &[u8], f: F) -> R where F: FnOnce(*const u8) -> R, { f(code.as_ptr()) } #[cfg(test)] mod tests { use super::*; // Tests that verify code generation works without executing mod generation { use super::*; #[test] fn test_add_function_generation() { let code = generate_add_function(); // Verify that code was generated (non-empty buffer) assert!(!code.is_empty()); // ARM64 instructions are 4 bytes each assert_eq!(code.len() % 4, 0); } #[test] fn test_factorial_generation() { let code = generate_factorial(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); // Factorial should generate more code due to recursion logic assert!(code.len() > 20); } #[test] fn test_array_sum_generation() { let code = generate_array_sum(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] fn test_multiply_by_constant_generation() { // Test power of two (uses shift) - should generate less code let code_pow2 = generate_multiply_by_constant(8); assert!(!code_pow2.is_empty()); assert_eq!(code_pow2.len() % 4, 0); // Test non-power of two (uses mul) - might generate slightly more code let code_regular = generate_multiply_by_constant(7); assert!(!code_regular.is_empty()); assert_eq!(code_regular.len() % 4, 0); } #[test] fn test_vector_add_generation() { let code = generate_vector_add(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] fn test_memcpy_generation() { let code = generate_memcpy(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] #[cfg(target_arch = "aarch64")] fn test_hello_world_generation() { let code = generate_hello_world(); assert!(!code.is_empty()); // Should include both the string data and the code assert!(code.len() > "Hello World!".len()); } #[test] #[cfg(not(target_arch = "aarch64"))] fn test_hello_world_generation_skipped() { // Skip this test on non-ARM64 architectures because // the function address calculation is architecture-specific println!("Skipping hello_world generation test on non-ARM64 architecture"); } } // Tests that execute the generated code - only run on ARM64 #[cfg(all(test, target_arch = "aarch64"))] #[allow(unused_unsafe)] mod execution { use std::mem; use super::*; #[test] fn test_add_function_execution() { let code = generate_add_function(); let add_fn: extern "C" fn(i32, i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { add_fn(5, 3) }, 8); assert_eq!(unsafe { add_fn(-10, 20) }, 10); } #[test] fn test_factorial_execution() { let code = generate_factorial(); let factorial_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { factorial_fn(0) }, 1); assert_eq!(unsafe { factorial_fn(1) }, 1); assert_eq!(unsafe { factorial_fn(5) }, 120); } #[test] fn test_array_sum_execution() { let code = generate_array_sum(); let sum_fn: extern "C" fn(*const i32, usize) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; let array = [1, 2, 3, 4, 5]; assert_eq!(unsafe { sum_fn(array.as_ptr(), array.len()) }, 15); let empty: [i32; 0] = []; assert_eq!(unsafe { sum_fn(empty.as_ptr(), 0) }, 0); } #[test] fn test_multiply_by_constant_execution() { // Test power of two (uses shift) let code = generate_multiply_by_constant(8); let mul_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { mul_fn(5) }, 40); // Test non-power of two (uses mul) let code = generate_multiply_by_constant(7); let mul_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { mul_fn(6) }, 42); } } } /// Generates a function that performs SIMD operations using NEON instructions. /// /// Adds two integer vectors element by element. pub fn generate_vector_add() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; ldp x3, x4, [x0] // Load first two elements from vector 1 ; ldp x5, x6, [x1] // Load first two elements from vector 2 ; add x3, x3, x5 // Add first elements ; add x4, x4, x6 // Add second elements ; stp x3, x4, [x2] // Store result ; ret ); ops.finalize().unwrap() } }
This example demonstrates vector operations on ARM64. While the current implementation uses general-purpose registers for simplicity, ARM64’s NEON instruction set provides extensive SIMD capabilities for more complex parallel operations.
Runtime Specialization
One of dynasm-rs’s key strengths is generating specialized code based on runtime information:
#![allow(unused)] fn main() { use std::{io, slice}; use dynasmrt::{dynasm, DynasmApi, DynasmLabelApi, ExecutableBuffer}; /// Generates a simple "Hello World" function using ARM64 assembly. /// /// This example demonstrates: /// - Embedding data directly in the assembly /// - Using labels for addressing /// - Calling external Rust functions from assembly /// - Stack management for ARM64 ABI pub fn generate_hello_world() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); let string = "Hello World!"; // Embed the string data with a label dynasm!(ops ; .arch aarch64 ; ->hello: ; .bytes string.as_bytes() ); // Generate the function that prints the string // Load the 64-bit function address in chunks (16 bits at a time) let print_addr = print as *const () as usize; dynasm!(ops ; .arch aarch64 ; adr x0, ->hello // Load string address into first arg (ARM64 ABI) ; mov w1, string.len() as u32 // Load string length into second arg ; movz x2, (print_addr & 0xFFFF) as u32 ; movk x2, ((print_addr >> 16) & 0xFFFF) as u32, lsl 16 ; movk x2, ((print_addr >> 32) & 0xFFFF) as u32, lsl 32 ; movk x2, ((print_addr >> 48) & 0xFFFF) as u32, lsl 48 ; blr x2 // Call the print function ; ret // Return ); ops.finalize().unwrap() } /// External function called from assembly to print a buffer. /// /// # Safety /// /// The caller must ensure that: /// - `buffer` points to valid memory for at least `length` bytes /// - The memory pointed to by `buffer` is initialized pub unsafe extern "C" fn print(buffer: *const u8, length: u64) -> bool { io::Write::write_all( &mut io::stdout(), slice::from_raw_parts(buffer, length as usize), ) .is_ok() } /// Generates an optimized addition function for two integers. /// /// Creates machine code equivalent to: `fn add(a: i32, b: i32) -> i32 { a + b /// }` pub fn generate_add_function() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; add w0, w0, w1 // Add w1 to w0, result in w0 (ARM64 ABI) ; ret // Return with result in w0 ); ops.finalize().unwrap() } /// Generates a factorial function using recursion. /// /// Demonstrates more complex control flow with conditional jumps and recursive /// calls. pub fn generate_factorial() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); let entry_label = ops.new_dynamic_label(); ops.dynamic_label(entry_label); dynasm!(ops ; .arch aarch64 ; cmp w0, #1 // Compare n with 1 ; b.le ->base_case // Jump if n <= 1 ; stp x29, x30, [sp, #-16]! // Save frame pointer and link register ; stp x19, x20, [sp, #-16]! // Save callee-saved registers ; mov w19, w0 // Save n in w19 ; sub w0, w0, #1 // n - 1 ; adr x1, =>entry_label // Load our own address for recursion ; blr x1 // Recursive call with n-1 ; mul w0, w0, w19 // Multiply result by n ; ldp x19, x20, [sp], #16 // Restore callee-saved registers ; ldp x29, x30, [sp], #16 // Restore frame pointer and link register ; ret ; ->base_case: ; mov w0, #1 // Return 1 for base case ; ret ); ops.finalize().unwrap() } /// Generates a loop that sums an array of integers. /// /// Takes a pointer to an i32 array and its length, returns the sum. pub fn generate_array_sum() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; mov w2, #0 // Initialize sum to 0 ; cbz x1, ->done // If length is 0, return 0 ; ->loop_start: ; ldr w3, [x0], #4 // Load element and increment pointer ; add w2, w2, w3 // Add to sum ; sub x1, x1, #1 // Decrement counter ; cbnz x1, ->loop_start // Continue if not zero ; ->done: ; mov w0, w2 // Move result to return register ; ret ); ops.finalize().unwrap() } /// Generates a function that performs SIMD operations using NEON instructions. /// /// Adds two integer vectors element by element. pub fn generate_vector_add() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; ldp x3, x4, [x0] // Load first two elements from vector 1 ; ldp x5, x6, [x1] // Load first two elements from vector 2 ; add x3, x3, x5 // Add first elements ; add x4, x4, x6 // Add second elements ; stp x3, x4, [x2] // Store result ; ret ); ops.finalize().unwrap() } /// Generates a memcpy implementation optimized for small sizes. /// /// Uses a simple byte copy loop for ARM64. pub fn generate_memcpy() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; mov x3, x0 // Save destination for return ; cbz x2, ->done // If count is 0, return ; ->copy_loop: ; ldrb w4, [x1], #1 // Load byte from source and increment ; strb w4, [x0], #1 // Store byte to dest and increment ; sub x2, x2, #1 // Decrement count ; cbnz x2, ->copy_loop // Continue if not zero ; ->done: ; mov x0, x3 // Return original destination ; ret ); ops.finalize().unwrap() } /// Helper function to execute generated code safely. /// /// Converts the generated bytes into an executable function pointer. /// /// # Safety /// /// The caller must ensure that: /// - `code` contains valid machine code for the target architecture /// - The code follows the expected calling convention /// - The function pointer type matches the actual generated code signature pub unsafe fn execute_generated_code<F, R>(code: &[u8], f: F) -> R where F: FnOnce(*const u8) -> R, { f(code.as_ptr()) } #[cfg(test)] mod tests { use super::*; // Tests that verify code generation works without executing mod generation { use super::*; #[test] fn test_add_function_generation() { let code = generate_add_function(); // Verify that code was generated (non-empty buffer) assert!(!code.is_empty()); // ARM64 instructions are 4 bytes each assert_eq!(code.len() % 4, 0); } #[test] fn test_factorial_generation() { let code = generate_factorial(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); // Factorial should generate more code due to recursion logic assert!(code.len() > 20); } #[test] fn test_array_sum_generation() { let code = generate_array_sum(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] fn test_multiply_by_constant_generation() { // Test power of two (uses shift) - should generate less code let code_pow2 = generate_multiply_by_constant(8); assert!(!code_pow2.is_empty()); assert_eq!(code_pow2.len() % 4, 0); // Test non-power of two (uses mul) - might generate slightly more code let code_regular = generate_multiply_by_constant(7); assert!(!code_regular.is_empty()); assert_eq!(code_regular.len() % 4, 0); } #[test] fn test_vector_add_generation() { let code = generate_vector_add(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] fn test_memcpy_generation() { let code = generate_memcpy(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] #[cfg(target_arch = "aarch64")] fn test_hello_world_generation() { let code = generate_hello_world(); assert!(!code.is_empty()); // Should include both the string data and the code assert!(code.len() > "Hello World!".len()); } #[test] #[cfg(not(target_arch = "aarch64"))] fn test_hello_world_generation_skipped() { // Skip this test on non-ARM64 architectures because // the function address calculation is architecture-specific println!("Skipping hello_world generation test on non-ARM64 architecture"); } } // Tests that execute the generated code - only run on ARM64 #[cfg(all(test, target_arch = "aarch64"))] #[allow(unused_unsafe)] mod execution { use std::mem; use super::*; #[test] fn test_add_function_execution() { let code = generate_add_function(); let add_fn: extern "C" fn(i32, i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { add_fn(5, 3) }, 8); assert_eq!(unsafe { add_fn(-10, 20) }, 10); } #[test] fn test_factorial_execution() { let code = generate_factorial(); let factorial_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { factorial_fn(0) }, 1); assert_eq!(unsafe { factorial_fn(1) }, 1); assert_eq!(unsafe { factorial_fn(5) }, 120); } #[test] fn test_array_sum_execution() { let code = generate_array_sum(); let sum_fn: extern "C" fn(*const i32, usize) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; let array = [1, 2, 3, 4, 5]; assert_eq!(unsafe { sum_fn(array.as_ptr(), array.len()) }, 15); let empty: [i32; 0] = []; assert_eq!(unsafe { sum_fn(empty.as_ptr(), 0) }, 0); } #[test] fn test_multiply_by_constant_execution() { // Test power of two (uses shift) let code = generate_multiply_by_constant(8); let mul_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { mul_fn(5) }, 40); // Test non-power of two (uses mul) let code = generate_multiply_by_constant(7); let mul_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { mul_fn(6) }, 42); } } } /// Demonstrates conditional compilation based on runtime values. /// /// Generates specialized code for specific constant values. pub fn generate_multiply_by_constant(constant: i32) -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); // Optimize for powers of two using shifts if constant > 0 && (constant & (constant - 1)) == 0 { let shift = constant.trailing_zeros(); dynasm!(ops ; .arch aarch64 ; lsl w0, w0, shift // Shift left for power of 2 ; ret ); } else { dynasm!(ops ; .arch aarch64 ; mov w1, constant as u32 // Load constant ; mul w0, w0, w1 // Multiply ; ret ); } ops.finalize().unwrap() } }
This function generates different code depending on the constant value. For powers of two, it uses efficient shift instructions instead of multiplication, demonstrating how JIT compilation can outperform static compilation for specific cases.
Memory Management
The generated code must reside in executable memory. dynasm-rs handles the platform-specific details:
#![allow(unused)] fn main() { use std::{io, slice}; use dynasmrt::{dynasm, DynasmApi, DynasmLabelApi, ExecutableBuffer}; /// Generates a simple "Hello World" function using ARM64 assembly. /// /// This example demonstrates: /// - Embedding data directly in the assembly /// - Using labels for addressing /// - Calling external Rust functions from assembly /// - Stack management for ARM64 ABI pub fn generate_hello_world() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); let string = "Hello World!"; // Embed the string data with a label dynasm!(ops ; .arch aarch64 ; ->hello: ; .bytes string.as_bytes() ); // Generate the function that prints the string // Load the 64-bit function address in chunks (16 bits at a time) let print_addr = print as *const () as usize; dynasm!(ops ; .arch aarch64 ; adr x0, ->hello // Load string address into first arg (ARM64 ABI) ; mov w1, string.len() as u32 // Load string length into second arg ; movz x2, (print_addr & 0xFFFF) as u32 ; movk x2, ((print_addr >> 16) & 0xFFFF) as u32, lsl 16 ; movk x2, ((print_addr >> 32) & 0xFFFF) as u32, lsl 32 ; movk x2, ((print_addr >> 48) & 0xFFFF) as u32, lsl 48 ; blr x2 // Call the print function ; ret // Return ); ops.finalize().unwrap() } /// External function called from assembly to print a buffer. /// /// # Safety /// /// The caller must ensure that: /// - `buffer` points to valid memory for at least `length` bytes /// - The memory pointed to by `buffer` is initialized pub unsafe extern "C" fn print(buffer: *const u8, length: u64) -> bool { io::Write::write_all( &mut io::stdout(), slice::from_raw_parts(buffer, length as usize), ) .is_ok() } /// Generates an optimized addition function for two integers. /// /// Creates machine code equivalent to: `fn add(a: i32, b: i32) -> i32 { a + b /// }` pub fn generate_add_function() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; add w0, w0, w1 // Add w1 to w0, result in w0 (ARM64 ABI) ; ret // Return with result in w0 ); ops.finalize().unwrap() } /// Generates a factorial function using recursion. /// /// Demonstrates more complex control flow with conditional jumps and recursive /// calls. pub fn generate_factorial() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); let entry_label = ops.new_dynamic_label(); ops.dynamic_label(entry_label); dynasm!(ops ; .arch aarch64 ; cmp w0, #1 // Compare n with 1 ; b.le ->base_case // Jump if n <= 1 ; stp x29, x30, [sp, #-16]! // Save frame pointer and link register ; stp x19, x20, [sp, #-16]! // Save callee-saved registers ; mov w19, w0 // Save n in w19 ; sub w0, w0, #1 // n - 1 ; adr x1, =>entry_label // Load our own address for recursion ; blr x1 // Recursive call with n-1 ; mul w0, w0, w19 // Multiply result by n ; ldp x19, x20, [sp], #16 // Restore callee-saved registers ; ldp x29, x30, [sp], #16 // Restore frame pointer and link register ; ret ; ->base_case: ; mov w0, #1 // Return 1 for base case ; ret ); ops.finalize().unwrap() } /// Generates a loop that sums an array of integers. /// /// Takes a pointer to an i32 array and its length, returns the sum. pub fn generate_array_sum() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; mov w2, #0 // Initialize sum to 0 ; cbz x1, ->done // If length is 0, return 0 ; ->loop_start: ; ldr w3, [x0], #4 // Load element and increment pointer ; add w2, w2, w3 // Add to sum ; sub x1, x1, #1 // Decrement counter ; cbnz x1, ->loop_start // Continue if not zero ; ->done: ; mov w0, w2 // Move result to return register ; ret ); ops.finalize().unwrap() } /// Generates a function that performs SIMD operations using NEON instructions. /// /// Adds two integer vectors element by element. pub fn generate_vector_add() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; ldp x3, x4, [x0] // Load first two elements from vector 1 ; ldp x5, x6, [x1] // Load first two elements from vector 2 ; add x3, x3, x5 // Add first elements ; add x4, x4, x6 // Add second elements ; stp x3, x4, [x2] // Store result ; ret ); ops.finalize().unwrap() } /// Demonstrates conditional compilation based on runtime values. /// /// Generates specialized code for specific constant values. pub fn generate_multiply_by_constant(constant: i32) -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); // Optimize for powers of two using shifts if constant > 0 && (constant & (constant - 1)) == 0 { let shift = constant.trailing_zeros(); dynasm!(ops ; .arch aarch64 ; lsl w0, w0, shift // Shift left for power of 2 ; ret ); } else { dynasm!(ops ; .arch aarch64 ; mov w1, constant as u32 // Load constant ; mul w0, w0, w1 // Multiply ; ret ); } ops.finalize().unwrap() } /// Generates a memcpy implementation optimized for small sizes. /// /// Uses a simple byte copy loop for ARM64. pub fn generate_memcpy() -> ExecutableBuffer { let mut ops = dynasmrt::aarch64::Assembler::new().unwrap(); dynasm!(ops ; .arch aarch64 ; mov x3, x0 // Save destination for return ; cbz x2, ->done // If count is 0, return ; ->copy_loop: ; ldrb w4, [x1], #1 // Load byte from source and increment ; strb w4, [x0], #1 // Store byte to dest and increment ; sub x2, x2, #1 // Decrement count ; cbnz x2, ->copy_loop // Continue if not zero ; ->done: ; mov x0, x3 // Return original destination ; ret ); ops.finalize().unwrap() } #[cfg(test)] mod tests { use super::*; // Tests that verify code generation works without executing mod generation { use super::*; #[test] fn test_add_function_generation() { let code = generate_add_function(); // Verify that code was generated (non-empty buffer) assert!(!code.is_empty()); // ARM64 instructions are 4 bytes each assert_eq!(code.len() % 4, 0); } #[test] fn test_factorial_generation() { let code = generate_factorial(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); // Factorial should generate more code due to recursion logic assert!(code.len() > 20); } #[test] fn test_array_sum_generation() { let code = generate_array_sum(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] fn test_multiply_by_constant_generation() { // Test power of two (uses shift) - should generate less code let code_pow2 = generate_multiply_by_constant(8); assert!(!code_pow2.is_empty()); assert_eq!(code_pow2.len() % 4, 0); // Test non-power of two (uses mul) - might generate slightly more code let code_regular = generate_multiply_by_constant(7); assert!(!code_regular.is_empty()); assert_eq!(code_regular.len() % 4, 0); } #[test] fn test_vector_add_generation() { let code = generate_vector_add(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] fn test_memcpy_generation() { let code = generate_memcpy(); assert!(!code.is_empty()); assert_eq!(code.len() % 4, 0); } #[test] #[cfg(target_arch = "aarch64")] fn test_hello_world_generation() { let code = generate_hello_world(); assert!(!code.is_empty()); // Should include both the string data and the code assert!(code.len() > "Hello World!".len()); } #[test] #[cfg(not(target_arch = "aarch64"))] fn test_hello_world_generation_skipped() { // Skip this test on non-ARM64 architectures because // the function address calculation is architecture-specific println!("Skipping hello_world generation test on non-ARM64 architecture"); } } // Tests that execute the generated code - only run on ARM64 #[cfg(all(test, target_arch = "aarch64"))] #[allow(unused_unsafe)] mod execution { use std::mem; use super::*; #[test] fn test_add_function_execution() { let code = generate_add_function(); let add_fn: extern "C" fn(i32, i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { add_fn(5, 3) }, 8); assert_eq!(unsafe { add_fn(-10, 20) }, 10); } #[test] fn test_factorial_execution() { let code = generate_factorial(); let factorial_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { factorial_fn(0) }, 1); assert_eq!(unsafe { factorial_fn(1) }, 1); assert_eq!(unsafe { factorial_fn(5) }, 120); } #[test] fn test_array_sum_execution() { let code = generate_array_sum(); let sum_fn: extern "C" fn(*const i32, usize) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; let array = [1, 2, 3, 4, 5]; assert_eq!(unsafe { sum_fn(array.as_ptr(), array.len()) }, 15); let empty: [i32; 0] = []; assert_eq!(unsafe { sum_fn(empty.as_ptr(), 0) }, 0); } #[test] fn test_multiply_by_constant_execution() { // Test power of two (uses shift) let code = generate_multiply_by_constant(8); let mul_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { mul_fn(5) }, 40); // Test non-power of two (uses mul) let code = generate_multiply_by_constant(7); let mul_fn: extern "C" fn(i32) -> i32 = unsafe { mem::transmute(code.as_ptr()) }; assert_eq!(unsafe { mul_fn(6) }, 42); } } } /// Helper function to execute generated code safely. /// /// Converts the generated bytes into an executable function pointer. /// /// # Safety /// /// The caller must ensure that: /// - `code` contains valid machine code for the target architecture /// - The code follows the expected calling convention /// - The function pointer type matches the actual generated code signature pub unsafe fn execute_generated_code<F, R>(code: &[u8], f: F) -> R where F: FnOnce(*const u8) -> R, { f(code.as_ptr()) } }
The library manages memory protection flags and ensures proper alignment. On Unix systems, it uses mmap with PROT_EXEC; on Windows, it uses VirtualAlloc with PAGE_EXECUTE_READWRITE.
Label System
dynasm-rs provides a sophisticated label system for managing control flow:
- Local labels (prefixed with
->
) are unique within each dynasm invocation - Global labels (prefixed with
=>
) can be referenced across multiple invocations - Dynamic labels use runtime values for computed jumps
#![allow(unused)] fn main() { dynasm!(ops ; =>function_start: ; test rax, rax ; jz ->skip ; call ->helper ; ->skip: ; ret ; ->helper: ; xor rax, rax ; ret ); }
Architecture Support
dynasm-rs supports multiple architectures with comprehensive instruction set coverage:
- ARM64/AArch64: Modern 64-bit ARM with NEON SIMD support (demonstrated in these examples)
- x86/x64: Full instruction set including SSE, AVX, and AVX-512
- ARM: 32-bit ARM instruction sets
Each architecture has its own syntax and register naming conventions, but the overall API remains consistent. The examples in this documentation use ARM64 assembly, which is the architecture for Apple Silicon and many modern ARM processors.
Integration Patterns
dynasm-rs integrates well with existing compiler infrastructure. Here’s a pattern for compiling expressions to machine code:
#![allow(unused)] fn main() { enum Expr { Const(i32), Add(Box<Expr>, Box<Expr>), Mul(Box<Expr>, Box<Expr>), } fn compile_expr(expr: &Expr, ops: &mut dynasmrt::aarch64::Assembler) { match expr { Expr::Const(val) => { dynasm!(ops; .arch aarch64; mov w0, *val as u32); } Expr::Add(a, b) => { compile_expr(a, ops); dynasm!(ops; .arch aarch64; str w0, [sp, #-16]!); compile_expr(b, ops); dynasm!(ops; .arch aarch64; ldr w1, [sp], #16; add w0, w0, w1); } Expr::Mul(a, b) => { compile_expr(a, ops); dynasm!(ops; .arch aarch64; str w0, [sp, #-16]!); compile_expr(b, ops); dynasm!(ops; .arch aarch64; ldr w1, [sp], #16; mul w0, w0, w1); } } } }
This recursive compilation strategy works well for tree-structured intermediate representations.
inkwell
The inkwell
crate provides safe, idiomatic Rust bindings to LLVM, enabling compiler developers to generate highly optimized machine code while leveraging LLVM’s mature optimization infrastructure and broad platform support. LLVM IR serves as a universal intermediate representation that can be compiled to native code for virtually any modern processor architecture. The inkwell bindings wrap LLVM’s C++ API with Rust’s type system and ownership model, preventing common errors like use-after-free and type mismatches that plague direct LLVM usage.
The architecture of inkwell mirrors LLVM’s conceptual model while providing Rust-native abstractions. Contexts manage the lifetime of LLVM data structures, modules contain functions and global variables, builders construct instruction sequences, and execution engines provide JIT compilation capabilities. This design ensures memory safety through Rust’s lifetime system while maintaining the full power and flexibility of LLVM’s code generation capabilities.
Installation and Setup
Inkwell currently supports LLVM versions 8 through 18. You must have LLVM installed on your system and specify the version in your Cargo.toml
dependencies.
On macOS, you can use Homebrew to install LLVM. For example, to install LLVM 18:
brew install llvm@18
Add inkwell to your Cargo.toml
with the appropriate LLVM version feature flag:
[dependencies]
inkwell = { version = "0.6.0", features = ["llvm18-1"] }
Supported versions use the pattern llvmM-0
where M is the LLVM major version (8-18).
Basic Usage
Creating an LLVM context for code generation:
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } }
The context is the core of LLVM’s infrastructure, owning all types, values, and metadata. Every LLVM operation requires a context, which manages memory and ensures proper cleanup of LLVM data structures.
Function Creation
Building simple arithmetic functions:
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } }
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } }
Function creation involves defining the function signature through LLVM’s type system, adding the function to the module, and creating basic blocks for the function body. The entry block serves as the starting point for instruction generation. Parameters are accessed through the function value and can be used in subsequent instructions.
Constants and Comparisons
Creating constant values and comparison operations:
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } }
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } }
Constants are compile-time values that LLVM can optimize aggressively. Comparison operations produce boolean results used for control flow decisions. LLVM supports both signed and unsigned integer comparisons with various predicates.
Control Flow
Implementing conditional branches and phi nodes:
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } }
Control flow in LLVM uses explicit basic blocks connected by branch instructions. Conditional branches test a boolean condition and jump to one of two target blocks. Phi nodes implement the SSA form by selecting values based on the predecessor block. This explicit representation enables sophisticated control flow optimizations.
Loop Construction
Building loops with phi nodes:
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } }
Loops in LLVM use phi nodes to manage loop variables in SSA form. The loop structure consists of an entry block, a loop block containing the phi node and loop body, and an exit block. The phi node receives different values depending on whether control flow comes from the initial entry or from the loop itself.
Stack Allocation
Using alloca for local variables:
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } }
The alloca instruction creates stack storage for mutable variables. Load and store instructions access these variables. This pattern is commonly used before mem2reg optimization, which promotes allocas to SSA registers when possible.
Array Operations
Working with arrays and pointers:
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } }
Array operations use the GEP (GetElementPtr) instruction to compute addresses of array elements. The GEP instruction performs pointer arithmetic in a type-safe manner, taking into account element sizes and array dimensions.
Structure Types
Defining and manipulating aggregate types:
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } }
Structures in LLVM represent aggregate data types with indexed fields. The extract_value instruction retrieves fields from struct values. This example shows how to work with heterogeneous data types in LLVM.
Global Variables
Creating and using module-level variables:
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } }
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } }
Global variables exist at module scope and can be accessed by all functions. They support various linkage types controlling visibility and sharing across compilation units.
Recursive Functions
Implementing recursion:
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } }
Recursive functions in LLVM work like any other function call. The function can call itself by using its own function value as the callee. This example implements factorial recursively with a base case and recursive case.
Optimization
Applying LLVM’s optimization passes using the modern pass manager:
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } }
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } }
LLVM provides a modern pass manager (available in LLVM 18) with a string-based interface for specifying optimization pipelines. Common passes include instcombine, reassociate, gvn, simplifycfg, and mem2reg. The PassBuilderOptions allows fine-grained control over optimization behavior.
JIT Compilation
Just-in-time compilation for immediate execution:
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } }
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } }
The execution engine provides JIT compilation capabilities, compiling LLVM IR to machine code in memory for immediate execution. This enables dynamic code generation scenarios like REPLs, runtime specialization, and adaptive optimization.
Code Emission
Generating object files and LLVM IR:
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } }
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } }
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } }
LLVM can emit code in various formats including object files and LLVM IR text. The target machine encapsulates platform-specific code generation details. Module verification ensures the generated IR is well-formed before optimization or code generation.
Helper Functions
Utility for creating function types:
#![allow(unused)] fn main() { use std::error::Error; use std::path::Path; use inkwell::context::Context; use inkwell::module::Module; use inkwell::passes::PassBuilderOptions; use inkwell::targets::{ CodeModel, FileType, InitializationConfig, RelocMode, Target, TargetMachine, }; use inkwell::types::{BasicMetadataTypeEnum, BasicType, BasicTypeEnum}; use inkwell::values::FunctionValue; use inkwell::{AddressSpace, IntPredicate, OptimizationLevel}; /// Creates a basic LLVM context pub fn create_context() -> Context { Context::create() } /// Creates a simple function that adds two integers pub fn create_add_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("add", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get function parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Build addition and return let sum = builder.build_int_add(x, y, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); function } /// Creates a function that multiplies two integers pub fn create_multiply_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("multiply", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let product = builder.build_int_mul(x, y, "product").unwrap(); builder.build_return(Some(&product)).unwrap(); function } /// Creates a function with a constant return value pub fn create_constant_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("get_constant", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let constant = i32_type.const_int(42, false); builder.build_return(Some(&constant)).unwrap(); function } /// Demonstrates integer comparison operations pub fn create_comparison_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let bool_type = context.bool_type(); let fn_type = bool_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("compare_ints", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Compare x > y let comparison = builder .build_int_compare(IntPredicate::SGT, x, y, "cmp") .unwrap(); builder.build_return(Some(&comparison)).unwrap(); function } /// Creates a function with conditional branching (if-then-else) pub fn create_conditional_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("conditional", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let then_block = context.append_basic_block(function, "then"); let else_block = context.append_basic_block(function, "else"); let merge_block = context.append_basic_block(function, "merge"); // Entry block builder.position_at_end(entry); let x = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); let condition = builder .build_int_compare(IntPredicate::SGT, x, zero, "cond") .unwrap(); builder .build_conditional_branch(condition, then_block, else_block) .unwrap(); // Then block: return x * 2 builder.position_at_end(then_block); let two = i32_type.const_int(2, false); let then_val = builder.build_int_mul(x, two, "then_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Else block: return x * -1 builder.position_at_end(else_block); let neg_one = i32_type.const_int(-1i64 as u64, true); let else_val = builder.build_int_mul(x, neg_one, "else_val").unwrap(); builder.build_unconditional_branch(merge_block).unwrap(); // Merge block with phi node builder.position_at_end(merge_block); let phi = builder.build_phi(i32_type, "result").unwrap(); phi.add_incoming(&[(&then_val, then_block), (&else_val, else_block)]); builder.build_return(Some(&phi.as_basic_value())).unwrap(); function } /// Creates a simple loop that counts from 0 to n pub fn create_loop_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("count_loop", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block: initialize counter builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let counter = builder.build_phi(i32_type, "counter").unwrap(); counter.add_incoming(&[(&zero, entry)]); // Increment counter let one = i32_type.const_int(1, false); let next_counter = builder .build_int_add( counter.as_basic_value().into_int_value(), one, "next_counter", ) .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_counter, n, "loop_cond") .unwrap(); // Add incoming value for next iteration let loop_block_end = builder.get_insert_block().unwrap(); counter.add_incoming(&[(&next_counter, loop_block_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder .build_return(Some(&counter.as_basic_value())) .unwrap(); function } /// Creates a function that allocates and uses local variables (stack /// allocation) pub fn create_alloca_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into(), i32_type.into()], false); let function = module.add_function("use_alloca", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get parameters let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); // Allocate stack variables let x_ptr = builder.build_alloca(i32_type, "x_ptr").unwrap(); let y_ptr = builder.build_alloca(i32_type, "y_ptr").unwrap(); let result_ptr = builder.build_alloca(i32_type, "result_ptr").unwrap(); // Store values builder.build_store(x_ptr, x).unwrap(); builder.build_store(y_ptr, y).unwrap(); // Load values let x_val = builder.build_load(i32_type, x_ptr, "x_val").unwrap(); let y_val = builder.build_load(i32_type, y_ptr, "y_val").unwrap(); // Compute and store result let sum = builder .build_int_add(x_val.into_int_value(), y_val.into_int_value(), "sum") .unwrap(); builder.build_store(result_ptr, sum).unwrap(); // Load and return result let result = builder.build_load(i32_type, result_ptr, "result").unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a function that works with arrays pub fn create_array_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let i32_ptr_type = context.ptr_type(AddressSpace::default()); let fn_type = i32_type.fn_type(&[i32_ptr_type.into(), i32_type.into()], false); let function = module.add_function("sum_array", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let loop_block = context.append_basic_block(function, "loop"); let exit_block = context.append_basic_block(function, "exit"); // Entry block builder.position_at_end(entry); let array_ptr = function.get_nth_param(0).unwrap().into_pointer_value(); let size = function.get_nth_param(1).unwrap().into_int_value(); let zero = i32_type.const_int(0, false); builder.build_unconditional_branch(loop_block).unwrap(); // Loop block builder.position_at_end(loop_block); let index = builder.build_phi(i32_type, "index").unwrap(); let sum = builder.build_phi(i32_type, "sum").unwrap(); index.add_incoming(&[(&zero, entry)]); sum.add_incoming(&[(&zero, entry)]); // Load array element let elem_ptr = unsafe { builder .build_gep( i32_type, array_ptr, &[index.as_basic_value().into_int_value()], "elem_ptr", ) .unwrap() }; let elem = builder.build_load(i32_type, elem_ptr, "elem").unwrap(); // Update sum let new_sum = builder .build_int_add( sum.as_basic_value().into_int_value(), elem.into_int_value(), "new_sum", ) .unwrap(); // Update index let one = i32_type.const_int(1, false); let next_index = builder .build_int_add(index.as_basic_value().into_int_value(), one, "next_index") .unwrap(); // Check loop condition let condition = builder .build_int_compare(IntPredicate::SLT, next_index, size, "loop_cond") .unwrap(); // Update phi nodes let loop_end = builder.get_insert_block().unwrap(); index.add_incoming(&[(&next_index, loop_end)]); sum.add_incoming(&[(&new_sum, loop_end)]); builder .build_conditional_branch(condition, loop_block, exit_block) .unwrap(); // Exit block builder.position_at_end(exit_block); builder.build_return(Some(&sum.as_basic_value())).unwrap(); function } /// Creates a global variable pub fn create_global_variable<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> inkwell::values::GlobalValue<'ctx> { let i32_type = context.i32_type(); let global = module.add_global(i32_type, Some(AddressSpace::default()), "global_counter"); global.set_initializer(&i32_type.const_int(0, false)); global.set_linkage(inkwell::module::Linkage::Internal); global } /// Creates a function that uses a global variable pub fn create_global_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[], false); let function = module.add_function("increment_global", fn_type, None); // Create or get global variable let global = module .get_global("global_counter") .unwrap_or_else(|| create_global_variable(context, module)); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Load global value let global_ptr = global.as_pointer_value(); let current = builder.build_load(i32_type, global_ptr, "current").unwrap(); // Increment let one = i32_type.const_int(1, false); let incremented = builder .build_int_add(current.into_int_value(), one, "incremented") .unwrap(); // Store back to global builder.build_store(global_ptr, incremented).unwrap(); // Return new value builder.build_return(Some(&incremented)).unwrap(); function } /// Creates a recursive factorial function pub fn create_recursive_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let fn_type = i32_type.fn_type(&[i32_type.into()], false); let function = module.add_function("factorial", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); let recurse = context.append_basic_block(function, "recurse"); let base = context.append_basic_block(function, "base"); // Entry block builder.position_at_end(entry); let n = function.get_nth_param(0).unwrap().into_int_value(); let one = i32_type.const_int(1, false); let is_base = builder .build_int_compare(IntPredicate::SLE, n, one, "is_base") .unwrap(); builder .build_conditional_branch(is_base, base, recurse) .unwrap(); // Base case: return 1 builder.position_at_end(base); builder.build_return(Some(&one)).unwrap(); // Recursive case: return n * factorial(n-1) builder.position_at_end(recurse); let n_minus_1 = builder.build_int_sub(n, one, "n_minus_1").unwrap(); let rec_result = builder .build_call(function, &[n_minus_1.into()], "rec_result") .unwrap(); let result = builder .build_int_mul( n, rec_result .try_as_basic_value() .left() .unwrap() .into_int_value(), "result", ) .unwrap(); builder.build_return(Some(&result)).unwrap(); function } /// Creates a struct type and a function that uses it pub fn create_struct_function<'ctx>( context: &'ctx Context, module: &Module<'ctx>, ) -> FunctionValue<'ctx> { let i32_type = context.i32_type(); let f64_type = context.f64_type(); // Define a Point struct with x and y fields let field_types = [i32_type.into(), f64_type.into()]; let struct_type = context.struct_type(&field_types, false); let fn_type = f64_type.fn_type(&[struct_type.into()], false); let function = module.add_function("get_point_y", fn_type, None); let builder = context.create_builder(); let entry = context.append_basic_block(function, "entry"); builder.position_at_end(entry); // Get the struct parameter let point = function.get_nth_param(0).unwrap().into_struct_value(); // Extract the y field (index 1) let y_field = builder.build_extract_value(point, 1, "y_field").unwrap(); builder.build_return(Some(&y_field)).unwrap(); function } /// Runs optimization passes on a module using the modern pass manager (LLVM 18) pub fn optimize_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { // First verify the module is valid module.verify().map_err(|e| e.to_string())?; // Initialize targets for optimization Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; // Common optimization passes let passes = [ "instcombine", // Combine instructions "reassociate", // Reassociate expressions "gvn", // Global value numbering "simplifycfg", // Simplify control flow graph "mem2reg", // Promote memory to registers ]; let pass_builder_options = PassBuilderOptions::create(); pass_builder_options.set_loop_vectorization(true); pass_builder_options.set_loop_unrolling(true); pass_builder_options.set_merge_functions(true); module .run_passes(&passes.join(","), &target_machine, pass_builder_options) .map_err(|e| e.to_string()) } /// Runs specific optimization passes on a module pub fn run_custom_passes<'ctx>(module: &Module<'ctx>, passes: &[&str]) -> Result<(), String> { // Verify module first module.verify().map_err(|e| e.to_string())?; // Initialize targets Target::initialize_all(&InitializationConfig::default()); let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple) .map_err(|e| format!("Failed to create target: {}", e))?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::None, RelocMode::Default, CodeModel::Default, ) .ok_or("Failed to create target machine")?; module .run_passes( &passes.join(","), &target_machine, PassBuilderOptions::create(), ) .map_err(|e| e.to_string()) } /// Writes LLVM IR to a file pub fn write_ir_to_file<'ctx>(module: &Module<'ctx>, path: &Path) -> Result<(), String> { module .print_to_file(path) .map_err(|e| format!("Failed to write IR: {}", e)) } /// Compiles module to object file pub fn compile_to_object_file<'ctx>( module: &Module<'ctx>, path: &Path, ) -> Result<(), Box<dyn Error>> { Target::initialize_native(&InitializationConfig::default())?; let target_triple = TargetMachine::get_default_triple(); let target = Target::from_triple(&target_triple)?; let target_machine = target .create_target_machine( &target_triple, "generic", "", OptimizationLevel::Default, RelocMode::Default, CodeModel::Default, ) .ok_or("Could not create target machine")?; target_machine.write_to_file(module, FileType::Object, path)?; Ok(()) } /// Verifies that a module is valid pub fn verify_module<'ctx>(module: &Module<'ctx>) -> Result<(), String> { module.verify().map_err(|e| e.to_string()) } /// Simple JIT execution example pub fn create_execution_engine<'ctx>( module: &Module<'ctx>, ) -> Result<inkwell::execution_engine::ExecutionEngine<'ctx>, String> { module .create_jit_execution_engine(OptimizationLevel::None) .map_err(|e| e.to_string()) } /// Example of JIT compiling and executing a function pub fn jit_compile_and_execute(context: &Context) -> Result<u64, Box<dyn Error>> { let module = context.create_module("jit_example"); let builder = context.create_builder(); // Create a simple sum function: sum(x, y, z) = x + y + z let i64_type = context.i64_type(); let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false); let function = module.add_function("sum", fn_type, None); let basic_block = context.append_basic_block(function, "entry"); builder.position_at_end(basic_block); let x = function.get_nth_param(0).unwrap().into_int_value(); let y = function.get_nth_param(1).unwrap().into_int_value(); let z = function.get_nth_param(2).unwrap().into_int_value(); let sum = builder.build_int_add(x, y, "sum").unwrap(); let sum = builder.build_int_add(sum, z, "sum").unwrap(); builder.build_return(Some(&sum)).unwrap(); // Create execution engine let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None)?; // Get the compiled function type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64; let sum_fn = unsafe { execution_engine.get_function::<SumFunc>("sum")? }; // Execute the function let result = unsafe { sum_fn.call(1, 2, 3) }; Ok(result) } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_add_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "add"); assert_eq!(function.count_params(), 2); assert!(verify_module(&module).is_ok()); } #[test] fn test_constant_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_constant_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "get_constant"); assert_eq!(function.count_params(), 0); assert!(verify_module(&module).is_ok()); } #[test] fn test_conditional_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_conditional_function(&context, &module); assert_eq!(function.count_basic_blocks(), 4); // entry, then, else, merge assert!(verify_module(&module).is_ok()); } #[test] fn test_loop_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_loop_function(&context, &module); assert_eq!(function.count_basic_blocks(), 3); // entry, loop, exit assert!(verify_module(&module).is_ok()); } #[test] fn test_global_variable() { let context = Context::create(); let module = context.create_module("test"); let global = create_global_variable(&context, &module); assert_eq!(global.get_name().to_str().unwrap(), "global_counter"); assert!(verify_module(&module).is_ok()); } #[test] fn test_recursive_function() { let context = Context::create(); let module = context.create_module("test"); let function = create_recursive_function(&context, &module); assert_eq!(function.get_name().to_str().unwrap(), "factorial"); assert!(verify_module(&module).is_ok()); } #[test] fn test_optimization() { let context = Context::create(); let module = context.create_module("test"); // Create several functions create_add_function(&context, &module); create_multiply_function(&context, &module); create_constant_function(&context, &module); // Apply optimizations assert!(optimize_module(&module).is_ok()); // Module should still be valid after optimization assert!(verify_module(&module).is_ok()); } #[test] fn test_custom_passes() { let context = Context::create(); let module = context.create_module("test"); // Create a simple function create_add_function(&context, &module); // Run specific optimization passes let passes = ["instcombine", "simplifycfg"]; assert!(run_custom_passes(&module, &passes).is_ok()); // Module should still be valid assert!(verify_module(&module).is_ok()); } #[test] fn test_jit_execution() { let context = Context::create(); // Test JIT compilation and execution match jit_compile_and_execute(&context) { Ok(result) => assert_eq!(result, 6), // 1 + 2 + 3 = 6 Err(e) => panic!("JIT execution failed: {}", e), } } } /// Helper to create a function type pub fn create_function_type<'ctx>( context: &'ctx Context, param_types: Vec<BasicMetadataTypeEnum<'ctx>>, return_type: Option<BasicTypeEnum<'ctx>>, is_var_args: bool, ) -> inkwell::types::FunctionType<'ctx> { match return_type { Some(ret) => ret.fn_type(¶m_types, is_var_args), None => context.void_type().fn_type(¶m_types, is_var_args), } } }
This helper simplifies creating function types with proper handling of void returns and variadic arguments.
Best Practices
Maintain clear separation between your language’s AST and LLVM IR generation. Build an intermediate representation that bridges your language semantics and LLVM’s model. This separation simplifies both frontend development and backend optimization.
Use LLVM’s type system to enforce invariants at compile time. Rich type information enables better optimization and catches errors early. Avoid using opaque pointers when specific types provide better optimization opportunities.
Leverage LLVM’s SSA form by minimizing mutable state. Use phi nodes instead of memory operations when possible. SSA form enables powerful optimizations like constant propagation and dead code elimination.
Structure code generation to emit IR suitable for optimization. Avoid patterns that inhibit optimization like excessive memory operations or complex control flow. Simple, regular IR patterns optimize better than clever, complicated constructions.
Enable appropriate optimization levels based on use case. Debug builds benefit from minimal optimization for faster compilation and better debugging. Release builds should use higher optimization levels for maximum performance.
Use LLVM intrinsics for operations with hardware support. Intrinsics for mathematical functions, atomic operations, and SIMD instructions generate better code than manual implementations. LLVM recognizes and optimizes intrinsic patterns.
Profile and analyze generated code to identify optimization opportunities. LLVM provides extensive analysis passes that reveal performance bottlenecks. Use this information to guide both frontend improvements and optimization pass selection.
melior
The melior
crate provides safe Rust bindings to MLIR (Multi-Level Intermediate Representation), enabling compiler developers to leverage MLIR’s powerful infrastructure for building optimizing compilers and domain-specific code generators. MLIR represents computations as a graph of operations organized into regions and blocks, supporting multiple levels of abstraction from high-level tensor operations to low-level machine instructions. The melior bindings expose MLIR’s dialect system, allowing developers to work with various IR representations including arithmetic operations, control flow, tensor computations, and LLVM IR generation.
The architecture of melior wraps MLIR’s C API with idiomatic Rust abstractions, providing type safety and memory safety guarantees while maintaining the flexibility of MLIR’s extensible design. The crate supports all standard MLIR dialects including func, arith, scf, tensor, memref, and llvm, enabling progressive lowering from high-level abstractions to executable code. This multi-level approach allows compilers to perform optimizations at the most appropriate abstraction level, improving both compilation time and generated code quality.
Installation on macOS
To use melior on macOS, you need to install LLVM/MLIR 20 via Homebrew:
brew install llvm@20
You can get the LLVM installation path with:
$(brew --prefix llvm@20)
Add melior to your Cargo.toml:
[dependencies]
melior = "0.25"
Basic Usage
Creating an MLIR context with all dialects loaded:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } }
The context manages dialect registration and configuration. MLIR requires dialects to be loaded before you can use their operations. The create_test_context
function loads all standard dialects and LLVM translations for immediate use.
Function Creation
Building simple arithmetic functions:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } }
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } }
Function creation involves specifying parameter types and return types using MLIR’s type system. The function body consists of a region containing basic blocks, with the entry block receiving function parameters as block arguments. This structure supports both simple functions and complex control flow patterns.
Arithmetic Operations
Creating constant values and arithmetic computations:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } }
The arith dialect provides integer and floating-point arithmetic operations. Constants are materialized using arith::constant operations, and computations build expression trees through operation chaining. Each operation produces results that subsequent operations consume, creating a dataflow graph representation.
Type and Attribute Builders
Helper utilities for creating MLIR types and attributes:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } }
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } }
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } }
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } }
These builders provide convenient methods for creating common MLIR types and attributes without manually constructing them each time. The TypeBuilder handles integer types, index types, and function types. The AttributeBuilder creates string attributes, integer attributes, and type attributes.
Module Operations
Utility functions for working with MLIR modules:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } }
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } }
These utilities help verify module correctness and convert modules to their textual MLIR representation for debugging and inspection.
MLIR Transformations and Optimization Passes
MLIR’s power comes from its transformation infrastructure. The PassManager orchestrates optimization passes that transform and optimize IR at different abstraction levels.
Basic Transformations
Canonicalizer simplifies IR by applying local pattern-based rewrites:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } }
Common Subexpression Elimination (CSE) removes redundant computations:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } }
Sparse Conditional Constant Propagation (SCCP) performs constant folding and dead code elimination:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } }
Function Optimizations
Inlining replaces function calls with their bodies:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } }
Symbol DCE removes unused functions and global symbols:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } }
Loop Optimizations
Loop-Invariant Code Motion (LICM) hoists invariant computations out of loops:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } }
Memory Optimizations
Promote memory allocations to SSA registers using mem2reg:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } }
GPU Transformations
Convert parallel patterns to GPU kernels using GPU dialect lowering:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } }
Utility Passes
Strip debug information for release builds using strip-debuginfo:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } }
Optimization Pipelines
Combine multiple passes into an optimization pipeline:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } }
Custom Pass Pipelines
Build fluent transformation pipelines with the PassPipeline builder:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } }
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } }
The PassPipeline builder provides a fluent API for constructing custom optimization sequences. Each transformation method returns self, allowing method chaining. The pipeline executes passes in the order they were added, enabling precise control over optimization phases.
Best Practices
Structure compilation as progressive lowering through multiple abstraction levels. Start with domain-specific representations and lower gradually to executable code. This approach enables optimizations at appropriate abstraction levels and improves compiler modularity.
Leverage MLIR’s dialect system to separate concerns. Use high-level dialects for domain logic, mid-level dialects for general optimizations, and low-level dialects for code generation. This separation enables reuse across different compilation pipelines.
Design custom operations to be composable and orthogonal. Avoid monolithic operations that combine multiple concepts. Instead, build complex behaviors from simple, well-defined operations that optimization passes can analyze and transform.
Use MLIR’s type system to enforce invariants. Rich types catch errors early and enable optimizations. Track properties like tensor dimensions, memory layouts, and value constraints through the type system rather than runtime checks.
Implement verification for custom operations. Verification catches IR inconsistencies early and provides better error messages. Well-verified IR enables aggressive optimizations without compromising correctness.
Build reusable transformation patterns. Patterns that match common idioms enable optimization across different contexts. Parameterize patterns to handle variations while maintaining transformation correctness.
codespan
Codespan provides fundamental span tracking and position management infrastructure for compiler diagnostics and source mapping. The crate offers precise byte-level position tracking with efficient conversion to line and column information, making it ideal for error reporting and source code analysis. Unlike higher-level diagnostic libraries, codespan focuses on the core span arithmetic and position calculations that underpin accurate source location tracking.
The library centers around immutable position types that represent byte offsets in source text. These types support arithmetic operations for calculating ranges and distances while maintaining type safety. The span abstraction represents a contiguous range of source text, enabling precise tracking of token locations, expression boundaries, and error positions throughout compilation.
Core Position Types
#![allow(unused)] fn main() { use std::collections::HashMap; use std::fmt; use codespan::{ByteIndex, ByteOffset, ColumnIndex, LineIndex, LineOffset, Span}; impl SourceFile { pub fn new(name: String, contents: String) -> Self { let line_starts = std::iter::once(ByteIndex::from(0)) .chain(contents.char_indices().filter_map(|(i, c)| { if c == '\n' { Some(ByteIndex::from(i as u32 + 1)) } else { None } })) .collect(); Self { name, contents, line_starts, } } pub fn name(&self) -> &str { &self.name } pub fn contents(&self) -> &str { &self.contents } pub fn line_index(&self, byte_index: ByteIndex) -> LineIndex { match self.line_starts.binary_search(&byte_index) { Ok(line) => LineIndex::from(line as u32), Err(next_line) => LineIndex::from((next_line as u32).saturating_sub(1)), } } pub fn column_index(&self, byte_index: ByteIndex) -> ColumnIndex { let line_index = self.line_index(byte_index); let line_start = self.line_starts[line_index.to_usize()]; let column_offset = byte_index - line_start; ColumnIndex::from(column_offset.to_usize() as u32) } pub fn location(&self, byte_index: ByteIndex) -> Location { Location { line: self.line_index(byte_index), column: self.column_index(byte_index), } } pub fn slice(&self, span: Span) -> &str { let start = span.start().to_usize(); let end = span.end().to_usize(); &self.contents[start..end] } } /// A location in a source file #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Location { pub line: LineIndex, pub column: ColumnIndex, } impl fmt::Display for Location { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "{}:{}", self.line.to_usize() + 1, self.column.to_usize() + 1 ) } } /// A span manager for tracking multiple source files pub struct SpanManager { files: Vec<SourceFile>, file_map: HashMap<String, usize>, } impl SpanManager { pub fn new() -> Self { Self { files: Vec::new(), file_map: HashMap::new(), } } pub fn add_file(&mut self, name: String, contents: String) -> FileId { let file_id = FileId(self.files.len()); let file = SourceFile::new(name.clone(), contents); self.files.push(file); self.file_map.insert(name, file_id.0); file_id } pub fn get_file(&self, id: FileId) -> Option<&SourceFile> { self.files.get(id.0) } pub fn find_file(&self, name: &str) -> Option<FileId> { self.file_map.get(name).map(|&id| FileId(id)) } pub fn create_span(&self, start: ByteIndex, end: ByteIndex) -> Span { Span::new(start, end) } pub fn merge_spans(&self, first: Span, second: Span) -> Span { let start = first.start().min(second.start()); let end = first.end().max(second.end()); Span::new(start, end) } } impl Default for SpanManager { fn default() -> Self { Self::new() } } /// File identifier #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct FileId(usize); /// A token with span information #[derive(Debug, Clone)] pub struct Token<T> { pub kind: T, pub span: Span, pub file_id: FileId, } impl<T> Token<T> { pub fn new(kind: T, span: Span, file_id: FileId) -> Self { Self { kind, span, file_id, } } } /// Example token types for demonstration #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), String(String), Keyword(Keyword), Operator(Operator), Delimiter(Delimiter), } #[derive(Debug, Clone, PartialEq)] pub enum Keyword { Let, If, Else, While, Function, Return, } #[derive(Debug, Clone, PartialEq)] pub enum Operator { Plus, Minus, Star, Slash, Equal, NotEqual, Less, Greater, Assign, } #[derive(Debug, Clone, PartialEq)] pub enum Delimiter { LeftParen, RightParen, LeftBrace, RightBrace, LeftBracket, RightBracket, Semicolon, Comma, } /// A simple lexer using codespan for span tracking pub struct Lexer { input: String, position: usize, file_id: FileId, } impl Lexer { pub fn new(input: String, file_id: FileId) -> Self { Self { input, position: 0, file_id, } } pub fn tokenize(&mut self) -> Vec<Token<TokenKind>> { let mut tokens = Vec::new(); while !self.is_eof() { self.skip_whitespace(); if self.is_eof() { break; } let start = ByteIndex::from(self.position as u32); if let Some(token) = self.scan_token() { let end = ByteIndex::from(self.position as u32); let span = Span::new(start, end); tokens.push(Token::new(token, span, self.file_id)); } } tokens } fn scan_token(&mut self) -> Option<TokenKind> { let start_char = self.current_char()?; match start_char { '+' => { self.advance(); Some(TokenKind::Operator(Operator::Plus)) } '-' => { self.advance(); Some(TokenKind::Operator(Operator::Minus)) } '*' => { self.advance(); Some(TokenKind::Operator(Operator::Star)) } '/' => { self.advance(); Some(TokenKind::Operator(Operator::Slash)) } '=' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::Equal)) } else { Some(TokenKind::Operator(Operator::Assign)) } } '!' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::NotEqual)) } else { None } } '<' => { self.advance(); Some(TokenKind::Operator(Operator::Less)) } '>' => { self.advance(); Some(TokenKind::Operator(Operator::Greater)) } '(' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftParen)) } ')' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightParen)) } '{' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBrace)) } '}' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBrace)) } '[' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBracket)) } ']' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBracket)) } ';' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Semicolon)) } ',' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Comma)) } '"' => self.scan_string(), c if c.is_ascii_digit() => self.scan_number(), c if c.is_ascii_alphabetic() || c == '_' => self.scan_identifier_or_keyword(), _ => { self.advance(); None } } } fn scan_string(&mut self) -> Option<TokenKind> { self.advance(); // consume opening quote let start = self.position; while !self.is_eof() && self.current_char() != Some('"') { if self.current_char() == Some('\\') { self.advance(); // consume backslash if !self.is_eof() { self.advance(); // consume escaped character } } else { self.advance(); } } if self.current_char() == Some('"') { let content = self.input[start..self.position].to_string(); self.advance(); // consume closing quote Some(TokenKind::String(content)) } else { None // unterminated string } } fn scan_number(&mut self) -> Option<TokenKind> { let start = self.position; while !self.is_eof() && self.current_char().is_some_and(|c| c.is_ascii_digit()) { self.advance(); } let num_str = &self.input[start..self.position]; num_str.parse().ok().map(TokenKind::Number) } fn scan_identifier_or_keyword(&mut self) -> Option<TokenKind> { let start = self.position; while !self.is_eof() { if let Some(c) = self.current_char() { if c.is_ascii_alphanumeric() || c == '_' { self.advance(); } else { break; } } else { break; } } let ident = &self.input[start..self.position]; let token = match ident { "let" => TokenKind::Keyword(Keyword::Let), "if" => TokenKind::Keyword(Keyword::If), "else" => TokenKind::Keyword(Keyword::Else), "while" => TokenKind::Keyword(Keyword::While), "function" => TokenKind::Keyword(Keyword::Function), "return" => TokenKind::Keyword(Keyword::Return), _ => TokenKind::Identifier(ident.to_string()), }; Some(token) } fn skip_whitespace(&mut self) { while !self.is_eof() && self.current_char().is_some_and(|c| c.is_whitespace()) { self.advance(); } } fn current_char(&self) -> Option<char> { self.input[self.position..].chars().next() } fn advance(&mut self) { if let Some(c) = self.current_char() { self.position += c.len_utf8(); } } fn is_eof(&self) -> bool { self.position >= self.input.len() } } /// Span arithmetic demonstrations pub fn demonstrate_span_arithmetic() { let start = ByteIndex::from(10); let offset = ByteOffset::from(5); // Adding offset to index let new_position = start + offset; assert_eq!(new_position, ByteIndex::from(15)); // Subtracting offset from index let prev_position = new_position - offset; assert_eq!(prev_position, start); // Creating spans let span = Span::new(start, new_position); assert_eq!(span.start(), start); assert_eq!(span.end(), new_position); } /// Line offset calculations pub fn demonstrate_line_offsets() { let line = LineIndex::from(5); let offset = LineOffset::from(3); // Moving forward by lines let new_line = line + offset; assert_eq!(new_line, LineIndex::from(8)); // Moving backward by lines let prev_line = new_line - offset; assert_eq!(prev_line, line); } /// UTF-8 aware position tracking pub fn track_utf8_positions(text: &str) -> Vec<(char, ByteIndex)> { let mut positions = Vec::new(); let mut byte_pos = ByteIndex::from(0); for ch in text.chars() { positions.push((ch, byte_pos)); byte_pos += ByteOffset::from(ch.len_utf8() as i64); } positions } #[cfg(test)] mod tests { use super::*; #[test] fn test_source_file() { let source = "let x = 42;\nlet y = x + 1;\nprint(y);"; let file = SourceFile::new("test.lang".to_string(), source.to_string()); // Test line index calculation assert_eq!(file.line_index(ByteIndex::from(0)), LineIndex::from(0)); assert_eq!(file.line_index(ByteIndex::from(12)), LineIndex::from(1)); assert_eq!(file.line_index(ByteIndex::from(27)), LineIndex::from(2)); // Test column index calculation assert_eq!(file.column_index(ByteIndex::from(4)), ColumnIndex::from(4)); assert_eq!(file.column_index(ByteIndex::from(16)), ColumnIndex::from(4)); // Test location let loc = file.location(ByteIndex::from(16)); assert_eq!(loc.line, LineIndex::from(1)); assert_eq!(loc.column, ColumnIndex::from(4)); } #[test] fn test_span_manager() { let mut manager = SpanManager::new(); let file1 = manager.add_file("main.lang".to_string(), "let x = 10;".to_string()); let file2 = manager.add_file( "lib.lang".to_string(), "function add(a, b) { return a + b; }".to_string(), ); assert!(manager.get_file(file1).is_some()); assert!(manager.get_file(file2).is_some()); assert_eq!(manager.find_file("main.lang"), Some(file1)); assert_eq!(manager.find_file("lib.lang"), Some(file2)); } #[test] fn test_lexer() { let mut manager = SpanManager::new(); let file_id = manager.add_file("test.lang".to_string(), "let x = 42 + 3;".to_string()); let mut lexer = Lexer::new("let x = 42 + 3;".to_string(), file_id); let tokens = lexer.tokenize(); assert_eq!(tokens.len(), 7); // Check first token (let) assert_eq!(tokens[0].kind, TokenKind::Keyword(Keyword::Let)); assert_eq!(tokens[0].span.start(), ByteIndex::from(0)); assert_eq!(tokens[0].span.end(), ByteIndex::from(3)); // Check identifier assert_eq!(tokens[1].kind, TokenKind::Identifier("x".to_string())); // Check equals operator assert_eq!(tokens[2].kind, TokenKind::Operator(Operator::Assign)); // Check number assert_eq!(tokens[3].kind, TokenKind::Number(42)); // Check plus operator assert_eq!(tokens[4].kind, TokenKind::Operator(Operator::Plus)); // Check second number assert_eq!(tokens[5].kind, TokenKind::Number(3)); // Check semicolon assert_eq!(tokens[6].kind, TokenKind::Delimiter(Delimiter::Semicolon)); } #[test] fn test_span_arithmetic() { demonstrate_span_arithmetic(); } #[test] fn test_line_offsets() { demonstrate_line_offsets(); } #[test] fn test_utf8_tracking() { let text = "hello 世界!"; let positions = track_utf8_positions(text); // ASCII characters take 1 byte assert_eq!(positions[0], ('h', ByteIndex::from(0))); assert_eq!(positions[5], (' ', ByteIndex::from(5))); // Chinese characters take 3 bytes each assert_eq!(positions[6], ('世', ByteIndex::from(6))); assert_eq!(positions[7], ('界', ByteIndex::from(9))); assert_eq!(positions[8], ('!', ByteIndex::from(12))); } #[test] fn test_span_merging() { let manager = SpanManager::new(); let span1 = manager.create_span(ByteIndex::from(10), ByteIndex::from(20)); let span2 = manager.create_span(ByteIndex::from(15), ByteIndex::from(25)); let merged = manager.merge_spans(span1, span2); assert_eq!(merged.start(), ByteIndex::from(10)); assert_eq!(merged.end(), ByteIndex::from(25)); } } /// A source file with span tracking #[derive(Debug, Clone)] pub struct SourceFile { name: String, contents: String, line_starts: Vec<ByteIndex>, } }
The SourceFile structure maintains the source text along with precomputed line start positions for efficient line and column calculation. The line starts vector enables binary search for position lookups, providing O(log n) complexity for location queries.
#![allow(unused)] fn main() { pub fn new(name: String, contents: String) -> Self { let line_starts = std::iter::once(ByteIndex::from(0)) .chain(contents.char_indices().filter_map(|(i, c)| { if c == '\n' { Some(ByteIndex::from(i as u32 + 1)) } else { None } })) .collect(); Self { name, contents, line_starts, } } }
File construction scans the input once to identify line boundaries, building an index for subsequent position queries. This upfront computation trades initialization time for faster repeated lookups during compilation.
#![allow(unused)] fn main() { pub fn line_index(&self, byte_index: ByteIndex) -> LineIndex { match self.line_starts.binary_search(&byte_index) { Ok(line) => LineIndex::from(line as u32), Err(next_line) => LineIndex::from((next_line as u32).saturating_sub(1)), } } }
Line index calculation uses binary search on the precomputed line starts. When the search finds an exact match, that line contains the position. Otherwise, the position falls within the preceding line.
#![allow(unused)] fn main() { pub fn column_index(&self, byte_index: ByteIndex) -> ColumnIndex { let line_index = self.line_index(byte_index); let line_start = self.line_starts[line_index.to_usize()]; let column_offset = byte_index - line_start; ColumnIndex::from(column_offset.to_usize() as u32) } }
Column calculation first determines the line, then computes the byte offset from the line start. This approach handles variable-width characters correctly by operating on byte positions rather than character counts.
Span Management
#![allow(unused)] fn main() { use std::collections::HashMap; use std::fmt; use codespan::{ByteIndex, ByteOffset, ColumnIndex, LineIndex, LineOffset, Span}; /// A source file with span tracking #[derive(Debug, Clone)] pub struct SourceFile { name: String, contents: String, line_starts: Vec<ByteIndex>, } impl SourceFile { pub fn new(name: String, contents: String) -> Self { let line_starts = std::iter::once(ByteIndex::from(0)) .chain(contents.char_indices().filter_map(|(i, c)| { if c == '\n' { Some(ByteIndex::from(i as u32 + 1)) } else { None } })) .collect(); Self { name, contents, line_starts, } } pub fn name(&self) -> &str { &self.name } pub fn contents(&self) -> &str { &self.contents } pub fn line_index(&self, byte_index: ByteIndex) -> LineIndex { match self.line_starts.binary_search(&byte_index) { Ok(line) => LineIndex::from(line as u32), Err(next_line) => LineIndex::from((next_line as u32).saturating_sub(1)), } } pub fn column_index(&self, byte_index: ByteIndex) -> ColumnIndex { let line_index = self.line_index(byte_index); let line_start = self.line_starts[line_index.to_usize()]; let column_offset = byte_index - line_start; ColumnIndex::from(column_offset.to_usize() as u32) } pub fn location(&self, byte_index: ByteIndex) -> Location { Location { line: self.line_index(byte_index), column: self.column_index(byte_index), } } pub fn slice(&self, span: Span) -> &str { let start = span.start().to_usize(); let end = span.end().to_usize(); &self.contents[start..end] } } /// A location in a source file #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Location { pub line: LineIndex, pub column: ColumnIndex, } impl fmt::Display for Location { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "{}:{}", self.line.to_usize() + 1, self.column.to_usize() + 1 ) } } impl SpanManager { pub fn new() -> Self { Self { files: Vec::new(), file_map: HashMap::new(), } } pub fn add_file(&mut self, name: String, contents: String) -> FileId { let file_id = FileId(self.files.len()); let file = SourceFile::new(name.clone(), contents); self.files.push(file); self.file_map.insert(name, file_id.0); file_id } pub fn get_file(&self, id: FileId) -> Option<&SourceFile> { self.files.get(id.0) } pub fn find_file(&self, name: &str) -> Option<FileId> { self.file_map.get(name).map(|&id| FileId(id)) } pub fn create_span(&self, start: ByteIndex, end: ByteIndex) -> Span { Span::new(start, end) } pub fn merge_spans(&self, first: Span, second: Span) -> Span { let start = first.start().min(second.start()); let end = first.end().max(second.end()); Span::new(start, end) } } impl Default for SpanManager { fn default() -> Self { Self::new() } } /// File identifier #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct FileId(usize); /// A token with span information #[derive(Debug, Clone)] pub struct Token<T> { pub kind: T, pub span: Span, pub file_id: FileId, } impl<T> Token<T> { pub fn new(kind: T, span: Span, file_id: FileId) -> Self { Self { kind, span, file_id, } } } /// Example token types for demonstration #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), String(String), Keyword(Keyword), Operator(Operator), Delimiter(Delimiter), } #[derive(Debug, Clone, PartialEq)] pub enum Keyword { Let, If, Else, While, Function, Return, } #[derive(Debug, Clone, PartialEq)] pub enum Operator { Plus, Minus, Star, Slash, Equal, NotEqual, Less, Greater, Assign, } #[derive(Debug, Clone, PartialEq)] pub enum Delimiter { LeftParen, RightParen, LeftBrace, RightBrace, LeftBracket, RightBracket, Semicolon, Comma, } /// A simple lexer using codespan for span tracking pub struct Lexer { input: String, position: usize, file_id: FileId, } impl Lexer { pub fn new(input: String, file_id: FileId) -> Self { Self { input, position: 0, file_id, } } pub fn tokenize(&mut self) -> Vec<Token<TokenKind>> { let mut tokens = Vec::new(); while !self.is_eof() { self.skip_whitespace(); if self.is_eof() { break; } let start = ByteIndex::from(self.position as u32); if let Some(token) = self.scan_token() { let end = ByteIndex::from(self.position as u32); let span = Span::new(start, end); tokens.push(Token::new(token, span, self.file_id)); } } tokens } fn scan_token(&mut self) -> Option<TokenKind> { let start_char = self.current_char()?; match start_char { '+' => { self.advance(); Some(TokenKind::Operator(Operator::Plus)) } '-' => { self.advance(); Some(TokenKind::Operator(Operator::Minus)) } '*' => { self.advance(); Some(TokenKind::Operator(Operator::Star)) } '/' => { self.advance(); Some(TokenKind::Operator(Operator::Slash)) } '=' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::Equal)) } else { Some(TokenKind::Operator(Operator::Assign)) } } '!' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::NotEqual)) } else { None } } '<' => { self.advance(); Some(TokenKind::Operator(Operator::Less)) } '>' => { self.advance(); Some(TokenKind::Operator(Operator::Greater)) } '(' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftParen)) } ')' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightParen)) } '{' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBrace)) } '}' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBrace)) } '[' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBracket)) } ']' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBracket)) } ';' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Semicolon)) } ',' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Comma)) } '"' => self.scan_string(), c if c.is_ascii_digit() => self.scan_number(), c if c.is_ascii_alphabetic() || c == '_' => self.scan_identifier_or_keyword(), _ => { self.advance(); None } } } fn scan_string(&mut self) -> Option<TokenKind> { self.advance(); // consume opening quote let start = self.position; while !self.is_eof() && self.current_char() != Some('"') { if self.current_char() == Some('\\') { self.advance(); // consume backslash if !self.is_eof() { self.advance(); // consume escaped character } } else { self.advance(); } } if self.current_char() == Some('"') { let content = self.input[start..self.position].to_string(); self.advance(); // consume closing quote Some(TokenKind::String(content)) } else { None // unterminated string } } fn scan_number(&mut self) -> Option<TokenKind> { let start = self.position; while !self.is_eof() && self.current_char().is_some_and(|c| c.is_ascii_digit()) { self.advance(); } let num_str = &self.input[start..self.position]; num_str.parse().ok().map(TokenKind::Number) } fn scan_identifier_or_keyword(&mut self) -> Option<TokenKind> { let start = self.position; while !self.is_eof() { if let Some(c) = self.current_char() { if c.is_ascii_alphanumeric() || c == '_' { self.advance(); } else { break; } } else { break; } } let ident = &self.input[start..self.position]; let token = match ident { "let" => TokenKind::Keyword(Keyword::Let), "if" => TokenKind::Keyword(Keyword::If), "else" => TokenKind::Keyword(Keyword::Else), "while" => TokenKind::Keyword(Keyword::While), "function" => TokenKind::Keyword(Keyword::Function), "return" => TokenKind::Keyword(Keyword::Return), _ => TokenKind::Identifier(ident.to_string()), }; Some(token) } fn skip_whitespace(&mut self) { while !self.is_eof() && self.current_char().is_some_and(|c| c.is_whitespace()) { self.advance(); } } fn current_char(&self) -> Option<char> { self.input[self.position..].chars().next() } fn advance(&mut self) { if let Some(c) = self.current_char() { self.position += c.len_utf8(); } } fn is_eof(&self) -> bool { self.position >= self.input.len() } } /// Span arithmetic demonstrations pub fn demonstrate_span_arithmetic() { let start = ByteIndex::from(10); let offset = ByteOffset::from(5); // Adding offset to index let new_position = start + offset; assert_eq!(new_position, ByteIndex::from(15)); // Subtracting offset from index let prev_position = new_position - offset; assert_eq!(prev_position, start); // Creating spans let span = Span::new(start, new_position); assert_eq!(span.start(), start); assert_eq!(span.end(), new_position); } /// Line offset calculations pub fn demonstrate_line_offsets() { let line = LineIndex::from(5); let offset = LineOffset::from(3); // Moving forward by lines let new_line = line + offset; assert_eq!(new_line, LineIndex::from(8)); // Moving backward by lines let prev_line = new_line - offset; assert_eq!(prev_line, line); } /// UTF-8 aware position tracking pub fn track_utf8_positions(text: &str) -> Vec<(char, ByteIndex)> { let mut positions = Vec::new(); let mut byte_pos = ByteIndex::from(0); for ch in text.chars() { positions.push((ch, byte_pos)); byte_pos += ByteOffset::from(ch.len_utf8() as i64); } positions } #[cfg(test)] mod tests { use super::*; #[test] fn test_source_file() { let source = "let x = 42;\nlet y = x + 1;\nprint(y);"; let file = SourceFile::new("test.lang".to_string(), source.to_string()); // Test line index calculation assert_eq!(file.line_index(ByteIndex::from(0)), LineIndex::from(0)); assert_eq!(file.line_index(ByteIndex::from(12)), LineIndex::from(1)); assert_eq!(file.line_index(ByteIndex::from(27)), LineIndex::from(2)); // Test column index calculation assert_eq!(file.column_index(ByteIndex::from(4)), ColumnIndex::from(4)); assert_eq!(file.column_index(ByteIndex::from(16)), ColumnIndex::from(4)); // Test location let loc = file.location(ByteIndex::from(16)); assert_eq!(loc.line, LineIndex::from(1)); assert_eq!(loc.column, ColumnIndex::from(4)); } #[test] fn test_span_manager() { let mut manager = SpanManager::new(); let file1 = manager.add_file("main.lang".to_string(), "let x = 10;".to_string()); let file2 = manager.add_file( "lib.lang".to_string(), "function add(a, b) { return a + b; }".to_string(), ); assert!(manager.get_file(file1).is_some()); assert!(manager.get_file(file2).is_some()); assert_eq!(manager.find_file("main.lang"), Some(file1)); assert_eq!(manager.find_file("lib.lang"), Some(file2)); } #[test] fn test_lexer() { let mut manager = SpanManager::new(); let file_id = manager.add_file("test.lang".to_string(), "let x = 42 + 3;".to_string()); let mut lexer = Lexer::new("let x = 42 + 3;".to_string(), file_id); let tokens = lexer.tokenize(); assert_eq!(tokens.len(), 7); // Check first token (let) assert_eq!(tokens[0].kind, TokenKind::Keyword(Keyword::Let)); assert_eq!(tokens[0].span.start(), ByteIndex::from(0)); assert_eq!(tokens[0].span.end(), ByteIndex::from(3)); // Check identifier assert_eq!(tokens[1].kind, TokenKind::Identifier("x".to_string())); // Check equals operator assert_eq!(tokens[2].kind, TokenKind::Operator(Operator::Assign)); // Check number assert_eq!(tokens[3].kind, TokenKind::Number(42)); // Check plus operator assert_eq!(tokens[4].kind, TokenKind::Operator(Operator::Plus)); // Check second number assert_eq!(tokens[5].kind, TokenKind::Number(3)); // Check semicolon assert_eq!(tokens[6].kind, TokenKind::Delimiter(Delimiter::Semicolon)); } #[test] fn test_span_arithmetic() { demonstrate_span_arithmetic(); } #[test] fn test_line_offsets() { demonstrate_line_offsets(); } #[test] fn test_utf8_tracking() { let text = "hello 世界!"; let positions = track_utf8_positions(text); // ASCII characters take 1 byte assert_eq!(positions[0], ('h', ByteIndex::from(0))); assert_eq!(positions[5], (' ', ByteIndex::from(5))); // Chinese characters take 3 bytes each assert_eq!(positions[6], ('世', ByteIndex::from(6))); assert_eq!(positions[7], ('界', ByteIndex::from(9))); assert_eq!(positions[8], ('!', ByteIndex::from(12))); } #[test] fn test_span_merging() { let manager = SpanManager::new(); let span1 = manager.create_span(ByteIndex::from(10), ByteIndex::from(20)); let span2 = manager.create_span(ByteIndex::from(15), ByteIndex::from(25)); let merged = manager.merge_spans(span1, span2); assert_eq!(merged.start(), ByteIndex::from(10)); assert_eq!(merged.end(), ByteIndex::from(25)); } } /// A span manager for tracking multiple source files pub struct SpanManager { files: Vec<SourceFile>, file_map: HashMap<String, usize>, } }
The SpanManager coordinates multiple source files in a compilation unit. It provides centralized file registration and lookup while maintaining consistent file identifiers across the compilation pipeline.
#![allow(unused)] fn main() { pub fn add_file(&mut self, name: String, contents: String) -> FileId { let file_id = FileId(self.files.len()); let file = SourceFile::new(name.clone(), contents); self.files.push(file); self.file_map.insert(name, file_id.0); file_id } }
File registration assigns sequential identifiers and maintains a name-based lookup table. This design supports both positional access for span resolution and name-based queries for import resolution.
#![allow(unused)] fn main() { pub fn merge_spans(&self, first: Span, second: Span) -> Span { let start = first.start().min(second.start()); let end = first.end().max(second.end()); Span::new(start, end) } }
Span merging combines multiple spans into their encompassing range. This operation proves essential for error reporting when an error spans multiple tokens or when synthesizing spans for derived expressions.
Token Representation
#![allow(unused)] fn main() { use std::collections::HashMap; use std::fmt; use codespan::{ByteIndex, ByteOffset, ColumnIndex, LineIndex, LineOffset, Span}; /// A source file with span tracking #[derive(Debug, Clone)] pub struct SourceFile { name: String, contents: String, line_starts: Vec<ByteIndex>, } impl SourceFile { pub fn new(name: String, contents: String) -> Self { let line_starts = std::iter::once(ByteIndex::from(0)) .chain(contents.char_indices().filter_map(|(i, c)| { if c == '\n' { Some(ByteIndex::from(i as u32 + 1)) } else { None } })) .collect(); Self { name, contents, line_starts, } } pub fn name(&self) -> &str { &self.name } pub fn contents(&self) -> &str { &self.contents } pub fn line_index(&self, byte_index: ByteIndex) -> LineIndex { match self.line_starts.binary_search(&byte_index) { Ok(line) => LineIndex::from(line as u32), Err(next_line) => LineIndex::from((next_line as u32).saturating_sub(1)), } } pub fn column_index(&self, byte_index: ByteIndex) -> ColumnIndex { let line_index = self.line_index(byte_index); let line_start = self.line_starts[line_index.to_usize()]; let column_offset = byte_index - line_start; ColumnIndex::from(column_offset.to_usize() as u32) } pub fn location(&self, byte_index: ByteIndex) -> Location { Location { line: self.line_index(byte_index), column: self.column_index(byte_index), } } pub fn slice(&self, span: Span) -> &str { let start = span.start().to_usize(); let end = span.end().to_usize(); &self.contents[start..end] } } /// A location in a source file #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Location { pub line: LineIndex, pub column: ColumnIndex, } impl fmt::Display for Location { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "{}:{}", self.line.to_usize() + 1, self.column.to_usize() + 1 ) } } /// A span manager for tracking multiple source files pub struct SpanManager { files: Vec<SourceFile>, file_map: HashMap<String, usize>, } impl SpanManager { pub fn new() -> Self { Self { files: Vec::new(), file_map: HashMap::new(), } } pub fn add_file(&mut self, name: String, contents: String) -> FileId { let file_id = FileId(self.files.len()); let file = SourceFile::new(name.clone(), contents); self.files.push(file); self.file_map.insert(name, file_id.0); file_id } pub fn get_file(&self, id: FileId) -> Option<&SourceFile> { self.files.get(id.0) } pub fn find_file(&self, name: &str) -> Option<FileId> { self.file_map.get(name).map(|&id| FileId(id)) } pub fn create_span(&self, start: ByteIndex, end: ByteIndex) -> Span { Span::new(start, end) } pub fn merge_spans(&self, first: Span, second: Span) -> Span { let start = first.start().min(second.start()); let end = first.end().max(second.end()); Span::new(start, end) } } impl Default for SpanManager { fn default() -> Self { Self::new() } } /// File identifier #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct FileId(usize); impl<T> Token<T> { pub fn new(kind: T, span: Span, file_id: FileId) -> Self { Self { kind, span, file_id, } } } /// Example token types for demonstration #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), String(String), Keyword(Keyword), Operator(Operator), Delimiter(Delimiter), } #[derive(Debug, Clone, PartialEq)] pub enum Keyword { Let, If, Else, While, Function, Return, } #[derive(Debug, Clone, PartialEq)] pub enum Operator { Plus, Minus, Star, Slash, Equal, NotEqual, Less, Greater, Assign, } #[derive(Debug, Clone, PartialEq)] pub enum Delimiter { LeftParen, RightParen, LeftBrace, RightBrace, LeftBracket, RightBracket, Semicolon, Comma, } /// A simple lexer using codespan for span tracking pub struct Lexer { input: String, position: usize, file_id: FileId, } impl Lexer { pub fn new(input: String, file_id: FileId) -> Self { Self { input, position: 0, file_id, } } pub fn tokenize(&mut self) -> Vec<Token<TokenKind>> { let mut tokens = Vec::new(); while !self.is_eof() { self.skip_whitespace(); if self.is_eof() { break; } let start = ByteIndex::from(self.position as u32); if let Some(token) = self.scan_token() { let end = ByteIndex::from(self.position as u32); let span = Span::new(start, end); tokens.push(Token::new(token, span, self.file_id)); } } tokens } fn scan_token(&mut self) -> Option<TokenKind> { let start_char = self.current_char()?; match start_char { '+' => { self.advance(); Some(TokenKind::Operator(Operator::Plus)) } '-' => { self.advance(); Some(TokenKind::Operator(Operator::Minus)) } '*' => { self.advance(); Some(TokenKind::Operator(Operator::Star)) } '/' => { self.advance(); Some(TokenKind::Operator(Operator::Slash)) } '=' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::Equal)) } else { Some(TokenKind::Operator(Operator::Assign)) } } '!' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::NotEqual)) } else { None } } '<' => { self.advance(); Some(TokenKind::Operator(Operator::Less)) } '>' => { self.advance(); Some(TokenKind::Operator(Operator::Greater)) } '(' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftParen)) } ')' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightParen)) } '{' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBrace)) } '}' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBrace)) } '[' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBracket)) } ']' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBracket)) } ';' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Semicolon)) } ',' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Comma)) } '"' => self.scan_string(), c if c.is_ascii_digit() => self.scan_number(), c if c.is_ascii_alphabetic() || c == '_' => self.scan_identifier_or_keyword(), _ => { self.advance(); None } } } fn scan_string(&mut self) -> Option<TokenKind> { self.advance(); // consume opening quote let start = self.position; while !self.is_eof() && self.current_char() != Some('"') { if self.current_char() == Some('\\') { self.advance(); // consume backslash if !self.is_eof() { self.advance(); // consume escaped character } } else { self.advance(); } } if self.current_char() == Some('"') { let content = self.input[start..self.position].to_string(); self.advance(); // consume closing quote Some(TokenKind::String(content)) } else { None // unterminated string } } fn scan_number(&mut self) -> Option<TokenKind> { let start = self.position; while !self.is_eof() && self.current_char().is_some_and(|c| c.is_ascii_digit()) { self.advance(); } let num_str = &self.input[start..self.position]; num_str.parse().ok().map(TokenKind::Number) } fn scan_identifier_or_keyword(&mut self) -> Option<TokenKind> { let start = self.position; while !self.is_eof() { if let Some(c) = self.current_char() { if c.is_ascii_alphanumeric() || c == '_' { self.advance(); } else { break; } } else { break; } } let ident = &self.input[start..self.position]; let token = match ident { "let" => TokenKind::Keyword(Keyword::Let), "if" => TokenKind::Keyword(Keyword::If), "else" => TokenKind::Keyword(Keyword::Else), "while" => TokenKind::Keyword(Keyword::While), "function" => TokenKind::Keyword(Keyword::Function), "return" => TokenKind::Keyword(Keyword::Return), _ => TokenKind::Identifier(ident.to_string()), }; Some(token) } fn skip_whitespace(&mut self) { while !self.is_eof() && self.current_char().is_some_and(|c| c.is_whitespace()) { self.advance(); } } fn current_char(&self) -> Option<char> { self.input[self.position..].chars().next() } fn advance(&mut self) { if let Some(c) = self.current_char() { self.position += c.len_utf8(); } } fn is_eof(&self) -> bool { self.position >= self.input.len() } } /// Span arithmetic demonstrations pub fn demonstrate_span_arithmetic() { let start = ByteIndex::from(10); let offset = ByteOffset::from(5); // Adding offset to index let new_position = start + offset; assert_eq!(new_position, ByteIndex::from(15)); // Subtracting offset from index let prev_position = new_position - offset; assert_eq!(prev_position, start); // Creating spans let span = Span::new(start, new_position); assert_eq!(span.start(), start); assert_eq!(span.end(), new_position); } /// Line offset calculations pub fn demonstrate_line_offsets() { let line = LineIndex::from(5); let offset = LineOffset::from(3); // Moving forward by lines let new_line = line + offset; assert_eq!(new_line, LineIndex::from(8)); // Moving backward by lines let prev_line = new_line - offset; assert_eq!(prev_line, line); } /// UTF-8 aware position tracking pub fn track_utf8_positions(text: &str) -> Vec<(char, ByteIndex)> { let mut positions = Vec::new(); let mut byte_pos = ByteIndex::from(0); for ch in text.chars() { positions.push((ch, byte_pos)); byte_pos += ByteOffset::from(ch.len_utf8() as i64); } positions } #[cfg(test)] mod tests { use super::*; #[test] fn test_source_file() { let source = "let x = 42;\nlet y = x + 1;\nprint(y);"; let file = SourceFile::new("test.lang".to_string(), source.to_string()); // Test line index calculation assert_eq!(file.line_index(ByteIndex::from(0)), LineIndex::from(0)); assert_eq!(file.line_index(ByteIndex::from(12)), LineIndex::from(1)); assert_eq!(file.line_index(ByteIndex::from(27)), LineIndex::from(2)); // Test column index calculation assert_eq!(file.column_index(ByteIndex::from(4)), ColumnIndex::from(4)); assert_eq!(file.column_index(ByteIndex::from(16)), ColumnIndex::from(4)); // Test location let loc = file.location(ByteIndex::from(16)); assert_eq!(loc.line, LineIndex::from(1)); assert_eq!(loc.column, ColumnIndex::from(4)); } #[test] fn test_span_manager() { let mut manager = SpanManager::new(); let file1 = manager.add_file("main.lang".to_string(), "let x = 10;".to_string()); let file2 = manager.add_file( "lib.lang".to_string(), "function add(a, b) { return a + b; }".to_string(), ); assert!(manager.get_file(file1).is_some()); assert!(manager.get_file(file2).is_some()); assert_eq!(manager.find_file("main.lang"), Some(file1)); assert_eq!(manager.find_file("lib.lang"), Some(file2)); } #[test] fn test_lexer() { let mut manager = SpanManager::new(); let file_id = manager.add_file("test.lang".to_string(), "let x = 42 + 3;".to_string()); let mut lexer = Lexer::new("let x = 42 + 3;".to_string(), file_id); let tokens = lexer.tokenize(); assert_eq!(tokens.len(), 7); // Check first token (let) assert_eq!(tokens[0].kind, TokenKind::Keyword(Keyword::Let)); assert_eq!(tokens[0].span.start(), ByteIndex::from(0)); assert_eq!(tokens[0].span.end(), ByteIndex::from(3)); // Check identifier assert_eq!(tokens[1].kind, TokenKind::Identifier("x".to_string())); // Check equals operator assert_eq!(tokens[2].kind, TokenKind::Operator(Operator::Assign)); // Check number assert_eq!(tokens[3].kind, TokenKind::Number(42)); // Check plus operator assert_eq!(tokens[4].kind, TokenKind::Operator(Operator::Plus)); // Check second number assert_eq!(tokens[5].kind, TokenKind::Number(3)); // Check semicolon assert_eq!(tokens[6].kind, TokenKind::Delimiter(Delimiter::Semicolon)); } #[test] fn test_span_arithmetic() { demonstrate_span_arithmetic(); } #[test] fn test_line_offsets() { demonstrate_line_offsets(); } #[test] fn test_utf8_tracking() { let text = "hello 世界!"; let positions = track_utf8_positions(text); // ASCII characters take 1 byte assert_eq!(positions[0], ('h', ByteIndex::from(0))); assert_eq!(positions[5], (' ', ByteIndex::from(5))); // Chinese characters take 3 bytes each assert_eq!(positions[6], ('世', ByteIndex::from(6))); assert_eq!(positions[7], ('界', ByteIndex::from(9))); assert_eq!(positions[8], ('!', ByteIndex::from(12))); } #[test] fn test_span_merging() { let manager = SpanManager::new(); let span1 = manager.create_span(ByteIndex::from(10), ByteIndex::from(20)); let span2 = manager.create_span(ByteIndex::from(15), ByteIndex::from(25)); let merged = manager.merge_spans(span1, span2); assert_eq!(merged.start(), ByteIndex::from(10)); assert_eq!(merged.end(), ByteIndex::from(25)); } } /// A token with span information #[derive(Debug, Clone)] pub struct Token<T> { pub kind: T, pub span: Span, pub file_id: FileId, } }
Tokens carry their span information throughout parsing and analysis. The generic type parameter allows reuse across different token representations while maintaining consistent span tracking.
#![allow(unused)] fn main() { use std::collections::HashMap; use std::fmt; use codespan::{ByteIndex, ByteOffset, ColumnIndex, LineIndex, LineOffset, Span}; /// A source file with span tracking #[derive(Debug, Clone)] pub struct SourceFile { name: String, contents: String, line_starts: Vec<ByteIndex>, } impl SourceFile { pub fn new(name: String, contents: String) -> Self { let line_starts = std::iter::once(ByteIndex::from(0)) .chain(contents.char_indices().filter_map(|(i, c)| { if c == '\n' { Some(ByteIndex::from(i as u32 + 1)) } else { None } })) .collect(); Self { name, contents, line_starts, } } pub fn name(&self) -> &str { &self.name } pub fn contents(&self) -> &str { &self.contents } pub fn line_index(&self, byte_index: ByteIndex) -> LineIndex { match self.line_starts.binary_search(&byte_index) { Ok(line) => LineIndex::from(line as u32), Err(next_line) => LineIndex::from((next_line as u32).saturating_sub(1)), } } pub fn column_index(&self, byte_index: ByteIndex) -> ColumnIndex { let line_index = self.line_index(byte_index); let line_start = self.line_starts[line_index.to_usize()]; let column_offset = byte_index - line_start; ColumnIndex::from(column_offset.to_usize() as u32) } pub fn location(&self, byte_index: ByteIndex) -> Location { Location { line: self.line_index(byte_index), column: self.column_index(byte_index), } } pub fn slice(&self, span: Span) -> &str { let start = span.start().to_usize(); let end = span.end().to_usize(); &self.contents[start..end] } } /// A location in a source file #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Location { pub line: LineIndex, pub column: ColumnIndex, } impl fmt::Display for Location { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "{}:{}", self.line.to_usize() + 1, self.column.to_usize() + 1 ) } } /// A span manager for tracking multiple source files pub struct SpanManager { files: Vec<SourceFile>, file_map: HashMap<String, usize>, } impl SpanManager { pub fn new() -> Self { Self { files: Vec::new(), file_map: HashMap::new(), } } pub fn add_file(&mut self, name: String, contents: String) -> FileId { let file_id = FileId(self.files.len()); let file = SourceFile::new(name.clone(), contents); self.files.push(file); self.file_map.insert(name, file_id.0); file_id } pub fn get_file(&self, id: FileId) -> Option<&SourceFile> { self.files.get(id.0) } pub fn find_file(&self, name: &str) -> Option<FileId> { self.file_map.get(name).map(|&id| FileId(id)) } pub fn create_span(&self, start: ByteIndex, end: ByteIndex) -> Span { Span::new(start, end) } pub fn merge_spans(&self, first: Span, second: Span) -> Span { let start = first.start().min(second.start()); let end = first.end().max(second.end()); Span::new(start, end) } } impl Default for SpanManager { fn default() -> Self { Self::new() } } /// File identifier #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct FileId(usize); /// A token with span information #[derive(Debug, Clone)] pub struct Token<T> { pub kind: T, pub span: Span, pub file_id: FileId, } impl<T> Token<T> { pub fn new(kind: T, span: Span, file_id: FileId) -> Self { Self { kind, span, file_id, } } } #[derive(Debug, Clone, PartialEq)] pub enum Keyword { Let, If, Else, While, Function, Return, } #[derive(Debug, Clone, PartialEq)] pub enum Operator { Plus, Minus, Star, Slash, Equal, NotEqual, Less, Greater, Assign, } #[derive(Debug, Clone, PartialEq)] pub enum Delimiter { LeftParen, RightParen, LeftBrace, RightBrace, LeftBracket, RightBracket, Semicolon, Comma, } /// A simple lexer using codespan for span tracking pub struct Lexer { input: String, position: usize, file_id: FileId, } impl Lexer { pub fn new(input: String, file_id: FileId) -> Self { Self { input, position: 0, file_id, } } pub fn tokenize(&mut self) -> Vec<Token<TokenKind>> { let mut tokens = Vec::new(); while !self.is_eof() { self.skip_whitespace(); if self.is_eof() { break; } let start = ByteIndex::from(self.position as u32); if let Some(token) = self.scan_token() { let end = ByteIndex::from(self.position as u32); let span = Span::new(start, end); tokens.push(Token::new(token, span, self.file_id)); } } tokens } fn scan_token(&mut self) -> Option<TokenKind> { let start_char = self.current_char()?; match start_char { '+' => { self.advance(); Some(TokenKind::Operator(Operator::Plus)) } '-' => { self.advance(); Some(TokenKind::Operator(Operator::Minus)) } '*' => { self.advance(); Some(TokenKind::Operator(Operator::Star)) } '/' => { self.advance(); Some(TokenKind::Operator(Operator::Slash)) } '=' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::Equal)) } else { Some(TokenKind::Operator(Operator::Assign)) } } '!' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::NotEqual)) } else { None } } '<' => { self.advance(); Some(TokenKind::Operator(Operator::Less)) } '>' => { self.advance(); Some(TokenKind::Operator(Operator::Greater)) } '(' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftParen)) } ')' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightParen)) } '{' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBrace)) } '}' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBrace)) } '[' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBracket)) } ']' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBracket)) } ';' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Semicolon)) } ',' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Comma)) } '"' => self.scan_string(), c if c.is_ascii_digit() => self.scan_number(), c if c.is_ascii_alphabetic() || c == '_' => self.scan_identifier_or_keyword(), _ => { self.advance(); None } } } fn scan_string(&mut self) -> Option<TokenKind> { self.advance(); // consume opening quote let start = self.position; while !self.is_eof() && self.current_char() != Some('"') { if self.current_char() == Some('\\') { self.advance(); // consume backslash if !self.is_eof() { self.advance(); // consume escaped character } } else { self.advance(); } } if self.current_char() == Some('"') { let content = self.input[start..self.position].to_string(); self.advance(); // consume closing quote Some(TokenKind::String(content)) } else { None // unterminated string } } fn scan_number(&mut self) -> Option<TokenKind> { let start = self.position; while !self.is_eof() && self.current_char().is_some_and(|c| c.is_ascii_digit()) { self.advance(); } let num_str = &self.input[start..self.position]; num_str.parse().ok().map(TokenKind::Number) } fn scan_identifier_or_keyword(&mut self) -> Option<TokenKind> { let start = self.position; while !self.is_eof() { if let Some(c) = self.current_char() { if c.is_ascii_alphanumeric() || c == '_' { self.advance(); } else { break; } } else { break; } } let ident = &self.input[start..self.position]; let token = match ident { "let" => TokenKind::Keyword(Keyword::Let), "if" => TokenKind::Keyword(Keyword::If), "else" => TokenKind::Keyword(Keyword::Else), "while" => TokenKind::Keyword(Keyword::While), "function" => TokenKind::Keyword(Keyword::Function), "return" => TokenKind::Keyword(Keyword::Return), _ => TokenKind::Identifier(ident.to_string()), }; Some(token) } fn skip_whitespace(&mut self) { while !self.is_eof() && self.current_char().is_some_and(|c| c.is_whitespace()) { self.advance(); } } fn current_char(&self) -> Option<char> { self.input[self.position..].chars().next() } fn advance(&mut self) { if let Some(c) = self.current_char() { self.position += c.len_utf8(); } } fn is_eof(&self) -> bool { self.position >= self.input.len() } } /// Span arithmetic demonstrations pub fn demonstrate_span_arithmetic() { let start = ByteIndex::from(10); let offset = ByteOffset::from(5); // Adding offset to index let new_position = start + offset; assert_eq!(new_position, ByteIndex::from(15)); // Subtracting offset from index let prev_position = new_position - offset; assert_eq!(prev_position, start); // Creating spans let span = Span::new(start, new_position); assert_eq!(span.start(), start); assert_eq!(span.end(), new_position); } /// Line offset calculations pub fn demonstrate_line_offsets() { let line = LineIndex::from(5); let offset = LineOffset::from(3); // Moving forward by lines let new_line = line + offset; assert_eq!(new_line, LineIndex::from(8)); // Moving backward by lines let prev_line = new_line - offset; assert_eq!(prev_line, line); } /// UTF-8 aware position tracking pub fn track_utf8_positions(text: &str) -> Vec<(char, ByteIndex)> { let mut positions = Vec::new(); let mut byte_pos = ByteIndex::from(0); for ch in text.chars() { positions.push((ch, byte_pos)); byte_pos += ByteOffset::from(ch.len_utf8() as i64); } positions } #[cfg(test)] mod tests { use super::*; #[test] fn test_source_file() { let source = "let x = 42;\nlet y = x + 1;\nprint(y);"; let file = SourceFile::new("test.lang".to_string(), source.to_string()); // Test line index calculation assert_eq!(file.line_index(ByteIndex::from(0)), LineIndex::from(0)); assert_eq!(file.line_index(ByteIndex::from(12)), LineIndex::from(1)); assert_eq!(file.line_index(ByteIndex::from(27)), LineIndex::from(2)); // Test column index calculation assert_eq!(file.column_index(ByteIndex::from(4)), ColumnIndex::from(4)); assert_eq!(file.column_index(ByteIndex::from(16)), ColumnIndex::from(4)); // Test location let loc = file.location(ByteIndex::from(16)); assert_eq!(loc.line, LineIndex::from(1)); assert_eq!(loc.column, ColumnIndex::from(4)); } #[test] fn test_span_manager() { let mut manager = SpanManager::new(); let file1 = manager.add_file("main.lang".to_string(), "let x = 10;".to_string()); let file2 = manager.add_file( "lib.lang".to_string(), "function add(a, b) { return a + b; }".to_string(), ); assert!(manager.get_file(file1).is_some()); assert!(manager.get_file(file2).is_some()); assert_eq!(manager.find_file("main.lang"), Some(file1)); assert_eq!(manager.find_file("lib.lang"), Some(file2)); } #[test] fn test_lexer() { let mut manager = SpanManager::new(); let file_id = manager.add_file("test.lang".to_string(), "let x = 42 + 3;".to_string()); let mut lexer = Lexer::new("let x = 42 + 3;".to_string(), file_id); let tokens = lexer.tokenize(); assert_eq!(tokens.len(), 7); // Check first token (let) assert_eq!(tokens[0].kind, TokenKind::Keyword(Keyword::Let)); assert_eq!(tokens[0].span.start(), ByteIndex::from(0)); assert_eq!(tokens[0].span.end(), ByteIndex::from(3)); // Check identifier assert_eq!(tokens[1].kind, TokenKind::Identifier("x".to_string())); // Check equals operator assert_eq!(tokens[2].kind, TokenKind::Operator(Operator::Assign)); // Check number assert_eq!(tokens[3].kind, TokenKind::Number(42)); // Check plus operator assert_eq!(tokens[4].kind, TokenKind::Operator(Operator::Plus)); // Check second number assert_eq!(tokens[5].kind, TokenKind::Number(3)); // Check semicolon assert_eq!(tokens[6].kind, TokenKind::Delimiter(Delimiter::Semicolon)); } #[test] fn test_span_arithmetic() { demonstrate_span_arithmetic(); } #[test] fn test_line_offsets() { demonstrate_line_offsets(); } #[test] fn test_utf8_tracking() { let text = "hello 世界!"; let positions = track_utf8_positions(text); // ASCII characters take 1 byte assert_eq!(positions[0], ('h', ByteIndex::from(0))); assert_eq!(positions[5], (' ', ByteIndex::from(5))); // Chinese characters take 3 bytes each assert_eq!(positions[6], ('世', ByteIndex::from(6))); assert_eq!(positions[7], ('界', ByteIndex::from(9))); assert_eq!(positions[8], ('!', ByteIndex::from(12))); } #[test] fn test_span_merging() { let manager = SpanManager::new(); let span1 = manager.create_span(ByteIndex::from(10), ByteIndex::from(20)); let span2 = manager.create_span(ByteIndex::from(15), ByteIndex::from(25)); let merged = manager.merge_spans(span1, span2); assert_eq!(merged.start(), ByteIndex::from(10)); assert_eq!(merged.end(), ByteIndex::from(25)); } } /// Example token types for demonstration #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), String(String), Keyword(Keyword), Operator(Operator), Delimiter(Delimiter), } }
The token enumeration demonstrates typical language constructs that benefit from span tracking. Each variant can be associated with its source location for precise error reporting.
Lexical Analysis
#![allow(unused)] fn main() { use std::collections::HashMap; use std::fmt; use codespan::{ByteIndex, ByteOffset, ColumnIndex, LineIndex, LineOffset, Span}; /// A source file with span tracking #[derive(Debug, Clone)] pub struct SourceFile { name: String, contents: String, line_starts: Vec<ByteIndex>, } impl SourceFile { pub fn new(name: String, contents: String) -> Self { let line_starts = std::iter::once(ByteIndex::from(0)) .chain(contents.char_indices().filter_map(|(i, c)| { if c == '\n' { Some(ByteIndex::from(i as u32 + 1)) } else { None } })) .collect(); Self { name, contents, line_starts, } } pub fn name(&self) -> &str { &self.name } pub fn contents(&self) -> &str { &self.contents } pub fn line_index(&self, byte_index: ByteIndex) -> LineIndex { match self.line_starts.binary_search(&byte_index) { Ok(line) => LineIndex::from(line as u32), Err(next_line) => LineIndex::from((next_line as u32).saturating_sub(1)), } } pub fn column_index(&self, byte_index: ByteIndex) -> ColumnIndex { let line_index = self.line_index(byte_index); let line_start = self.line_starts[line_index.to_usize()]; let column_offset = byte_index - line_start; ColumnIndex::from(column_offset.to_usize() as u32) } pub fn location(&self, byte_index: ByteIndex) -> Location { Location { line: self.line_index(byte_index), column: self.column_index(byte_index), } } pub fn slice(&self, span: Span) -> &str { let start = span.start().to_usize(); let end = span.end().to_usize(); &self.contents[start..end] } } /// A location in a source file #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Location { pub line: LineIndex, pub column: ColumnIndex, } impl fmt::Display for Location { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "{}:{}", self.line.to_usize() + 1, self.column.to_usize() + 1 ) } } /// A span manager for tracking multiple source files pub struct SpanManager { files: Vec<SourceFile>, file_map: HashMap<String, usize>, } impl SpanManager { pub fn new() -> Self { Self { files: Vec::new(), file_map: HashMap::new(), } } pub fn add_file(&mut self, name: String, contents: String) -> FileId { let file_id = FileId(self.files.len()); let file = SourceFile::new(name.clone(), contents); self.files.push(file); self.file_map.insert(name, file_id.0); file_id } pub fn get_file(&self, id: FileId) -> Option<&SourceFile> { self.files.get(id.0) } pub fn find_file(&self, name: &str) -> Option<FileId> { self.file_map.get(name).map(|&id| FileId(id)) } pub fn create_span(&self, start: ByteIndex, end: ByteIndex) -> Span { Span::new(start, end) } pub fn merge_spans(&self, first: Span, second: Span) -> Span { let start = first.start().min(second.start()); let end = first.end().max(second.end()); Span::new(start, end) } } impl Default for SpanManager { fn default() -> Self { Self::new() } } /// File identifier #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct FileId(usize); /// A token with span information #[derive(Debug, Clone)] pub struct Token<T> { pub kind: T, pub span: Span, pub file_id: FileId, } impl<T> Token<T> { pub fn new(kind: T, span: Span, file_id: FileId) -> Self { Self { kind, span, file_id, } } } /// Example token types for demonstration #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), String(String), Keyword(Keyword), Operator(Operator), Delimiter(Delimiter), } #[derive(Debug, Clone, PartialEq)] pub enum Keyword { Let, If, Else, While, Function, Return, } #[derive(Debug, Clone, PartialEq)] pub enum Operator { Plus, Minus, Star, Slash, Equal, NotEqual, Less, Greater, Assign, } #[derive(Debug, Clone, PartialEq)] pub enum Delimiter { LeftParen, RightParen, LeftBrace, RightBrace, LeftBracket, RightBracket, Semicolon, Comma, } impl Lexer { pub fn new(input: String, file_id: FileId) -> Self { Self { input, position: 0, file_id, } } pub fn tokenize(&mut self) -> Vec<Token<TokenKind>> { let mut tokens = Vec::new(); while !self.is_eof() { self.skip_whitespace(); if self.is_eof() { break; } let start = ByteIndex::from(self.position as u32); if let Some(token) = self.scan_token() { let end = ByteIndex::from(self.position as u32); let span = Span::new(start, end); tokens.push(Token::new(token, span, self.file_id)); } } tokens } fn scan_token(&mut self) -> Option<TokenKind> { let start_char = self.current_char()?; match start_char { '+' => { self.advance(); Some(TokenKind::Operator(Operator::Plus)) } '-' => { self.advance(); Some(TokenKind::Operator(Operator::Minus)) } '*' => { self.advance(); Some(TokenKind::Operator(Operator::Star)) } '/' => { self.advance(); Some(TokenKind::Operator(Operator::Slash)) } '=' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::Equal)) } else { Some(TokenKind::Operator(Operator::Assign)) } } '!' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::NotEqual)) } else { None } } '<' => { self.advance(); Some(TokenKind::Operator(Operator::Less)) } '>' => { self.advance(); Some(TokenKind::Operator(Operator::Greater)) } '(' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftParen)) } ')' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightParen)) } '{' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBrace)) } '}' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBrace)) } '[' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBracket)) } ']' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBracket)) } ';' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Semicolon)) } ',' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Comma)) } '"' => self.scan_string(), c if c.is_ascii_digit() => self.scan_number(), c if c.is_ascii_alphabetic() || c == '_' => self.scan_identifier_or_keyword(), _ => { self.advance(); None } } } fn scan_string(&mut self) -> Option<TokenKind> { self.advance(); // consume opening quote let start = self.position; while !self.is_eof() && self.current_char() != Some('"') { if self.current_char() == Some('\\') { self.advance(); // consume backslash if !self.is_eof() { self.advance(); // consume escaped character } } else { self.advance(); } } if self.current_char() == Some('"') { let content = self.input[start..self.position].to_string(); self.advance(); // consume closing quote Some(TokenKind::String(content)) } else { None // unterminated string } } fn scan_number(&mut self) -> Option<TokenKind> { let start = self.position; while !self.is_eof() && self.current_char().is_some_and(|c| c.is_ascii_digit()) { self.advance(); } let num_str = &self.input[start..self.position]; num_str.parse().ok().map(TokenKind::Number) } fn scan_identifier_or_keyword(&mut self) -> Option<TokenKind> { let start = self.position; while !self.is_eof() { if let Some(c) = self.current_char() { if c.is_ascii_alphanumeric() || c == '_' { self.advance(); } else { break; } } else { break; } } let ident = &self.input[start..self.position]; let token = match ident { "let" => TokenKind::Keyword(Keyword::Let), "if" => TokenKind::Keyword(Keyword::If), "else" => TokenKind::Keyword(Keyword::Else), "while" => TokenKind::Keyword(Keyword::While), "function" => TokenKind::Keyword(Keyword::Function), "return" => TokenKind::Keyword(Keyword::Return), _ => TokenKind::Identifier(ident.to_string()), }; Some(token) } fn skip_whitespace(&mut self) { while !self.is_eof() && self.current_char().is_some_and(|c| c.is_whitespace()) { self.advance(); } } fn current_char(&self) -> Option<char> { self.input[self.position..].chars().next() } fn advance(&mut self) { if let Some(c) = self.current_char() { self.position += c.len_utf8(); } } fn is_eof(&self) -> bool { self.position >= self.input.len() } } /// Span arithmetic demonstrations pub fn demonstrate_span_arithmetic() { let start = ByteIndex::from(10); let offset = ByteOffset::from(5); // Adding offset to index let new_position = start + offset; assert_eq!(new_position, ByteIndex::from(15)); // Subtracting offset from index let prev_position = new_position - offset; assert_eq!(prev_position, start); // Creating spans let span = Span::new(start, new_position); assert_eq!(span.start(), start); assert_eq!(span.end(), new_position); } /// Line offset calculations pub fn demonstrate_line_offsets() { let line = LineIndex::from(5); let offset = LineOffset::from(3); // Moving forward by lines let new_line = line + offset; assert_eq!(new_line, LineIndex::from(8)); // Moving backward by lines let prev_line = new_line - offset; assert_eq!(prev_line, line); } /// UTF-8 aware position tracking pub fn track_utf8_positions(text: &str) -> Vec<(char, ByteIndex)> { let mut positions = Vec::new(); let mut byte_pos = ByteIndex::from(0); for ch in text.chars() { positions.push((ch, byte_pos)); byte_pos += ByteOffset::from(ch.len_utf8() as i64); } positions } #[cfg(test)] mod tests { use super::*; #[test] fn test_source_file() { let source = "let x = 42;\nlet y = x + 1;\nprint(y);"; let file = SourceFile::new("test.lang".to_string(), source.to_string()); // Test line index calculation assert_eq!(file.line_index(ByteIndex::from(0)), LineIndex::from(0)); assert_eq!(file.line_index(ByteIndex::from(12)), LineIndex::from(1)); assert_eq!(file.line_index(ByteIndex::from(27)), LineIndex::from(2)); // Test column index calculation assert_eq!(file.column_index(ByteIndex::from(4)), ColumnIndex::from(4)); assert_eq!(file.column_index(ByteIndex::from(16)), ColumnIndex::from(4)); // Test location let loc = file.location(ByteIndex::from(16)); assert_eq!(loc.line, LineIndex::from(1)); assert_eq!(loc.column, ColumnIndex::from(4)); } #[test] fn test_span_manager() { let mut manager = SpanManager::new(); let file1 = manager.add_file("main.lang".to_string(), "let x = 10;".to_string()); let file2 = manager.add_file( "lib.lang".to_string(), "function add(a, b) { return a + b; }".to_string(), ); assert!(manager.get_file(file1).is_some()); assert!(manager.get_file(file2).is_some()); assert_eq!(manager.find_file("main.lang"), Some(file1)); assert_eq!(manager.find_file("lib.lang"), Some(file2)); } #[test] fn test_lexer() { let mut manager = SpanManager::new(); let file_id = manager.add_file("test.lang".to_string(), "let x = 42 + 3;".to_string()); let mut lexer = Lexer::new("let x = 42 + 3;".to_string(), file_id); let tokens = lexer.tokenize(); assert_eq!(tokens.len(), 7); // Check first token (let) assert_eq!(tokens[0].kind, TokenKind::Keyword(Keyword::Let)); assert_eq!(tokens[0].span.start(), ByteIndex::from(0)); assert_eq!(tokens[0].span.end(), ByteIndex::from(3)); // Check identifier assert_eq!(tokens[1].kind, TokenKind::Identifier("x".to_string())); // Check equals operator assert_eq!(tokens[2].kind, TokenKind::Operator(Operator::Assign)); // Check number assert_eq!(tokens[3].kind, TokenKind::Number(42)); // Check plus operator assert_eq!(tokens[4].kind, TokenKind::Operator(Operator::Plus)); // Check second number assert_eq!(tokens[5].kind, TokenKind::Number(3)); // Check semicolon assert_eq!(tokens[6].kind, TokenKind::Delimiter(Delimiter::Semicolon)); } #[test] fn test_span_arithmetic() { demonstrate_span_arithmetic(); } #[test] fn test_line_offsets() { demonstrate_line_offsets(); } #[test] fn test_utf8_tracking() { let text = "hello 世界!"; let positions = track_utf8_positions(text); // ASCII characters take 1 byte assert_eq!(positions[0], ('h', ByteIndex::from(0))); assert_eq!(positions[5], (' ', ByteIndex::from(5))); // Chinese characters take 3 bytes each assert_eq!(positions[6], ('世', ByteIndex::from(6))); assert_eq!(positions[7], ('界', ByteIndex::from(9))); assert_eq!(positions[8], ('!', ByteIndex::from(12))); } #[test] fn test_span_merging() { let manager = SpanManager::new(); let span1 = manager.create_span(ByteIndex::from(10), ByteIndex::from(20)); let span2 = manager.create_span(ByteIndex::from(15), ByteIndex::from(25)); let merged = manager.merge_spans(span1, span2); assert_eq!(merged.start(), ByteIndex::from(10)); assert_eq!(merged.end(), ByteIndex::from(25)); } } /// A simple lexer using codespan for span tracking pub struct Lexer { input: String, position: usize, file_id: FileId, } }
The lexer maintains position state while scanning input text. It tracks byte positions rather than character indices to handle UTF-8 text correctly.
#![allow(unused)] fn main() { pub fn tokenize(&mut self) -> Vec<Token<TokenKind>> { let mut tokens = Vec::new(); while !self.is_eof() { self.skip_whitespace(); if self.is_eof() { break; } let start = ByteIndex::from(self.position as u32); if let Some(token) = self.scan_token() { let end = ByteIndex::from(self.position as u32); let span = Span::new(start, end); tokens.push(Token::new(token, span, self.file_id)); } } tokens } }
Tokenization produces a stream of tokens with associated spans. Each token records its start and end positions, enabling accurate source mapping for error messages and debugging information.
#![allow(unused)] fn main() { fn scan_token(&mut self) -> Option<TokenKind> { let start_char = self.current_char()?; match start_char { '+' => { self.advance(); Some(TokenKind::Operator(Operator::Plus)) } '-' => { self.advance(); Some(TokenKind::Operator(Operator::Minus)) } '*' => { self.advance(); Some(TokenKind::Operator(Operator::Star)) } '/' => { self.advance(); Some(TokenKind::Operator(Operator::Slash)) } '=' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::Equal)) } else { Some(TokenKind::Operator(Operator::Assign)) } } '!' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::NotEqual)) } else { None } } '<' => { self.advance(); Some(TokenKind::Operator(Operator::Less)) } '>' => { self.advance(); Some(TokenKind::Operator(Operator::Greater)) } '(' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftParen)) } ')' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightParen)) } '{' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBrace)) } '}' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBrace)) } '[' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBracket)) } ']' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBracket)) } ';' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Semicolon)) } ',' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Comma)) } '"' => self.scan_string(), c if c.is_ascii_digit() => self.scan_number(), c if c.is_ascii_alphabetic() || c == '_' => self.scan_identifier_or_keyword(), _ => { self.advance(); None } } } }
Token scanning dispatches on the first character to identify token types. The function advances through the input while tracking byte positions for span construction.
#![allow(unused)] fn main() { fn scan_string(&mut self) -> Option<TokenKind> { self.advance(); // consume opening quote let start = self.position; while !self.is_eof() && self.current_char() != Some('"') { if self.current_char() == Some('\\') { self.advance(); // consume backslash if !self.is_eof() { self.advance(); // consume escaped character } } else { self.advance(); } } if self.current_char() == Some('"') { let content = self.input[start..self.position].to_string(); self.advance(); // consume closing quote Some(TokenKind::String(content)) } else { None // unterminated string } } }
String scanning demonstrates escape sequence handling while maintaining accurate span information. The lexer tracks positions through escaped characters to ensure spans accurately reflect source locations.
Position Arithmetic
#![allow(unused)] fn main() { use std::collections::HashMap; use std::fmt; use codespan::{ByteIndex, ByteOffset, ColumnIndex, LineIndex, LineOffset, Span}; /// A source file with span tracking #[derive(Debug, Clone)] pub struct SourceFile { name: String, contents: String, line_starts: Vec<ByteIndex>, } impl SourceFile { pub fn new(name: String, contents: String) -> Self { let line_starts = std::iter::once(ByteIndex::from(0)) .chain(contents.char_indices().filter_map(|(i, c)| { if c == '\n' { Some(ByteIndex::from(i as u32 + 1)) } else { None } })) .collect(); Self { name, contents, line_starts, } } pub fn name(&self) -> &str { &self.name } pub fn contents(&self) -> &str { &self.contents } pub fn line_index(&self, byte_index: ByteIndex) -> LineIndex { match self.line_starts.binary_search(&byte_index) { Ok(line) => LineIndex::from(line as u32), Err(next_line) => LineIndex::from((next_line as u32).saturating_sub(1)), } } pub fn column_index(&self, byte_index: ByteIndex) -> ColumnIndex { let line_index = self.line_index(byte_index); let line_start = self.line_starts[line_index.to_usize()]; let column_offset = byte_index - line_start; ColumnIndex::from(column_offset.to_usize() as u32) } pub fn location(&self, byte_index: ByteIndex) -> Location { Location { line: self.line_index(byte_index), column: self.column_index(byte_index), } } pub fn slice(&self, span: Span) -> &str { let start = span.start().to_usize(); let end = span.end().to_usize(); &self.contents[start..end] } } /// A location in a source file #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Location { pub line: LineIndex, pub column: ColumnIndex, } impl fmt::Display for Location { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "{}:{}", self.line.to_usize() + 1, self.column.to_usize() + 1 ) } } /// A span manager for tracking multiple source files pub struct SpanManager { files: Vec<SourceFile>, file_map: HashMap<String, usize>, } impl SpanManager { pub fn new() -> Self { Self { files: Vec::new(), file_map: HashMap::new(), } } pub fn add_file(&mut self, name: String, contents: String) -> FileId { let file_id = FileId(self.files.len()); let file = SourceFile::new(name.clone(), contents); self.files.push(file); self.file_map.insert(name, file_id.0); file_id } pub fn get_file(&self, id: FileId) -> Option<&SourceFile> { self.files.get(id.0) } pub fn find_file(&self, name: &str) -> Option<FileId> { self.file_map.get(name).map(|&id| FileId(id)) } pub fn create_span(&self, start: ByteIndex, end: ByteIndex) -> Span { Span::new(start, end) } pub fn merge_spans(&self, first: Span, second: Span) -> Span { let start = first.start().min(second.start()); let end = first.end().max(second.end()); Span::new(start, end) } } impl Default for SpanManager { fn default() -> Self { Self::new() } } /// File identifier #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct FileId(usize); /// A token with span information #[derive(Debug, Clone)] pub struct Token<T> { pub kind: T, pub span: Span, pub file_id: FileId, } impl<T> Token<T> { pub fn new(kind: T, span: Span, file_id: FileId) -> Self { Self { kind, span, file_id, } } } /// Example token types for demonstration #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), String(String), Keyword(Keyword), Operator(Operator), Delimiter(Delimiter), } #[derive(Debug, Clone, PartialEq)] pub enum Keyword { Let, If, Else, While, Function, Return, } #[derive(Debug, Clone, PartialEq)] pub enum Operator { Plus, Minus, Star, Slash, Equal, NotEqual, Less, Greater, Assign, } #[derive(Debug, Clone, PartialEq)] pub enum Delimiter { LeftParen, RightParen, LeftBrace, RightBrace, LeftBracket, RightBracket, Semicolon, Comma, } /// A simple lexer using codespan for span tracking pub struct Lexer { input: String, position: usize, file_id: FileId, } impl Lexer { pub fn new(input: String, file_id: FileId) -> Self { Self { input, position: 0, file_id, } } pub fn tokenize(&mut self) -> Vec<Token<TokenKind>> { let mut tokens = Vec::new(); while !self.is_eof() { self.skip_whitespace(); if self.is_eof() { break; } let start = ByteIndex::from(self.position as u32); if let Some(token) = self.scan_token() { let end = ByteIndex::from(self.position as u32); let span = Span::new(start, end); tokens.push(Token::new(token, span, self.file_id)); } } tokens } fn scan_token(&mut self) -> Option<TokenKind> { let start_char = self.current_char()?; match start_char { '+' => { self.advance(); Some(TokenKind::Operator(Operator::Plus)) } '-' => { self.advance(); Some(TokenKind::Operator(Operator::Minus)) } '*' => { self.advance(); Some(TokenKind::Operator(Operator::Star)) } '/' => { self.advance(); Some(TokenKind::Operator(Operator::Slash)) } '=' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::Equal)) } else { Some(TokenKind::Operator(Operator::Assign)) } } '!' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::NotEqual)) } else { None } } '<' => { self.advance(); Some(TokenKind::Operator(Operator::Less)) } '>' => { self.advance(); Some(TokenKind::Operator(Operator::Greater)) } '(' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftParen)) } ')' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightParen)) } '{' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBrace)) } '}' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBrace)) } '[' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBracket)) } ']' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBracket)) } ';' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Semicolon)) } ',' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Comma)) } '"' => self.scan_string(), c if c.is_ascii_digit() => self.scan_number(), c if c.is_ascii_alphabetic() || c == '_' => self.scan_identifier_or_keyword(), _ => { self.advance(); None } } } fn scan_string(&mut self) -> Option<TokenKind> { self.advance(); // consume opening quote let start = self.position; while !self.is_eof() && self.current_char() != Some('"') { if self.current_char() == Some('\\') { self.advance(); // consume backslash if !self.is_eof() { self.advance(); // consume escaped character } } else { self.advance(); } } if self.current_char() == Some('"') { let content = self.input[start..self.position].to_string(); self.advance(); // consume closing quote Some(TokenKind::String(content)) } else { None // unterminated string } } fn scan_number(&mut self) -> Option<TokenKind> { let start = self.position; while !self.is_eof() && self.current_char().is_some_and(|c| c.is_ascii_digit()) { self.advance(); } let num_str = &self.input[start..self.position]; num_str.parse().ok().map(TokenKind::Number) } fn scan_identifier_or_keyword(&mut self) -> Option<TokenKind> { let start = self.position; while !self.is_eof() { if let Some(c) = self.current_char() { if c.is_ascii_alphanumeric() || c == '_' { self.advance(); } else { break; } } else { break; } } let ident = &self.input[start..self.position]; let token = match ident { "let" => TokenKind::Keyword(Keyword::Let), "if" => TokenKind::Keyword(Keyword::If), "else" => TokenKind::Keyword(Keyword::Else), "while" => TokenKind::Keyword(Keyword::While), "function" => TokenKind::Keyword(Keyword::Function), "return" => TokenKind::Keyword(Keyword::Return), _ => TokenKind::Identifier(ident.to_string()), }; Some(token) } fn skip_whitespace(&mut self) { while !self.is_eof() && self.current_char().is_some_and(|c| c.is_whitespace()) { self.advance(); } } fn current_char(&self) -> Option<char> { self.input[self.position..].chars().next() } fn advance(&mut self) { if let Some(c) = self.current_char() { self.position += c.len_utf8(); } } fn is_eof(&self) -> bool { self.position >= self.input.len() } } /// Line offset calculations pub fn demonstrate_line_offsets() { let line = LineIndex::from(5); let offset = LineOffset::from(3); // Moving forward by lines let new_line = line + offset; assert_eq!(new_line, LineIndex::from(8)); // Moving backward by lines let prev_line = new_line - offset; assert_eq!(prev_line, line); } /// UTF-8 aware position tracking pub fn track_utf8_positions(text: &str) -> Vec<(char, ByteIndex)> { let mut positions = Vec::new(); let mut byte_pos = ByteIndex::from(0); for ch in text.chars() { positions.push((ch, byte_pos)); byte_pos += ByteOffset::from(ch.len_utf8() as i64); } positions } #[cfg(test)] mod tests { use super::*; #[test] fn test_source_file() { let source = "let x = 42;\nlet y = x + 1;\nprint(y);"; let file = SourceFile::new("test.lang".to_string(), source.to_string()); // Test line index calculation assert_eq!(file.line_index(ByteIndex::from(0)), LineIndex::from(0)); assert_eq!(file.line_index(ByteIndex::from(12)), LineIndex::from(1)); assert_eq!(file.line_index(ByteIndex::from(27)), LineIndex::from(2)); // Test column index calculation assert_eq!(file.column_index(ByteIndex::from(4)), ColumnIndex::from(4)); assert_eq!(file.column_index(ByteIndex::from(16)), ColumnIndex::from(4)); // Test location let loc = file.location(ByteIndex::from(16)); assert_eq!(loc.line, LineIndex::from(1)); assert_eq!(loc.column, ColumnIndex::from(4)); } #[test] fn test_span_manager() { let mut manager = SpanManager::new(); let file1 = manager.add_file("main.lang".to_string(), "let x = 10;".to_string()); let file2 = manager.add_file( "lib.lang".to_string(), "function add(a, b) { return a + b; }".to_string(), ); assert!(manager.get_file(file1).is_some()); assert!(manager.get_file(file2).is_some()); assert_eq!(manager.find_file("main.lang"), Some(file1)); assert_eq!(manager.find_file("lib.lang"), Some(file2)); } #[test] fn test_lexer() { let mut manager = SpanManager::new(); let file_id = manager.add_file("test.lang".to_string(), "let x = 42 + 3;".to_string()); let mut lexer = Lexer::new("let x = 42 + 3;".to_string(), file_id); let tokens = lexer.tokenize(); assert_eq!(tokens.len(), 7); // Check first token (let) assert_eq!(tokens[0].kind, TokenKind::Keyword(Keyword::Let)); assert_eq!(tokens[0].span.start(), ByteIndex::from(0)); assert_eq!(tokens[0].span.end(), ByteIndex::from(3)); // Check identifier assert_eq!(tokens[1].kind, TokenKind::Identifier("x".to_string())); // Check equals operator assert_eq!(tokens[2].kind, TokenKind::Operator(Operator::Assign)); // Check number assert_eq!(tokens[3].kind, TokenKind::Number(42)); // Check plus operator assert_eq!(tokens[4].kind, TokenKind::Operator(Operator::Plus)); // Check second number assert_eq!(tokens[5].kind, TokenKind::Number(3)); // Check semicolon assert_eq!(tokens[6].kind, TokenKind::Delimiter(Delimiter::Semicolon)); } #[test] fn test_span_arithmetic() { demonstrate_span_arithmetic(); } #[test] fn test_line_offsets() { demonstrate_line_offsets(); } #[test] fn test_utf8_tracking() { let text = "hello 世界!"; let positions = track_utf8_positions(text); // ASCII characters take 1 byte assert_eq!(positions[0], ('h', ByteIndex::from(0))); assert_eq!(positions[5], (' ', ByteIndex::from(5))); // Chinese characters take 3 bytes each assert_eq!(positions[6], ('世', ByteIndex::from(6))); assert_eq!(positions[7], ('界', ByteIndex::from(9))); assert_eq!(positions[8], ('!', ByteIndex::from(12))); } #[test] fn test_span_merging() { let manager = SpanManager::new(); let span1 = manager.create_span(ByteIndex::from(10), ByteIndex::from(20)); let span2 = manager.create_span(ByteIndex::from(15), ByteIndex::from(25)); let merged = manager.merge_spans(span1, span2); assert_eq!(merged.start(), ByteIndex::from(10)); assert_eq!(merged.end(), ByteIndex::from(25)); } } /// Span arithmetic demonstrations pub fn demonstrate_span_arithmetic() { let start = ByteIndex::from(10); let offset = ByteOffset::from(5); // Adding offset to index let new_position = start + offset; assert_eq!(new_position, ByteIndex::from(15)); // Subtracting offset from index let prev_position = new_position - offset; assert_eq!(prev_position, start); // Creating spans let span = Span::new(start, new_position); assert_eq!(span.start(), start); assert_eq!(span.end(), new_position); } }
Span arithmetic operations enable position calculations throughout the compiler. ByteIndex represents absolute positions while ByteOffset represents relative distances, maintaining type safety in position calculations.
#![allow(unused)] fn main() { use std::collections::HashMap; use std::fmt; use codespan::{ByteIndex, ByteOffset, ColumnIndex, LineIndex, LineOffset, Span}; /// A source file with span tracking #[derive(Debug, Clone)] pub struct SourceFile { name: String, contents: String, line_starts: Vec<ByteIndex>, } impl SourceFile { pub fn new(name: String, contents: String) -> Self { let line_starts = std::iter::once(ByteIndex::from(0)) .chain(contents.char_indices().filter_map(|(i, c)| { if c == '\n' { Some(ByteIndex::from(i as u32 + 1)) } else { None } })) .collect(); Self { name, contents, line_starts, } } pub fn name(&self) -> &str { &self.name } pub fn contents(&self) -> &str { &self.contents } pub fn line_index(&self, byte_index: ByteIndex) -> LineIndex { match self.line_starts.binary_search(&byte_index) { Ok(line) => LineIndex::from(line as u32), Err(next_line) => LineIndex::from((next_line as u32).saturating_sub(1)), } } pub fn column_index(&self, byte_index: ByteIndex) -> ColumnIndex { let line_index = self.line_index(byte_index); let line_start = self.line_starts[line_index.to_usize()]; let column_offset = byte_index - line_start; ColumnIndex::from(column_offset.to_usize() as u32) } pub fn location(&self, byte_index: ByteIndex) -> Location { Location { line: self.line_index(byte_index), column: self.column_index(byte_index), } } pub fn slice(&self, span: Span) -> &str { let start = span.start().to_usize(); let end = span.end().to_usize(); &self.contents[start..end] } } /// A location in a source file #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Location { pub line: LineIndex, pub column: ColumnIndex, } impl fmt::Display for Location { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "{}:{}", self.line.to_usize() + 1, self.column.to_usize() + 1 ) } } /// A span manager for tracking multiple source files pub struct SpanManager { files: Vec<SourceFile>, file_map: HashMap<String, usize>, } impl SpanManager { pub fn new() -> Self { Self { files: Vec::new(), file_map: HashMap::new(), } } pub fn add_file(&mut self, name: String, contents: String) -> FileId { let file_id = FileId(self.files.len()); let file = SourceFile::new(name.clone(), contents); self.files.push(file); self.file_map.insert(name, file_id.0); file_id } pub fn get_file(&self, id: FileId) -> Option<&SourceFile> { self.files.get(id.0) } pub fn find_file(&self, name: &str) -> Option<FileId> { self.file_map.get(name).map(|&id| FileId(id)) } pub fn create_span(&self, start: ByteIndex, end: ByteIndex) -> Span { Span::new(start, end) } pub fn merge_spans(&self, first: Span, second: Span) -> Span { let start = first.start().min(second.start()); let end = first.end().max(second.end()); Span::new(start, end) } } impl Default for SpanManager { fn default() -> Self { Self::new() } } /// File identifier #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct FileId(usize); /// A token with span information #[derive(Debug, Clone)] pub struct Token<T> { pub kind: T, pub span: Span, pub file_id: FileId, } impl<T> Token<T> { pub fn new(kind: T, span: Span, file_id: FileId) -> Self { Self { kind, span, file_id, } } } /// Example token types for demonstration #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), String(String), Keyword(Keyword), Operator(Operator), Delimiter(Delimiter), } #[derive(Debug, Clone, PartialEq)] pub enum Keyword { Let, If, Else, While, Function, Return, } #[derive(Debug, Clone, PartialEq)] pub enum Operator { Plus, Minus, Star, Slash, Equal, NotEqual, Less, Greater, Assign, } #[derive(Debug, Clone, PartialEq)] pub enum Delimiter { LeftParen, RightParen, LeftBrace, RightBrace, LeftBracket, RightBracket, Semicolon, Comma, } /// A simple lexer using codespan for span tracking pub struct Lexer { input: String, position: usize, file_id: FileId, } impl Lexer { pub fn new(input: String, file_id: FileId) -> Self { Self { input, position: 0, file_id, } } pub fn tokenize(&mut self) -> Vec<Token<TokenKind>> { let mut tokens = Vec::new(); while !self.is_eof() { self.skip_whitespace(); if self.is_eof() { break; } let start = ByteIndex::from(self.position as u32); if let Some(token) = self.scan_token() { let end = ByteIndex::from(self.position as u32); let span = Span::new(start, end); tokens.push(Token::new(token, span, self.file_id)); } } tokens } fn scan_token(&mut self) -> Option<TokenKind> { let start_char = self.current_char()?; match start_char { '+' => { self.advance(); Some(TokenKind::Operator(Operator::Plus)) } '-' => { self.advance(); Some(TokenKind::Operator(Operator::Minus)) } '*' => { self.advance(); Some(TokenKind::Operator(Operator::Star)) } '/' => { self.advance(); Some(TokenKind::Operator(Operator::Slash)) } '=' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::Equal)) } else { Some(TokenKind::Operator(Operator::Assign)) } } '!' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::NotEqual)) } else { None } } '<' => { self.advance(); Some(TokenKind::Operator(Operator::Less)) } '>' => { self.advance(); Some(TokenKind::Operator(Operator::Greater)) } '(' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftParen)) } ')' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightParen)) } '{' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBrace)) } '}' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBrace)) } '[' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBracket)) } ']' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBracket)) } ';' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Semicolon)) } ',' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Comma)) } '"' => self.scan_string(), c if c.is_ascii_digit() => self.scan_number(), c if c.is_ascii_alphabetic() || c == '_' => self.scan_identifier_or_keyword(), _ => { self.advance(); None } } } fn scan_string(&mut self) -> Option<TokenKind> { self.advance(); // consume opening quote let start = self.position; while !self.is_eof() && self.current_char() != Some('"') { if self.current_char() == Some('\\') { self.advance(); // consume backslash if !self.is_eof() { self.advance(); // consume escaped character } } else { self.advance(); } } if self.current_char() == Some('"') { let content = self.input[start..self.position].to_string(); self.advance(); // consume closing quote Some(TokenKind::String(content)) } else { None // unterminated string } } fn scan_number(&mut self) -> Option<TokenKind> { let start = self.position; while !self.is_eof() && self.current_char().is_some_and(|c| c.is_ascii_digit()) { self.advance(); } let num_str = &self.input[start..self.position]; num_str.parse().ok().map(TokenKind::Number) } fn scan_identifier_or_keyword(&mut self) -> Option<TokenKind> { let start = self.position; while !self.is_eof() { if let Some(c) = self.current_char() { if c.is_ascii_alphanumeric() || c == '_' { self.advance(); } else { break; } } else { break; } } let ident = &self.input[start..self.position]; let token = match ident { "let" => TokenKind::Keyword(Keyword::Let), "if" => TokenKind::Keyword(Keyword::If), "else" => TokenKind::Keyword(Keyword::Else), "while" => TokenKind::Keyword(Keyword::While), "function" => TokenKind::Keyword(Keyword::Function), "return" => TokenKind::Keyword(Keyword::Return), _ => TokenKind::Identifier(ident.to_string()), }; Some(token) } fn skip_whitespace(&mut self) { while !self.is_eof() && self.current_char().is_some_and(|c| c.is_whitespace()) { self.advance(); } } fn current_char(&self) -> Option<char> { self.input[self.position..].chars().next() } fn advance(&mut self) { if let Some(c) = self.current_char() { self.position += c.len_utf8(); } } fn is_eof(&self) -> bool { self.position >= self.input.len() } } /// Span arithmetic demonstrations pub fn demonstrate_span_arithmetic() { let start = ByteIndex::from(10); let offset = ByteOffset::from(5); // Adding offset to index let new_position = start + offset; assert_eq!(new_position, ByteIndex::from(15)); // Subtracting offset from index let prev_position = new_position - offset; assert_eq!(prev_position, start); // Creating spans let span = Span::new(start, new_position); assert_eq!(span.start(), start); assert_eq!(span.end(), new_position); } /// UTF-8 aware position tracking pub fn track_utf8_positions(text: &str) -> Vec<(char, ByteIndex)> { let mut positions = Vec::new(); let mut byte_pos = ByteIndex::from(0); for ch in text.chars() { positions.push((ch, byte_pos)); byte_pos += ByteOffset::from(ch.len_utf8() as i64); } positions } #[cfg(test)] mod tests { use super::*; #[test] fn test_source_file() { let source = "let x = 42;\nlet y = x + 1;\nprint(y);"; let file = SourceFile::new("test.lang".to_string(), source.to_string()); // Test line index calculation assert_eq!(file.line_index(ByteIndex::from(0)), LineIndex::from(0)); assert_eq!(file.line_index(ByteIndex::from(12)), LineIndex::from(1)); assert_eq!(file.line_index(ByteIndex::from(27)), LineIndex::from(2)); // Test column index calculation assert_eq!(file.column_index(ByteIndex::from(4)), ColumnIndex::from(4)); assert_eq!(file.column_index(ByteIndex::from(16)), ColumnIndex::from(4)); // Test location let loc = file.location(ByteIndex::from(16)); assert_eq!(loc.line, LineIndex::from(1)); assert_eq!(loc.column, ColumnIndex::from(4)); } #[test] fn test_span_manager() { let mut manager = SpanManager::new(); let file1 = manager.add_file("main.lang".to_string(), "let x = 10;".to_string()); let file2 = manager.add_file( "lib.lang".to_string(), "function add(a, b) { return a + b; }".to_string(), ); assert!(manager.get_file(file1).is_some()); assert!(manager.get_file(file2).is_some()); assert_eq!(manager.find_file("main.lang"), Some(file1)); assert_eq!(manager.find_file("lib.lang"), Some(file2)); } #[test] fn test_lexer() { let mut manager = SpanManager::new(); let file_id = manager.add_file("test.lang".to_string(), "let x = 42 + 3;".to_string()); let mut lexer = Lexer::new("let x = 42 + 3;".to_string(), file_id); let tokens = lexer.tokenize(); assert_eq!(tokens.len(), 7); // Check first token (let) assert_eq!(tokens[0].kind, TokenKind::Keyword(Keyword::Let)); assert_eq!(tokens[0].span.start(), ByteIndex::from(0)); assert_eq!(tokens[0].span.end(), ByteIndex::from(3)); // Check identifier assert_eq!(tokens[1].kind, TokenKind::Identifier("x".to_string())); // Check equals operator assert_eq!(tokens[2].kind, TokenKind::Operator(Operator::Assign)); // Check number assert_eq!(tokens[3].kind, TokenKind::Number(42)); // Check plus operator assert_eq!(tokens[4].kind, TokenKind::Operator(Operator::Plus)); // Check second number assert_eq!(tokens[5].kind, TokenKind::Number(3)); // Check semicolon assert_eq!(tokens[6].kind, TokenKind::Delimiter(Delimiter::Semicolon)); } #[test] fn test_span_arithmetic() { demonstrate_span_arithmetic(); } #[test] fn test_line_offsets() { demonstrate_line_offsets(); } #[test] fn test_utf8_tracking() { let text = "hello 世界!"; let positions = track_utf8_positions(text); // ASCII characters take 1 byte assert_eq!(positions[0], ('h', ByteIndex::from(0))); assert_eq!(positions[5], (' ', ByteIndex::from(5))); // Chinese characters take 3 bytes each assert_eq!(positions[6], ('世', ByteIndex::from(6))); assert_eq!(positions[7], ('界', ByteIndex::from(9))); assert_eq!(positions[8], ('!', ByteIndex::from(12))); } #[test] fn test_span_merging() { let manager = SpanManager::new(); let span1 = manager.create_span(ByteIndex::from(10), ByteIndex::from(20)); let span2 = manager.create_span(ByteIndex::from(15), ByteIndex::from(25)); let merged = manager.merge_spans(span1, span2); assert_eq!(merged.start(), ByteIndex::from(10)); assert_eq!(merged.end(), ByteIndex::from(25)); } } /// Line offset calculations pub fn demonstrate_line_offsets() { let line = LineIndex::from(5); let offset = LineOffset::from(3); // Moving forward by lines let new_line = line + offset; assert_eq!(new_line, LineIndex::from(8)); // Moving backward by lines let prev_line = new_line - offset; assert_eq!(prev_line, line); } }
Line offset calculations mirror byte-level operations at the line granularity. These operations support navigation between error locations and related positions in diagnostic output.
UTF-8 Support
#![allow(unused)] fn main() { use std::collections::HashMap; use std::fmt; use codespan::{ByteIndex, ByteOffset, ColumnIndex, LineIndex, LineOffset, Span}; /// A source file with span tracking #[derive(Debug, Clone)] pub struct SourceFile { name: String, contents: String, line_starts: Vec<ByteIndex>, } impl SourceFile { pub fn new(name: String, contents: String) -> Self { let line_starts = std::iter::once(ByteIndex::from(0)) .chain(contents.char_indices().filter_map(|(i, c)| { if c == '\n' { Some(ByteIndex::from(i as u32 + 1)) } else { None } })) .collect(); Self { name, contents, line_starts, } } pub fn name(&self) -> &str { &self.name } pub fn contents(&self) -> &str { &self.contents } pub fn line_index(&self, byte_index: ByteIndex) -> LineIndex { match self.line_starts.binary_search(&byte_index) { Ok(line) => LineIndex::from(line as u32), Err(next_line) => LineIndex::from((next_line as u32).saturating_sub(1)), } } pub fn column_index(&self, byte_index: ByteIndex) -> ColumnIndex { let line_index = self.line_index(byte_index); let line_start = self.line_starts[line_index.to_usize()]; let column_offset = byte_index - line_start; ColumnIndex::from(column_offset.to_usize() as u32) } pub fn location(&self, byte_index: ByteIndex) -> Location { Location { line: self.line_index(byte_index), column: self.column_index(byte_index), } } pub fn slice(&self, span: Span) -> &str { let start = span.start().to_usize(); let end = span.end().to_usize(); &self.contents[start..end] } } /// A location in a source file #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Location { pub line: LineIndex, pub column: ColumnIndex, } impl fmt::Display for Location { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "{}:{}", self.line.to_usize() + 1, self.column.to_usize() + 1 ) } } /// A span manager for tracking multiple source files pub struct SpanManager { files: Vec<SourceFile>, file_map: HashMap<String, usize>, } impl SpanManager { pub fn new() -> Self { Self { files: Vec::new(), file_map: HashMap::new(), } } pub fn add_file(&mut self, name: String, contents: String) -> FileId { let file_id = FileId(self.files.len()); let file = SourceFile::new(name.clone(), contents); self.files.push(file); self.file_map.insert(name, file_id.0); file_id } pub fn get_file(&self, id: FileId) -> Option<&SourceFile> { self.files.get(id.0) } pub fn find_file(&self, name: &str) -> Option<FileId> { self.file_map.get(name).map(|&id| FileId(id)) } pub fn create_span(&self, start: ByteIndex, end: ByteIndex) -> Span { Span::new(start, end) } pub fn merge_spans(&self, first: Span, second: Span) -> Span { let start = first.start().min(second.start()); let end = first.end().max(second.end()); Span::new(start, end) } } impl Default for SpanManager { fn default() -> Self { Self::new() } } /// File identifier #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct FileId(usize); /// A token with span information #[derive(Debug, Clone)] pub struct Token<T> { pub kind: T, pub span: Span, pub file_id: FileId, } impl<T> Token<T> { pub fn new(kind: T, span: Span, file_id: FileId) -> Self { Self { kind, span, file_id, } } } /// Example token types for demonstration #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), String(String), Keyword(Keyword), Operator(Operator), Delimiter(Delimiter), } #[derive(Debug, Clone, PartialEq)] pub enum Keyword { Let, If, Else, While, Function, Return, } #[derive(Debug, Clone, PartialEq)] pub enum Operator { Plus, Minus, Star, Slash, Equal, NotEqual, Less, Greater, Assign, } #[derive(Debug, Clone, PartialEq)] pub enum Delimiter { LeftParen, RightParen, LeftBrace, RightBrace, LeftBracket, RightBracket, Semicolon, Comma, } /// A simple lexer using codespan for span tracking pub struct Lexer { input: String, position: usize, file_id: FileId, } impl Lexer { pub fn new(input: String, file_id: FileId) -> Self { Self { input, position: 0, file_id, } } pub fn tokenize(&mut self) -> Vec<Token<TokenKind>> { let mut tokens = Vec::new(); while !self.is_eof() { self.skip_whitespace(); if self.is_eof() { break; } let start = ByteIndex::from(self.position as u32); if let Some(token) = self.scan_token() { let end = ByteIndex::from(self.position as u32); let span = Span::new(start, end); tokens.push(Token::new(token, span, self.file_id)); } } tokens } fn scan_token(&mut self) -> Option<TokenKind> { let start_char = self.current_char()?; match start_char { '+' => { self.advance(); Some(TokenKind::Operator(Operator::Plus)) } '-' => { self.advance(); Some(TokenKind::Operator(Operator::Minus)) } '*' => { self.advance(); Some(TokenKind::Operator(Operator::Star)) } '/' => { self.advance(); Some(TokenKind::Operator(Operator::Slash)) } '=' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::Equal)) } else { Some(TokenKind::Operator(Operator::Assign)) } } '!' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::NotEqual)) } else { None } } '<' => { self.advance(); Some(TokenKind::Operator(Operator::Less)) } '>' => { self.advance(); Some(TokenKind::Operator(Operator::Greater)) } '(' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftParen)) } ')' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightParen)) } '{' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBrace)) } '}' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBrace)) } '[' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBracket)) } ']' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBracket)) } ';' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Semicolon)) } ',' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Comma)) } '"' => self.scan_string(), c if c.is_ascii_digit() => self.scan_number(), c if c.is_ascii_alphabetic() || c == '_' => self.scan_identifier_or_keyword(), _ => { self.advance(); None } } } fn scan_string(&mut self) -> Option<TokenKind> { self.advance(); // consume opening quote let start = self.position; while !self.is_eof() && self.current_char() != Some('"') { if self.current_char() == Some('\\') { self.advance(); // consume backslash if !self.is_eof() { self.advance(); // consume escaped character } } else { self.advance(); } } if self.current_char() == Some('"') { let content = self.input[start..self.position].to_string(); self.advance(); // consume closing quote Some(TokenKind::String(content)) } else { None // unterminated string } } fn scan_number(&mut self) -> Option<TokenKind> { let start = self.position; while !self.is_eof() && self.current_char().is_some_and(|c| c.is_ascii_digit()) { self.advance(); } let num_str = &self.input[start..self.position]; num_str.parse().ok().map(TokenKind::Number) } fn scan_identifier_or_keyword(&mut self) -> Option<TokenKind> { let start = self.position; while !self.is_eof() { if let Some(c) = self.current_char() { if c.is_ascii_alphanumeric() || c == '_' { self.advance(); } else { break; } } else { break; } } let ident = &self.input[start..self.position]; let token = match ident { "let" => TokenKind::Keyword(Keyword::Let), "if" => TokenKind::Keyword(Keyword::If), "else" => TokenKind::Keyword(Keyword::Else), "while" => TokenKind::Keyword(Keyword::While), "function" => TokenKind::Keyword(Keyword::Function), "return" => TokenKind::Keyword(Keyword::Return), _ => TokenKind::Identifier(ident.to_string()), }; Some(token) } fn skip_whitespace(&mut self) { while !self.is_eof() && self.current_char().is_some_and(|c| c.is_whitespace()) { self.advance(); } } fn current_char(&self) -> Option<char> { self.input[self.position..].chars().next() } fn advance(&mut self) { if let Some(c) = self.current_char() { self.position += c.len_utf8(); } } fn is_eof(&self) -> bool { self.position >= self.input.len() } } /// Span arithmetic demonstrations pub fn demonstrate_span_arithmetic() { let start = ByteIndex::from(10); let offset = ByteOffset::from(5); // Adding offset to index let new_position = start + offset; assert_eq!(new_position, ByteIndex::from(15)); // Subtracting offset from index let prev_position = new_position - offset; assert_eq!(prev_position, start); // Creating spans let span = Span::new(start, new_position); assert_eq!(span.start(), start); assert_eq!(span.end(), new_position); } /// Line offset calculations pub fn demonstrate_line_offsets() { let line = LineIndex::from(5); let offset = LineOffset::from(3); // Moving forward by lines let new_line = line + offset; assert_eq!(new_line, LineIndex::from(8)); // Moving backward by lines let prev_line = new_line - offset; assert_eq!(prev_line, line); } #[cfg(test)] mod tests { use super::*; #[test] fn test_source_file() { let source = "let x = 42;\nlet y = x + 1;\nprint(y);"; let file = SourceFile::new("test.lang".to_string(), source.to_string()); // Test line index calculation assert_eq!(file.line_index(ByteIndex::from(0)), LineIndex::from(0)); assert_eq!(file.line_index(ByteIndex::from(12)), LineIndex::from(1)); assert_eq!(file.line_index(ByteIndex::from(27)), LineIndex::from(2)); // Test column index calculation assert_eq!(file.column_index(ByteIndex::from(4)), ColumnIndex::from(4)); assert_eq!(file.column_index(ByteIndex::from(16)), ColumnIndex::from(4)); // Test location let loc = file.location(ByteIndex::from(16)); assert_eq!(loc.line, LineIndex::from(1)); assert_eq!(loc.column, ColumnIndex::from(4)); } #[test] fn test_span_manager() { let mut manager = SpanManager::new(); let file1 = manager.add_file("main.lang".to_string(), "let x = 10;".to_string()); let file2 = manager.add_file( "lib.lang".to_string(), "function add(a, b) { return a + b; }".to_string(), ); assert!(manager.get_file(file1).is_some()); assert!(manager.get_file(file2).is_some()); assert_eq!(manager.find_file("main.lang"), Some(file1)); assert_eq!(manager.find_file("lib.lang"), Some(file2)); } #[test] fn test_lexer() { let mut manager = SpanManager::new(); let file_id = manager.add_file("test.lang".to_string(), "let x = 42 + 3;".to_string()); let mut lexer = Lexer::new("let x = 42 + 3;".to_string(), file_id); let tokens = lexer.tokenize(); assert_eq!(tokens.len(), 7); // Check first token (let) assert_eq!(tokens[0].kind, TokenKind::Keyword(Keyword::Let)); assert_eq!(tokens[0].span.start(), ByteIndex::from(0)); assert_eq!(tokens[0].span.end(), ByteIndex::from(3)); // Check identifier assert_eq!(tokens[1].kind, TokenKind::Identifier("x".to_string())); // Check equals operator assert_eq!(tokens[2].kind, TokenKind::Operator(Operator::Assign)); // Check number assert_eq!(tokens[3].kind, TokenKind::Number(42)); // Check plus operator assert_eq!(tokens[4].kind, TokenKind::Operator(Operator::Plus)); // Check second number assert_eq!(tokens[5].kind, TokenKind::Number(3)); // Check semicolon assert_eq!(tokens[6].kind, TokenKind::Delimiter(Delimiter::Semicolon)); } #[test] fn test_span_arithmetic() { demonstrate_span_arithmetic(); } #[test] fn test_line_offsets() { demonstrate_line_offsets(); } #[test] fn test_utf8_tracking() { let text = "hello 世界!"; let positions = track_utf8_positions(text); // ASCII characters take 1 byte assert_eq!(positions[0], ('h', ByteIndex::from(0))); assert_eq!(positions[5], (' ', ByteIndex::from(5))); // Chinese characters take 3 bytes each assert_eq!(positions[6], ('世', ByteIndex::from(6))); assert_eq!(positions[7], ('界', ByteIndex::from(9))); assert_eq!(positions[8], ('!', ByteIndex::from(12))); } #[test] fn test_span_merging() { let manager = SpanManager::new(); let span1 = manager.create_span(ByteIndex::from(10), ByteIndex::from(20)); let span2 = manager.create_span(ByteIndex::from(15), ByteIndex::from(25)); let merged = manager.merge_spans(span1, span2); assert_eq!(merged.start(), ByteIndex::from(10)); assert_eq!(merged.end(), ByteIndex::from(25)); } } /// UTF-8 aware position tracking pub fn track_utf8_positions(text: &str) -> Vec<(char, ByteIndex)> { let mut positions = Vec::new(); let mut byte_pos = ByteIndex::from(0); for ch in text.chars() { positions.push((ch, byte_pos)); byte_pos += ByteOffset::from(ch.len_utf8() as i64); } positions } }
UTF-8 position tracking correctly handles variable-width characters. The function accumulates byte offsets based on actual character encoding lengths rather than assuming fixed-width characters.
Location Display
#![allow(unused)] fn main() { use std::collections::HashMap; use std::fmt; use codespan::{ByteIndex, ByteOffset, ColumnIndex, LineIndex, LineOffset, Span}; /// A source file with span tracking #[derive(Debug, Clone)] pub struct SourceFile { name: String, contents: String, line_starts: Vec<ByteIndex>, } impl SourceFile { pub fn new(name: String, contents: String) -> Self { let line_starts = std::iter::once(ByteIndex::from(0)) .chain(contents.char_indices().filter_map(|(i, c)| { if c == '\n' { Some(ByteIndex::from(i as u32 + 1)) } else { None } })) .collect(); Self { name, contents, line_starts, } } pub fn name(&self) -> &str { &self.name } pub fn contents(&self) -> &str { &self.contents } pub fn line_index(&self, byte_index: ByteIndex) -> LineIndex { match self.line_starts.binary_search(&byte_index) { Ok(line) => LineIndex::from(line as u32), Err(next_line) => LineIndex::from((next_line as u32).saturating_sub(1)), } } pub fn column_index(&self, byte_index: ByteIndex) -> ColumnIndex { let line_index = self.line_index(byte_index); let line_start = self.line_starts[line_index.to_usize()]; let column_offset = byte_index - line_start; ColumnIndex::from(column_offset.to_usize() as u32) } pub fn location(&self, byte_index: ByteIndex) -> Location { Location { line: self.line_index(byte_index), column: self.column_index(byte_index), } } pub fn slice(&self, span: Span) -> &str { let start = span.start().to_usize(); let end = span.end().to_usize(); &self.contents[start..end] } } impl fmt::Display for Location { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, "{}:{}", self.line.to_usize() + 1, self.column.to_usize() + 1 ) } } /// A span manager for tracking multiple source files pub struct SpanManager { files: Vec<SourceFile>, file_map: HashMap<String, usize>, } impl SpanManager { pub fn new() -> Self { Self { files: Vec::new(), file_map: HashMap::new(), } } pub fn add_file(&mut self, name: String, contents: String) -> FileId { let file_id = FileId(self.files.len()); let file = SourceFile::new(name.clone(), contents); self.files.push(file); self.file_map.insert(name, file_id.0); file_id } pub fn get_file(&self, id: FileId) -> Option<&SourceFile> { self.files.get(id.0) } pub fn find_file(&self, name: &str) -> Option<FileId> { self.file_map.get(name).map(|&id| FileId(id)) } pub fn create_span(&self, start: ByteIndex, end: ByteIndex) -> Span { Span::new(start, end) } pub fn merge_spans(&self, first: Span, second: Span) -> Span { let start = first.start().min(second.start()); let end = first.end().max(second.end()); Span::new(start, end) } } impl Default for SpanManager { fn default() -> Self { Self::new() } } /// File identifier #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct FileId(usize); /// A token with span information #[derive(Debug, Clone)] pub struct Token<T> { pub kind: T, pub span: Span, pub file_id: FileId, } impl<T> Token<T> { pub fn new(kind: T, span: Span, file_id: FileId) -> Self { Self { kind, span, file_id, } } } /// Example token types for demonstration #[derive(Debug, Clone, PartialEq)] pub enum TokenKind { Identifier(String), Number(i64), String(String), Keyword(Keyword), Operator(Operator), Delimiter(Delimiter), } #[derive(Debug, Clone, PartialEq)] pub enum Keyword { Let, If, Else, While, Function, Return, } #[derive(Debug, Clone, PartialEq)] pub enum Operator { Plus, Minus, Star, Slash, Equal, NotEqual, Less, Greater, Assign, } #[derive(Debug, Clone, PartialEq)] pub enum Delimiter { LeftParen, RightParen, LeftBrace, RightBrace, LeftBracket, RightBracket, Semicolon, Comma, } /// A simple lexer using codespan for span tracking pub struct Lexer { input: String, position: usize, file_id: FileId, } impl Lexer { pub fn new(input: String, file_id: FileId) -> Self { Self { input, position: 0, file_id, } } pub fn tokenize(&mut self) -> Vec<Token<TokenKind>> { let mut tokens = Vec::new(); while !self.is_eof() { self.skip_whitespace(); if self.is_eof() { break; } let start = ByteIndex::from(self.position as u32); if let Some(token) = self.scan_token() { let end = ByteIndex::from(self.position as u32); let span = Span::new(start, end); tokens.push(Token::new(token, span, self.file_id)); } } tokens } fn scan_token(&mut self) -> Option<TokenKind> { let start_char = self.current_char()?; match start_char { '+' => { self.advance(); Some(TokenKind::Operator(Operator::Plus)) } '-' => { self.advance(); Some(TokenKind::Operator(Operator::Minus)) } '*' => { self.advance(); Some(TokenKind::Operator(Operator::Star)) } '/' => { self.advance(); Some(TokenKind::Operator(Operator::Slash)) } '=' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::Equal)) } else { Some(TokenKind::Operator(Operator::Assign)) } } '!' => { self.advance(); if self.current_char() == Some('=') { self.advance(); Some(TokenKind::Operator(Operator::NotEqual)) } else { None } } '<' => { self.advance(); Some(TokenKind::Operator(Operator::Less)) } '>' => { self.advance(); Some(TokenKind::Operator(Operator::Greater)) } '(' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftParen)) } ')' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightParen)) } '{' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBrace)) } '}' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBrace)) } '[' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::LeftBracket)) } ']' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::RightBracket)) } ';' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Semicolon)) } ',' => { self.advance(); Some(TokenKind::Delimiter(Delimiter::Comma)) } '"' => self.scan_string(), c if c.is_ascii_digit() => self.scan_number(), c if c.is_ascii_alphabetic() || c == '_' => self.scan_identifier_or_keyword(), _ => { self.advance(); None } } } fn scan_string(&mut self) -> Option<TokenKind> { self.advance(); // consume opening quote let start = self.position; while !self.is_eof() && self.current_char() != Some('"') { if self.current_char() == Some('\\') { self.advance(); // consume backslash if !self.is_eof() { self.advance(); // consume escaped character } } else { self.advance(); } } if self.current_char() == Some('"') { let content = self.input[start..self.position].to_string(); self.advance(); // consume closing quote Some(TokenKind::String(content)) } else { None // unterminated string } } fn scan_number(&mut self) -> Option<TokenKind> { let start = self.position; while !self.is_eof() && self.current_char().is_some_and(|c| c.is_ascii_digit()) { self.advance(); } let num_str = &self.input[start..self.position]; num_str.parse().ok().map(TokenKind::Number) } fn scan_identifier_or_keyword(&mut self) -> Option<TokenKind> { let start = self.position; while !self.is_eof() { if let Some(c) = self.current_char() { if c.is_ascii_alphanumeric() || c == '_' { self.advance(); } else { break; } } else { break; } } let ident = &self.input[start..self.position]; let token = match ident { "let" => TokenKind::Keyword(Keyword::Let), "if" => TokenKind::Keyword(Keyword::If), "else" => TokenKind::Keyword(Keyword::Else), "while" => TokenKind::Keyword(Keyword::While), "function" => TokenKind::Keyword(Keyword::Function), "return" => TokenKind::Keyword(Keyword::Return), _ => TokenKind::Identifier(ident.to_string()), }; Some(token) } fn skip_whitespace(&mut self) { while !self.is_eof() && self.current_char().is_some_and(|c| c.is_whitespace()) { self.advance(); } } fn current_char(&self) -> Option<char> { self.input[self.position..].chars().next() } fn advance(&mut self) { if let Some(c) = self.current_char() { self.position += c.len_utf8(); } } fn is_eof(&self) -> bool { self.position >= self.input.len() } } /// Span arithmetic demonstrations pub fn demonstrate_span_arithmetic() { let start = ByteIndex::from(10); let offset = ByteOffset::from(5); // Adding offset to index let new_position = start + offset; assert_eq!(new_position, ByteIndex::from(15)); // Subtracting offset from index let prev_position = new_position - offset; assert_eq!(prev_position, start); // Creating spans let span = Span::new(start, new_position); assert_eq!(span.start(), start); assert_eq!(span.end(), new_position); } /// Line offset calculations pub fn demonstrate_line_offsets() { let line = LineIndex::from(5); let offset = LineOffset::from(3); // Moving forward by lines let new_line = line + offset; assert_eq!(new_line, LineIndex::from(8)); // Moving backward by lines let prev_line = new_line - offset; assert_eq!(prev_line, line); } /// UTF-8 aware position tracking pub fn track_utf8_positions(text: &str) -> Vec<(char, ByteIndex)> { let mut positions = Vec::new(); let mut byte_pos = ByteIndex::from(0); for ch in text.chars() { positions.push((ch, byte_pos)); byte_pos += ByteOffset::from(ch.len_utf8() as i64); } positions } #[cfg(test)] mod tests { use super::*; #[test] fn test_source_file() { let source = "let x = 42;\nlet y = x + 1;\nprint(y);"; let file = SourceFile::new("test.lang".to_string(), source.to_string()); // Test line index calculation assert_eq!(file.line_index(ByteIndex::from(0)), LineIndex::from(0)); assert_eq!(file.line_index(ByteIndex::from(12)), LineIndex::from(1)); assert_eq!(file.line_index(ByteIndex::from(27)), LineIndex::from(2)); // Test column index calculation assert_eq!(file.column_index(ByteIndex::from(4)), ColumnIndex::from(4)); assert_eq!(file.column_index(ByteIndex::from(16)), ColumnIndex::from(4)); // Test location let loc = file.location(ByteIndex::from(16)); assert_eq!(loc.line, LineIndex::from(1)); assert_eq!(loc.column, ColumnIndex::from(4)); } #[test] fn test_span_manager() { let mut manager = SpanManager::new(); let file1 = manager.add_file("main.lang".to_string(), "let x = 10;".to_string()); let file2 = manager.add_file( "lib.lang".to_string(), "function add(a, b) { return a + b; }".to_string(), ); assert!(manager.get_file(file1).is_some()); assert!(manager.get_file(file2).is_some()); assert_eq!(manager.find_file("main.lang"), Some(file1)); assert_eq!(manager.find_file("lib.lang"), Some(file2)); } #[test] fn test_lexer() { let mut manager = SpanManager::new(); let file_id = manager.add_file("test.lang".to_string(), "let x = 42 + 3;".to_string()); let mut lexer = Lexer::new("let x = 42 + 3;".to_string(), file_id); let tokens = lexer.tokenize(); assert_eq!(tokens.len(), 7); // Check first token (let) assert_eq!(tokens[0].kind, TokenKind::Keyword(Keyword::Let)); assert_eq!(tokens[0].span.start(), ByteIndex::from(0)); assert_eq!(tokens[0].span.end(), ByteIndex::from(3)); // Check identifier assert_eq!(tokens[1].kind, TokenKind::Identifier("x".to_string())); // Check equals operator assert_eq!(tokens[2].kind, TokenKind::Operator(Operator::Assign)); // Check number assert_eq!(tokens[3].kind, TokenKind::Number(42)); // Check plus operator assert_eq!(tokens[4].kind, TokenKind::Operator(Operator::Plus)); // Check second number assert_eq!(tokens[5].kind, TokenKind::Number(3)); // Check semicolon assert_eq!(tokens[6].kind, TokenKind::Delimiter(Delimiter::Semicolon)); } #[test] fn test_span_arithmetic() { demonstrate_span_arithmetic(); } #[test] fn test_line_offsets() { demonstrate_line_offsets(); } #[test] fn test_utf8_tracking() { let text = "hello 世界!"; let positions = track_utf8_positions(text); // ASCII characters take 1 byte assert_eq!(positions[0], ('h', ByteIndex::from(0))); assert_eq!(positions[5], (' ', ByteIndex::from(5))); // Chinese characters take 3 bytes each assert_eq!(positions[6], ('世', ByteIndex::from(6))); assert_eq!(positions[7], ('界', ByteIndex::from(9))); assert_eq!(positions[8], ('!', ByteIndex::from(12))); } #[test] fn test_span_merging() { let manager = SpanManager::new(); let span1 = manager.create_span(ByteIndex::from(10), ByteIndex::from(20)); let span2 = manager.create_span(ByteIndex::from(15), ByteIndex::from(25)); let merged = manager.merge_spans(span1, span2); assert_eq!(merged.start(), ByteIndex::from(10)); assert_eq!(merged.end(), ByteIndex::from(25)); } } /// A location in a source file #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Location { pub line: LineIndex, pub column: ColumnIndex, } }
The Location structure provides human-readable position information. Line and column indices are zero-based internally but typically displayed with one-based numbering for user interfaces.
Testing Strategies
The test suite validates core positioning operations across various text encodings. UTF-8 handling receives particular attention given the complexity of multi-byte character sequences.
Binary search performance in line lookup benefits from representative test data. Large files with many lines exercise the logarithmic lookup behavior that makes the approach scalable.
Span merging tests verify the commutativity and associativity properties that simplify span combination logic. These algebraic properties enable optimization passes to freely reorganize span calculations.
Integration Patterns
Codespan integrates naturally with parser combinators and parser generators. Parser libraries can construct spans from their internal position tracking, while codespan provides the arithmetic operations for span manipulation.
Error reporting libraries like codespan-reporting build on these primitives to provide rich diagnostic output. The separation of concerns keeps span tracking focused and reusable across different diagnostic frameworks.
Incremental compilation benefits from stable span representations. Source positions remain valid across incremental updates when unchanged regions retain their byte offsets.
Performance Considerations
Line start precomputation trades memory for lookup speed. For typical source files, the line index overhead remains negligible compared to the source text itself.
Binary search for line lookup provides consistent performance regardless of file size. This scalability matters for generated code or concatenated source files that may contain thousands of lines.
Span creation involves only copying two integers, making it efficient to track spans pervasively. The immutable nature of spans enables sharing without synchronization overhead.
Best Practices
Maintain span information throughout compilation rather than reconstructing it when needed. Early span loss complicates error reporting and prevents accurate source mapping.
Use typed position wrappers rather than raw integers to prevent unit confusion. The distinction between ByteIndex and ByteOffset catches common arithmetic errors at compile time.
Prefer byte positions over character positions for internal representation. Byte positions provide unambiguous locations in UTF-8 text while character positions require encoding assumptions.
Design token types to be lightweight since they carry span information. Large token payloads amplify memory usage when every token includes span data.
Codespan’s focused approach to span tracking provides a solid foundation for compiler diagnostics. The crate’s emphasis on correctness, particularly regarding UTF-8 handling, makes it suitable for production language implementations where accurate position tracking directly impacts developer experience.
rustyline
The rustyline
crate provides a pure-Rust readline implementation for building command-line interfaces. In compiler development, interactive REPLs (Read-Eval-Print Loops) are essential tools for testing language features, debugging compilation passes, and providing an interactive development environment. Rustyline offers features like line editing, history, completion, syntax highlighting, and multi-line input validation that make professional-quality REPLs possible.
A compiler REPL allows developers to experiment with language constructs, inspect intermediate representations, test type inference, and debug compilation errors interactively. Rustyline handles all the terminal interaction complexity, letting compiler authors focus on language semantics and compilation logic.
Basic REPL Structure
Creating a compiler REPL starts with defining commands and setting up the editor:
#![allow(unused)] fn main() { use std::borrow::Cow::{self, Borrowed, Owned}; use std::collections::{HashMap, HashSet}; use rustyline::completion::{Completer, Pair}; use rustyline::highlight::{CmdKind, Highlighter, MatchingBracketHighlighter}; use rustyline::hint::{Hinter, HistoryHinter}; use rustyline::history::DefaultHistory; use rustyline::validate::{ValidationContext, ValidationResult, Validator}; use rustyline::{CompletionType, Config, Context, EditMode, Editor, Helper, Result}; #[derive(Debug, Clone)] pub struct CompilerCommand { pub name: &'static str, pub description: &'static str, pub args: &'static str, } impl CompilerCommand { pub const COMMANDS: &'static [CompilerCommand] = &[ CompilerCommand { name: "load", description: "Load a source file", args: "<filename>", }, CompilerCommand { name: "compile", description: "Compile the current module", args: "[--optimize] [--debug]", }, CompilerCommand { name: "run", description: "Run the compiled program", args: "[args...]", }, CompilerCommand { name: "ast", description: "Show the AST", args: "[function_name]", }, CompilerCommand { name: "ir", description: "Show intermediate representation", args: "[function_name]", }, CompilerCommand { name: "symbols", description: "List all symbols", args: "[pattern]", }, CompilerCommand { name: "type", description: "Show type of expression", args: "<expression>", }, CompilerCommand { name: "help", description: "Show help", args: "[command]", }, CompilerCommand { name: "quit", description: "Exit the REPL", args: "", }, ]; } pub struct CompilerREPL { pub commands: HashMap<String, CompilerCommand>, pub keywords: HashSet<&'static str>, pub history_file: String, pub completer: CommandCompleter, pub highlighter: MatchingBracketHighlighter, pub hinter: HistoryHinter, pub validator: CompilerValidator, } impl Helper for CompilerREPL {} #[derive(Clone)] pub struct CommandCompleter { commands: Vec<String>, keywords: Vec<&'static str>, } impl CommandCompleter { pub fn new() -> Self { let commands = CompilerCommand::COMMANDS .iter() .map(|cmd| cmd.name.to_string()) .collect(); let keywords = vec![ "fn", "let", "const", "if", "else", "while", "for", "return", "struct", "enum", "impl", "trait", "pub", "mod", "use", ]; Self { commands, keywords } } } impl Default for CommandCompleter { fn default() -> Self { Self::new() } } impl Completer for CommandCompleter { type Candidate = Pair; fn complete(&self, line: &str, pos: usize, _ctx: &Context<'_>) -> Result<(usize, Vec<Pair>)> { let line_before_cursor = &line[..pos]; let words: Vec<&str> = line_before_cursor.split_whitespace().collect(); if words.is_empty() || (words.len() == 1 && !line_before_cursor.ends_with(' ')) { let prefix = words.first().unwrap_or(&""); let matches: Vec<Pair> = self .commands .iter() .filter(|cmd| cmd.starts_with(prefix)) .map(|cmd| Pair { display: cmd.clone(), replacement: cmd.clone(), }) .collect(); Ok((0, matches)) } else { let last_word = words.last().unwrap_or(&""); let word_start = line_before_cursor.rfind(last_word).unwrap_or(pos); let matches: Vec<Pair> = self .keywords .iter() .filter(|kw| kw.starts_with(last_word)) .map(|kw| Pair { display: kw.to_string(), replacement: kw.to_string(), }) .collect(); Ok((word_start, matches)) } } } impl Completer for CompilerREPL { type Candidate = Pair; fn complete(&self, line: &str, pos: usize, ctx: &Context<'_>) -> Result<(usize, Vec<Pair>)> { self.completer.complete(line, pos, ctx) } } impl Hinter for CompilerREPL { type Hint = String; fn hint(&self, line: &str, pos: usize, ctx: &Context<'_>) -> Option<String> { self.hinter.hint(line, pos, ctx) } } impl Highlighter for CompilerREPL { fn highlight_prompt<'b, 's: 'b, 'p: 'b>( &'s self, prompt: &'p str, default: bool, ) -> Cow<'b, str> { if default { Borrowed("compiler> ") } else { Borrowed(prompt) } } fn highlight_hint<'h>(&self, hint: &'h str) -> Cow<'h, str> { Owned(format!("\x1b[90m{}\x1b[0m", hint)) } fn highlight<'l>(&self, line: &'l str, pos: usize) -> Cow<'l, str> { let mut highlighted = String::new(); let words: Vec<&str> = line.split_whitespace().collect(); if let Some(first_word) = words.first() { if self.commands.contains_key(*first_word) { highlighted.push_str("\x1b[32m"); highlighted.push_str(first_word); highlighted.push_str("\x1b[0m"); if line.len() > first_word.len() { highlighted.push_str(&line[first_word.len()..]); } return Owned(highlighted); } } for (i, ch) in line.chars().enumerate() { if ch == '(' || ch == ')' || ch == '{' || ch == '}' || ch == '[' || ch == ']' { if i == pos || i == pos - 1 { highlighted.push_str("\x1b[1;33m"); highlighted.push(ch); highlighted.push_str("\x1b[0m"); } else { highlighted.push(ch); } } else { highlighted.push(ch); } } Owned(highlighted) } fn highlight_char(&self, line: &str, pos: usize, kind: CmdKind) -> bool { self.highlighter.highlight_char(line, pos, kind) } } #[derive(Clone)] pub struct CompilerValidator; impl Validator for CompilerValidator { fn validate(&self, ctx: &mut ValidationContext) -> Result<ValidationResult> { let input = ctx.input(); let mut stack = Vec::new(); for ch in input.chars() { match ch { '(' | '{' | '[' => stack.push(ch), ')' => { if stack.pop() != Some('(') { return Ok(ValidationResult::Invalid(Some( "Mismatched parentheses".into(), ))); } } '}' => { if stack.pop() != Some('{') { return Ok(ValidationResult::Invalid(Some("Mismatched braces".into()))); } } ']' => { if stack.pop() != Some('[') { return Ok(ValidationResult::Invalid(Some( "Mismatched brackets".into(), ))); } } _ => {} } } if stack.is_empty() { Ok(ValidationResult::Valid(None)) } else { Ok(ValidationResult::Incomplete) } } } impl Validator for CompilerREPL { fn validate(&self, ctx: &mut ValidationContext) -> Result<ValidationResult> { self.validator.validate(ctx) } } impl CompilerREPL { pub fn new() -> Self { let mut commands = HashMap::new(); for cmd in CompilerCommand::COMMANDS { commands.insert(cmd.name.to_string(), cmd.clone()); } let keywords = HashSet::from([ "fn", "let", "const", "if", "else", "while", "for", "return", "struct", "enum", "impl", "trait", "pub", "mod", "use", ]); Self { commands, keywords, history_file: "compiler_history.txt".to_string(), completer: CommandCompleter::new(), highlighter: MatchingBracketHighlighter::new(), hinter: HistoryHinter::new(), validator: CompilerValidator, } } } impl Default for CompilerREPL { fn default() -> Self { Self::new() } } pub fn process_command(line: &str, repl: &CompilerREPL) -> bool { let parts: Vec<&str> = line.split_whitespace().collect(); if parts.is_empty() { return true; } match parts[0] { "help" => { if parts.len() > 1 { if let Some(cmd) = repl.commands.get(parts[1]) { println!("{} - {}", cmd.name, cmd.description); println!("Usage: {} {}", cmd.name, cmd.args); } else { println!("Unknown command: {}", parts[1]); } } else { println!("Available commands:"); for cmd in CompilerCommand::COMMANDS { println!(" {:10} - {}", cmd.name, cmd.description); } } } "quit" => return false, "load" => println!("Loading file: {:?}", parts.get(1)), "compile" => println!("Compiling with options: {:?}", &parts[1..]), "run" => println!("Running with arguments: {:?}", &parts[1..]), "ast" => println!("Showing AST for: {:?}", parts.get(1)), "ir" => println!("Showing IR for: {:?}", parts.get(1)), "symbols" => println!("Listing symbols matching: {:?}", parts.get(1)), "type" => println!("Type checking: {}", parts[1..].join(" ")), _ => println!( "Unknown command: {}. Type 'help' for available commands.", parts[0] ), } true } pub fn create_editor() -> Result<Editor<CompilerREPL, DefaultHistory>> { let config = Config::builder() .history_ignore_space(true) .completion_type(CompletionType::List) .edit_mode(EditMode::Emacs) .build(); let helper = CompilerREPL::new(); let mut editor = Editor::with_config(config)?; editor.set_helper(Some(helper)); if editor.load_history("compiler_history.txt").is_err() { println!("No previous history."); } Ok(editor) } }
The configuration enables history tracking, list-style completions, and Emacs keybindings. The helper object provides all the advanced features like completion and syntax highlighting.
Command System
A well-designed compiler REPL provides commands for various compilation stages:
#![allow(unused)] fn main() { use std::borrow::Cow::{self, Borrowed, Owned}; use std::collections::{HashMap, HashSet}; use rustyline::completion::{Completer, Pair}; use rustyline::highlight::{CmdKind, Highlighter, MatchingBracketHighlighter}; use rustyline::hint::{Hinter, HistoryHinter}; use rustyline::history::DefaultHistory; use rustyline::validate::{ValidationContext, ValidationResult, Validator}; use rustyline::{CompletionType, Config, Context, EditMode, Editor, Helper, Result}; impl CompilerCommand { pub const COMMANDS: &'static [CompilerCommand] = &[ CompilerCommand { name: "load", description: "Load a source file", args: "<filename>", }, CompilerCommand { name: "compile", description: "Compile the current module", args: "[--optimize] [--debug]", }, CompilerCommand { name: "run", description: "Run the compiled program", args: "[args...]", }, CompilerCommand { name: "ast", description: "Show the AST", args: "[function_name]", }, CompilerCommand { name: "ir", description: "Show intermediate representation", args: "[function_name]", }, CompilerCommand { name: "symbols", description: "List all symbols", args: "[pattern]", }, CompilerCommand { name: "type", description: "Show type of expression", args: "<expression>", }, CompilerCommand { name: "help", description: "Show help", args: "[command]", }, CompilerCommand { name: "quit", description: "Exit the REPL", args: "", }, ]; } pub struct CompilerREPL { pub commands: HashMap<String, CompilerCommand>, pub keywords: HashSet<&'static str>, pub history_file: String, pub completer: CommandCompleter, pub highlighter: MatchingBracketHighlighter, pub hinter: HistoryHinter, pub validator: CompilerValidator, } impl Helper for CompilerREPL {} #[derive(Clone)] pub struct CommandCompleter { commands: Vec<String>, keywords: Vec<&'static str>, } impl CommandCompleter { pub fn new() -> Self { let commands = CompilerCommand::COMMANDS .iter() .map(|cmd| cmd.name.to_string()) .collect(); let keywords = vec![ "fn", "let", "const", "if", "else", "while", "for", "return", "struct", "enum", "impl", "trait", "pub", "mod", "use", ]; Self { commands, keywords } } } impl Default for CommandCompleter { fn default() -> Self { Self::new() } } impl Completer for CommandCompleter { type Candidate = Pair; fn complete(&self, line: &str, pos: usize, _ctx: &Context<'_>) -> Result<(usize, Vec<Pair>)> { let line_before_cursor = &line[..pos]; let words: Vec<&str> = line_before_cursor.split_whitespace().collect(); if words.is_empty() || (words.len() == 1 && !line_before_cursor.ends_with(' ')) { let prefix = words.first().unwrap_or(&""); let matches: Vec<Pair> = self .commands .iter() .filter(|cmd| cmd.starts_with(prefix)) .map(|cmd| Pair { display: cmd.clone(), replacement: cmd.clone(), }) .collect(); Ok((0, matches)) } else { let last_word = words.last().unwrap_or(&""); let word_start = line_before_cursor.rfind(last_word).unwrap_or(pos); let matches: Vec<Pair> = self .keywords .iter() .filter(|kw| kw.starts_with(last_word)) .map(|kw| Pair { display: kw.to_string(), replacement: kw.to_string(), }) .collect(); Ok((word_start, matches)) } } } impl Completer for CompilerREPL { type Candidate = Pair; fn complete(&self, line: &str, pos: usize, ctx: &Context<'_>) -> Result<(usize, Vec<Pair>)> { self.completer.complete(line, pos, ctx) } } impl Hinter for CompilerREPL { type Hint = String; fn hint(&self, line: &str, pos: usize, ctx: &Context<'_>) -> Option<String> { self.hinter.hint(line, pos, ctx) } } impl Highlighter for CompilerREPL { fn highlight_prompt<'b, 's: 'b, 'p: 'b>( &'s self, prompt: &'p str, default: bool, ) -> Cow<'b, str> { if default { Borrowed("compiler> ") } else { Borrowed(prompt) } } fn highlight_hint<'h>(&self, hint: &'h str) -> Cow<'h, str> { Owned(format!("\x1b[90m{}\x1b[0m", hint)) } fn highlight<'l>(&self, line: &'l str, pos: usize) -> Cow<'l, str> { let mut highlighted = String::new(); let words: Vec<&str> = line.split_whitespace().collect(); if let Some(first_word) = words.first() { if self.commands.contains_key(*first_word) { highlighted.push_str("\x1b[32m"); highlighted.push_str(first_word); highlighted.push_str("\x1b[0m"); if line.len() > first_word.len() { highlighted.push_str(&line[first_word.len()..]); } return Owned(highlighted); } } for (i, ch) in line.chars().enumerate() { if ch == '(' || ch == ')' || ch == '{' || ch == '}' || ch == '[' || ch == ']' { if i == pos || i == pos - 1 { highlighted.push_str("\x1b[1;33m"); highlighted.push(ch); highlighted.push_str("\x1b[0m"); } else { highlighted.push(ch); } } else { highlighted.push(ch); } } Owned(highlighted) } fn highlight_char(&self, line: &str, pos: usize, kind: CmdKind) -> bool { self.highlighter.highlight_char(line, pos, kind) } } #[derive(Clone)] pub struct CompilerValidator; impl Validator for CompilerValidator { fn validate(&self, ctx: &mut ValidationContext) -> Result<ValidationResult> { let input = ctx.input(); let mut stack = Vec::new(); for ch in input.chars() { match ch { '(' | '{' | '[' => stack.push(ch), ')' => { if stack.pop() != Some('(') { return Ok(ValidationResult::Invalid(Some( "Mismatched parentheses".into(), ))); } } '}' => { if stack.pop() != Some('{') { return Ok(ValidationResult::Invalid(Some("Mismatched braces".into()))); } } ']' => { if stack.pop() != Some('[') { return Ok(ValidationResult::Invalid(Some( "Mismatched brackets".into(), ))); } } _ => {} } } if stack.is_empty() { Ok(ValidationResult::Valid(None)) } else { Ok(ValidationResult::Incomplete) } } } impl Validator for CompilerREPL { fn validate(&self, ctx: &mut ValidationContext) -> Result<ValidationResult> { self.validator.validate(ctx) } } impl CompilerREPL { pub fn new() -> Self { let mut commands = HashMap::new(); for cmd in CompilerCommand::COMMANDS { commands.insert(cmd.name.to_string(), cmd.clone()); } let keywords = HashSet::from([ "fn", "let", "const", "if", "else", "while", "for", "return", "struct", "enum", "impl", "trait", "pub", "mod", "use", ]); Self { commands, keywords, history_file: "compiler_history.txt".to_string(), completer: CommandCompleter::new(), highlighter: MatchingBracketHighlighter::new(), hinter: HistoryHinter::new(), validator: CompilerValidator, } } } impl Default for CompilerREPL { fn default() -> Self { Self::new() } } pub fn create_editor() -> Result<Editor<CompilerREPL, DefaultHistory>> { let config = Config::builder() .history_ignore_space(true) .completion_type(CompletionType::List) .edit_mode(EditMode::Emacs) .build(); let helper = CompilerREPL::new(); let mut editor = Editor::with_config(config)?; editor.set_helper(Some(helper)); if editor.load_history("compiler_history.txt").is_err() { println!("No previous history."); } Ok(editor) } pub fn process_command(line: &str, repl: &CompilerREPL) -> bool { let parts: Vec<&str> = line.split_whitespace().collect(); if parts.is_empty() { return true; } match parts[0] { "help" => { if parts.len() > 1 { if let Some(cmd) = repl.commands.get(parts[1]) { println!("{} - {}", cmd.name, cmd.description); println!("Usage: {} {}", cmd.name, cmd.args); } else { println!("Unknown command: {}", parts[1]); } } else { println!("Available commands:"); for cmd in CompilerCommand::COMMANDS { println!(" {:10} - {}", cmd.name, cmd.description); } } } "quit" => return false, "load" => println!("Loading file: {:?}", parts.get(1)), "compile" => println!("Compiling with options: {:?}", &parts[1..]), "run" => println!("Running with arguments: {:?}", &parts[1..]), "ast" => println!("Showing AST for: {:?}", parts.get(1)), "ir" => println!("Showing IR for: {:?}", parts.get(1)), "symbols" => println!("Listing symbols matching: {:?}", parts.get(1)), "type" => println!("Type checking: {}", parts[1..].join(" ")), _ => println!( "Unknown command: {}. Type 'help' for available commands.", parts[0] ), } true } #[derive(Debug, Clone)] pub struct CompilerCommand { pub name: &'static str, pub description: &'static str, pub args: &'static str, } }
#![allow(unused)] fn main() { use std::borrow::Cow::{self, Borrowed, Owned}; use std::collections::{HashMap, HashSet}; use rustyline::completion::{Completer, Pair}; use rustyline::highlight::{CmdKind, Highlighter, MatchingBracketHighlighter}; use rustyline::hint::{Hinter, HistoryHinter}; use rustyline::history::DefaultHistory; use rustyline::validate::{ValidationContext, ValidationResult, Validator}; use rustyline::{CompletionType, Config, Context, EditMode, Editor, Helper, Result}; #[derive(Debug, Clone)] pub struct CompilerCommand { pub name: &'static str, pub description: &'static str, pub args: &'static str, } impl CompilerCommand { pub const COMMANDS: &'static [CompilerCommand] = &[ CompilerCommand { name: "load", description: "Load a source file", args: "<filename>", }, CompilerCommand { name: "compile", description: "Compile the current module", args: "[--optimize] [--debug]", }, CompilerCommand { name: "run", description: "Run the compiled program", args: "[args...]", }, CompilerCommand { name: "ast", description: "Show the AST", args: "[function_name]", }, CompilerCommand { name: "ir", description: "Show intermediate representation", args: "[function_name]", }, CompilerCommand { name: "symbols", description: "List all symbols", args: "[pattern]", }, CompilerCommand { name: "type", description: "Show type of expression", args: "<expression>", }, CompilerCommand { name: "help", description: "Show help", args: "[command]", }, CompilerCommand { name: "quit", description: "Exit the REPL", args: "", }, ]; } pub struct CompilerREPL { pub commands: HashMap<String, CompilerCommand>, pub keywords: HashSet<&'static str>, pub history_file: String, pub completer: CommandCompleter, pub highlighter: MatchingBracketHighlighter, pub hinter: HistoryHinter, pub validator: CompilerValidator, } impl Helper for CompilerREPL {} #[derive(Clone)] pub struct CommandCompleter { commands: Vec<String>, keywords: Vec<&'static str>, } impl CommandCompleter { pub fn new() -> Self { let commands = CompilerCommand::COMMANDS .iter() .map(|cmd| cmd.name.to_string()) .collect(); let keywords = vec![ "fn", "let", "const", "if", "else", "while", "for", "return", "struct", "enum", "impl", "trait", "pub", "mod", "use", ]; Self { commands, keywords } } } impl Default for CommandCompleter { fn default() -> Self { Self::new() } } impl Completer for CommandCompleter { type Candidate = Pair; fn complete(&self, line: &str, pos: usize, _ctx: &Context<'_>) -> Result<(usize, Vec<Pair>)> { let line_before_cursor = &line[..pos]; let words: Vec<&str> = line_before_cursor.split_whitespace().collect(); if words.is_empty() || (words.len() == 1 && !line_before_cursor.ends_with(' ')) { let prefix = words.first().unwrap_or(&""); let matches: Vec<Pair> = self .commands .iter() .filter(|cmd| cmd.starts_with(prefix)) .map(|cmd| Pair { display: cmd.clone(), replacement: cmd.clone(), }) .collect(); Ok((0, matches)) } else { let last_word = words.last().unwrap_or(&""); let word_start = line_before_cursor.rfind(last_word).unwrap_or(pos); let matches: Vec<Pair> = self .keywords .iter() .filter(|kw| kw.starts_with(last_word)) .map(|kw| Pair { display: kw.to_string(), replacement: kw.to_string(), }) .collect(); Ok((word_start, matches)) } } } impl Completer for CompilerREPL { type Candidate = Pair; fn complete(&self, line: &str, pos: usize, ctx: &Context<'_>) -> Result<(usize, Vec<Pair>)> { self.completer.complete(line, pos, ctx) } } impl Hinter for CompilerREPL { type Hint = String; fn hint(&self, line: &str, pos: usize, ctx: &Context<'_>) -> Option<String> { self.hinter.hint(line, pos, ctx) } } impl Highlighter for CompilerREPL { fn highlight_prompt<'b, 's: 'b, 'p: 'b>( &'s self, prompt: &'p str, default: bool, ) -> Cow<'b, str> { if default { Borrowed("compiler> ") } else { Borrowed(prompt) } } fn highlight_hint<'h>(&self, hint: &'h str) -> Cow<'h, str> { Owned(format!("\x1b[90m{}\x1b[0m", hint)) } fn highlight<'l>(&self, line: &'l str, pos: usize) -> Cow<'l, str> { let mut highlighted = String::new(); let words: Vec<&str> = line.split_whitespace().collect(); if let Some(first_word) = words.first() { if self.commands.contains_key(*first_word) { highlighted.push_str("\x1b[32m"); highlighted.push_str(first_word); highlighted.push_str("\x1b[0m"); if line.len() > first_word.len() { highlighted.push_str(&line[first_word.len()..]); } return Owned(highlighted); } } for (i, ch) in line.chars().enumerate() { if ch == '(' || ch == ')' || ch == '{' || ch == '}' || ch == '[' || ch == ']' { if i == pos || i == pos - 1 { highlighted.push_str("\x1b[1;33m"); highlighted.push(ch); highlighted.push_str("\x1b[0m"); } else { highlighted.push(ch); } } else { highlighted.push(ch); } } Owned(highlighted) } fn highlight_char(&self, line: &str, pos: usize, kind: CmdKind) -> bool { self.highlighter.highlight_char(line, pos, kind) } } #[derive(Clone)] pub struct CompilerValidator; impl Validator for CompilerValidator { fn validate(&self, ctx: &mut ValidationContext) -> Result<ValidationResult> { let input = ctx.input(); let mut stack = Vec::new(); for ch in input.chars() { match ch { '(' | '{' | '[' => stack.push(ch), ')' => { if stack.pop() != Some('(') { return Ok(ValidationResult::Invalid(Some( "Mismatched parentheses".into(), ))); } } '}' => { if stack.pop() != Some('{') { return Ok(ValidationResult::Invalid(Some("Mismatched braces".into()))); } } ']' => { if stack.pop() != Some('[') { return Ok(ValidationResult::Invalid(Some( "Mismatched brackets".into(), ))); } } _ => {} } } if stack.is_empty() { Ok(ValidationResult::Valid(None)) } else { Ok(ValidationResult::Incomplete) } } } impl Validator for CompilerREPL { fn validate(&self, ctx: &mut ValidationContext) -> Result<ValidationResult> { self.validator.validate(ctx) } } impl CompilerREPL { pub fn new() -> Self { let mut commands = HashMap::new(); for cmd in CompilerCommand::COMMANDS { commands.insert(cmd.name.to_string(), cmd.clone()); } let keywords = HashSet::from([ "fn", "let", "const", "if", "else", "while", "for", "return", "struct", "enum", "impl", "trait", "pub", "mod", "use", ]); Self { commands, keywords, history_file: "compiler_history.txt".to_string(), completer: CommandCompleter::new(), highlighter: MatchingBracketHighlighter::new(), hinter: HistoryHinter::new(), validator: CompilerValidator, } } } impl Default for CompilerREPL { fn default() -> Self { Self::new() } } pub fn create_editor() -> Result<Editor<CompilerREPL, DefaultHistory>> { let config = Config::builder() .history_ignore_space(true) .completion_type(CompletionType::List) .edit_mode(EditMode::Emacs) .build(); let helper = CompilerREPL::new(); let mut editor = Editor::with_config(config)?; editor.set_helper(Some(helper)); if editor.load_history("compiler_history.txt").is_err() { println!("No previous history."); } Ok(editor) } pub fn process_command(line: &str, repl: &CompilerREPL) -> bool { let parts: Vec<&str> = line.split_whitespace().collect(); if parts.is_empty() { return true; } match parts[0] { "help" => { if parts.len() > 1 { if let Some(cmd) = repl.commands.get(parts[1]) { println!("{} - {}", cmd.name, cmd.description); println!("Usage: {} {}", cmd.name, cmd.args); } else { println!("Unknown command: {}", parts[1]); } } else { println!("Available commands:"); for cmd in CompilerCommand::COMMANDS { println!(" {:10} - {}", cmd.name, cmd.description); } } } "quit" => return false, "load" => println!("Loading file: {:?}", parts.get(1)), "compile" => println!("Compiling with options: {:?}", &parts[1..]), "run" => println!("Running with arguments: {:?}", &parts[1..]), "ast" => println!("Showing AST for: {:?}", parts.get(1)), "ir" => println!("Showing IR for: {:?}", parts.get(1)), "symbols" => println!("Listing symbols matching: {:?}", parts.get(1)), "type" => println!("Type checking: {}", parts[1..].join(" ")), _ => println!( "Unknown command: {}. Type 'help' for available commands.", parts[0] ), } true } }
Commands allow users to load files, compile code, inspect ASTs and IR, query types, and manage the compilation context. This structure makes the REPL extensible and discoverable.
Completion Support
Intelligent completion improves REPL usability significantly:
#![allow(unused)] fn main() { use std::borrow::Cow::{self, Borrowed, Owned}; use std::collections::{HashMap, HashSet}; use rustyline::completion::{Completer, Pair}; use rustyline::highlight::{CmdKind, Highlighter, MatchingBracketHighlighter}; use rustyline::hint::{Hinter, HistoryHinter}; use rustyline::history::DefaultHistory; use rustyline::validate::{ValidationContext, ValidationResult, Validator}; use rustyline::{CompletionType, Config, Context, EditMode, Editor, Helper, Result}; #[derive(Debug, Clone)] pub struct CompilerCommand { pub name: &'static str, pub description: &'static str, pub args: &'static str, } impl CompilerCommand { pub const COMMANDS: &'static [CompilerCommand] = &[ CompilerCommand { name: "load", description: "Load a source file", args: "<filename>", }, CompilerCommand { name: "compile", description: "Compile the current module", args: "[--optimize] [--debug]", }, CompilerCommand { name: "run", description: "Run the compiled program", args: "[args...]", }, CompilerCommand { name: "ast", description: "Show the AST", args: "[function_name]", }, CompilerCommand { name: "ir", description: "Show intermediate representation", args: "[function_name]", }, CompilerCommand { name: "symbols", description: "List all symbols", args: "[pattern]", }, CompilerCommand { name: "type", description: "Show type of expression", args: "<expression>", }, CompilerCommand { name: "help", description: "Show help", args: "[command]", }, CompilerCommand { name: "quit", description: "Exit the REPL", args: "", }, ]; } pub struct CompilerREPL { pub commands: HashMap<String, CompilerCommand>, pub keywords: HashSet<&'static str>, pub history_file: String, pub completer: CommandCompleter, pub highlighter: MatchingBracketHighlighter, pub hinter: HistoryHinter, pub validator: CompilerValidator, } impl Helper for CompilerREPL {} impl CommandCompleter { pub fn new() -> Self { let commands = CompilerCommand::COMMANDS .iter() .map(|cmd| cmd.name.to_string()) .collect(); let keywords = vec![ "fn", "let", "const", "if", "else", "while", "for", "return", "struct", "enum", "impl", "trait", "pub", "mod", "use", ]; Self { commands, keywords } } } impl Default for CommandCompleter { fn default() -> Self { Self::new() } } impl Completer for CommandCompleter { type Candidate = Pair; fn complete(&self, line: &str, pos: usize, _ctx: &Context<'_>) -> Result<(usize, Vec<Pair>)> { let line_before_cursor = &line[..pos]; let words: Vec<&str> = line_before_cursor.split_whitespace().collect(); if words.is_empty() || (words.len() == 1 && !line_before_cursor.ends_with(' ')) { let prefix = words.first().unwrap_or(&""); let matches: Vec<Pair> = self .commands .iter() .filter(|cmd| cmd.starts_with(prefix)) .map(|cmd| Pair { display: cmd.clone(), replacement: cmd.clone(), }) .collect(); Ok((0, matches)) } else { let last_word = words.last().unwrap_or(&""); let word_start = line_before_cursor.rfind(last_word).unwrap_or(pos); let matches: Vec<Pair> = self .keywords .iter() .filter(|kw| kw.starts_with(last_word)) .map(|kw| Pair { display: kw.to_string(), replacement: kw.to_string(), }) .collect(); Ok((word_start, matches)) } } } impl Completer for CompilerREPL { type Candidate = Pair; fn complete(&self, line: &str, pos: usize, ctx: &Context<'_>) -> Result<(usize, Vec<Pair>)> { self.completer.complete(line, pos, ctx) } } impl Hinter for CompilerREPL { type Hint = String; fn hint(&self, line: &str, pos: usize, ctx: &Context<'_>) -> Option<String> { self.hinter.hint(line, pos, ctx) } } impl Highlighter for CompilerREPL { fn highlight_prompt<'b, 's: 'b, 'p: 'b>( &'s self, prompt: &'p str, default: bool, ) -> Cow<'b, str> { if default { Borrowed("compiler> ") } else { Borrowed(prompt) } } fn highlight_hint<'h>(&self, hint: &'h str) -> Cow<'h, str> { Owned(format!("\x1b[90m{}\x1b[0m", hint)) } fn highlight<'l>(&self, line: &'l str, pos: usize) -> Cow<'l, str> { let mut highlighted = String::new(); let words: Vec<&str> = line.split_whitespace().collect(); if let Some(first_word) = words.first() { if self.commands.contains_key(*first_word) { highlighted.push_str("\x1b[32m"); highlighted.push_str(first_word); highlighted.push_str("\x1b[0m"); if line.len() > first_word.len() { highlighted.push_str(&line[first_word.len()..]); } return Owned(highlighted); } } for (i, ch) in line.chars().enumerate() { if ch == '(' || ch == ')' || ch == '{' || ch == '}' || ch == '[' || ch == ']' { if i == pos || i == pos - 1 { highlighted.push_str("\x1b[1;33m"); highlighted.push(ch); highlighted.push_str("\x1b[0m"); } else { highlighted.push(ch); } } else { highlighted.push(ch); } } Owned(highlighted) } fn highlight_char(&self, line: &str, pos: usize, kind: CmdKind) -> bool { self.highlighter.highlight_char(line, pos, kind) } } #[derive(Clone)] pub struct CompilerValidator; impl Validator for CompilerValidator { fn validate(&self, ctx: &mut ValidationContext) -> Result<ValidationResult> { let input = ctx.input(); let mut stack = Vec::new(); for ch in input.chars() { match ch { '(' | '{' | '[' => stack.push(ch), ')' => { if stack.pop() != Some('(') { return Ok(ValidationResult::Invalid(Some( "Mismatched parentheses".into(), ))); } } '}' => { if stack.pop() != Some('{') { return Ok(ValidationResult::Invalid(Some("Mismatched braces".into()))); } } ']' => { if stack.pop() != Some('[') { return Ok(ValidationResult::Invalid(Some( "Mismatched brackets".into(), ))); } } _ => {} } } if stack.is_empty() { Ok(ValidationResult::Valid(None)) } else { Ok(ValidationResult::Incomplete) } } } impl Validator for CompilerREPL { fn validate(&self, ctx: &mut ValidationContext) -> Result<ValidationResult> { self.validator.validate(ctx) } } impl CompilerREPL { pub fn new() -> Self { let mut commands = HashMap::new(); for cmd in CompilerCommand::COMMANDS { commands.insert(cmd.name.to_string(), cmd.clone()); } let keywords = HashSet::from([ "fn", "let", "const", "if", "else", "while", "for", "return", "struct", "enum", "impl", "trait", "pub", "mod", "use", ]); Self { commands, keywords, history_file: "compiler_history.txt".to_string(), completer: CommandCompleter::new(), highlighter: MatchingBracketHighlighter::new(), hinter: HistoryHinter::new(), validator: CompilerValidator, } } } impl Default for CompilerREPL { fn default() -> Self { Self::new() } } pub fn create_editor() -> Result<Editor<CompilerREPL, DefaultHistory>> { let config = Config::builder() .history_ignore_space(true) .completion_type(CompletionType::List) .edit_mode(EditMode::Emacs) .build(); let helper = CompilerREPL::new(); let mut editor = Editor::with_config(config)?; editor.set_helper(Some(helper)); if editor.load_history("compiler_history.txt").is_err() { println!("No previous history."); } Ok(editor) } pub fn process_command(line: &str, repl: &CompilerREPL) -> bool { let parts: Vec<&str> = line.split_whitespace().collect(); if parts.is_empty() { return true; } match parts[0] { "help" => { if parts.len() > 1 { if let Some(cmd) = repl.commands.get(parts[1]) { println!("{} - {}", cmd.name, cmd.description); println!("Usage: {} {}", cmd.name, cmd.args); } else { println!("Unknown command: {}", parts[1]); } } else { println!("Available commands:"); for cmd in CompilerCommand::COMMANDS { println!(" {:10} - {}", cmd.name, cmd.description); } } } "quit" => return false, "load" => println!("Loading file: {:?}", parts.get(1)), "compile" => println!("Compiling with options: {:?}", &parts[1..]), "run" => println!("Running with arguments: {:?}", &parts[1..]), "ast" => println!("Showing AST for: {:?}", parts.get(1)), "ir" => println!("Showing IR for: {:?}", parts.get(1)), "symbols" => println!("Listing symbols matching: {:?}", parts.get(1)), "type" => println!("Type checking: {}", parts[1..].join(" ")), _ => println!( "Unknown command: {}. Type 'help' for available commands.", parts[0] ), } true } #[derive(Clone)] pub struct CommandCompleter { commands: Vec<String>, keywords: Vec<&'static str>, } }
#![allow(unused)] fn main() { impl Completer for CommandCompleter { type Candidate = Pair; fn complete( &self, line: &str, pos: usize, _ctx: &Context<'_>, ) -> Result<(usize, Vec<Pair>)> { let line_before_cursor = &line[..pos]; let words: Vec<&str> = line_before_cursor.split_whitespace().collect(); if words.is_empty() || (words.len() == 1 && !line_before_cursor.ends_with(' ')) { // Complete commands at start of line let prefix = words.get(0).unwrap_or(&""); let matches: Vec<Pair> = self.commands .iter() .filter(|cmd| cmd.starts_with(prefix)) .map(|cmd| Pair { display: cmd.clone(), replacement: cmd.clone(), }) .collect(); Ok((0, matches)) } else { // Complete keywords within expressions let last_word = words.last().unwrap_or(&""); let word_start = line_before_cursor.rfind(last_word).unwrap_or(pos); let matches: Vec<Pair> = self.keywords .iter() .filter(|kw| kw.starts_with(last_word)) .map(|kw| Pair { display: kw.to_string(), replacement: kw.to_string(), }) .collect(); Ok((word_start, matches)) } } } }
The completer distinguishes between command completion (at the start of a line) and keyword completion (within expressions). This context-aware completion helps users discover commands and write code faster.
Syntax Highlighting
Visual feedback through syntax highlighting makes the REPL more pleasant to use:
#![allow(unused)] fn main() { use std::borrow::Cow::{self, Borrowed, Owned}; use std::collections::{HashMap, HashSet}; use rustyline::completion::{Completer, Pair}; use rustyline::highlight::{CmdKind, Highlighter, MatchingBracketHighlighter}; use rustyline::hint::{Hinter, HistoryHinter}; use rustyline::history::DefaultHistory; use rustyline::validate::{ValidationContext, ValidationResult, Validator}; use rustyline::{CompletionType, Config, Context, EditMode, Editor, Helper, Result}; #[derive(Debug, Clone)] pub struct CompilerCommand { pub name: &'static str, pub description: &'static str, pub args: &'static str, } impl CompilerCommand { pub const COMMANDS: &'static [CompilerCommand] = &[ CompilerCommand { name: "load", description: "Load a source file", args: "<filename>", }, CompilerCommand { name: "compile", description: "Compile the current module", args: "[--optimize] [--debug]", }, CompilerCommand { name: "run", description: "Run the compiled program", args: "[args...]", }, CompilerCommand { name: "ast", description: "Show the AST", args: "[function_name]", }, CompilerCommand { name: "ir", description: "Show intermediate representation", args: "[function_name]", }, CompilerCommand { name: "symbols", description: "List all symbols", args: "[pattern]", }, CompilerCommand { name: "type", description: "Show type of expression", args: "<expression>", }, CompilerCommand { name: "help", description: "Show help", args: "[command]", }, CompilerCommand { name: "quit", description: "Exit the REPL", args: "", }, ]; } pub struct CompilerREPL { pub commands: HashMap<String, CompilerCommand>, pub keywords: HashSet<&'static str>, pub history_file: String, pub completer: CommandCompleter, pub highlighter: MatchingBracketHighlighter, pub hinter: HistoryHinter, pub validator: CompilerValidator, } impl Helper for CompilerREPL {} #[derive(Clone)] pub struct CommandCompleter { commands: Vec<String>, keywords: Vec<&'static str>, } impl CommandCompleter { pub fn new() -> Self { let commands = CompilerCommand::COMMANDS .iter() .map(|cmd| cmd.name.to_string()) .collect(); let keywords = vec![ "fn", "let", "const", "if", "else", "while", "for", "return", "struct", "enum", "impl", "trait", "pub", "mod", "use", ]; Self { commands, keywords } } } impl Default for CommandCompleter { fn default() -> Self { Self::new() } } impl Completer for CommandCompleter { type Candidate = Pair; fn complete(&self, line: &str, pos: usize, _ctx: &Context<'_>) -> Result<(usize, Vec<Pair>)> { let line_before_cursor = &line[..pos]; let words: Vec<&str> = line_before_cursor.split_whitespace().collect(); if words.is_empty() || (words.len() == 1 && !line_before_cursor.ends_with(' ')) { let prefix = words.first().unwrap_or(&""); let matches: Vec<Pair> = self .commands .iter() .filter(|cmd| cmd.starts_with(prefix)) .map(|cmd| Pair { display: cmd.clone(), replacement: cmd.clone(), }) .collect(); Ok((0, matches)) } else { let last_word = words.last().unwrap_or(&""); let word_start = line_before_cursor.rfind(last_word).unwrap_or(pos); let matches: Vec<Pair> = self .keywords .iter() .filter(|kw| kw.starts_with(last_word)) .map(|kw| Pair { display: kw.to_string(), replacement: kw.to_string(), }) .collect(); Ok((word_start, matches)) } } } impl Completer for CompilerREPL { type Candidate = Pair; fn complete(&self, line: &str, pos: usize, ctx: &Context<'_>) -> Result<(usize, Vec<Pair>)> { self.completer.complete(line, pos, ctx) } } impl Hinter for CompilerREPL { type Hint = String; fn hint(&self, line: &str, pos: usize, ctx: &Context<'_>) -> Option<String> { self.hinter.hint(line, pos, ctx) } } #[derive(Clone)] pub struct CompilerValidator; impl Validator for CompilerValidator { fn validate(&self, ctx: &mut ValidationContext) -> Result<ValidationResult> { let input = ctx.input(); let mut stack = Vec::new(); for ch in input.chars() { match ch { '(' | '{' | '[' => stack.push(ch), ')' => { if stack.pop() != Some('(') { return Ok(ValidationResult::Invalid(Some( "Mismatched parentheses".into(), ))); } } '}' => { if stack.pop() != Some('{') { return Ok(ValidationResult::Invalid(Some("Mismatched braces".into()))); } } ']' => { if stack.pop() != Some('[') { return Ok(ValidationResult::Invalid(Some( "Mismatched brackets".into(), ))); } } _ => {} } } if stack.is_empty() { Ok(ValidationResult::Valid(None)) } else { Ok(ValidationResult::Incomplete) } } } impl Validator for CompilerREPL { fn validate(&self, ctx: &mut ValidationContext) -> Result<ValidationResult> { self.validator.validate(ctx) } } impl CompilerREPL { pub fn new() -> Self { let mut commands = HashMap::new(); for cmd in CompilerCommand::COMMANDS { commands.insert(cmd.name.to_string(), cmd.clone()); } let keywords = HashSet::from([ "fn", "let", "const", "if", "else", "while", "for", "return", "struct", "enum", "impl", "trait", "pub", "mod", "use", ]); Self { commands, keywords, history_file: "compiler_history.txt".to_string(), completer: CommandCompleter::new(), highlighter: MatchingBracketHighlighter::new(), hinter: HistoryHinter::new(), validator: CompilerValidator, } } } impl Default for CompilerREPL { fn default() -> Self { Self::new() } } pub fn create_editor() -> Result<Editor<CompilerREPL, DefaultHistory>> { let config = Config::builder() .history_ignore_space(true) .completion_type(CompletionType::List) .edit_mode(EditMode::Emacs) .build(); let helper = CompilerREPL::new(); let mut editor = Editor::with_config(config)?; editor.set_helper(Some(helper)); if editor.load_history("compiler_history.txt").is_err() { println!("No previous history."); } Ok(editor) } pub fn process_command(line: &str, repl: &CompilerREPL) -> bool { let parts: Vec<&str> = line.split_whitespace().collect(); if parts.is_empty() { return true; } match parts[0] { "help" => { if parts.len() > 1 { if let Some(cmd) = repl.commands.get(parts[1]) { println!("{} - {}", cmd.name, cmd.description); println!("Usage: {} {}", cmd.name, cmd.args); } else { println!("Unknown command: {}", parts[1]); } } else { println!("Available commands:"); for cmd in CompilerCommand::COMMANDS { println!(" {:10} - {}", cmd.name, cmd.description); } } } "quit" => return false, "load" => println!("Loading file: {:?}", parts.get(1)), "compile" => println!("Compiling with options: {:?}", &parts[1..]), "run" => println!("Running with arguments: {:?}", &parts[1..]), "ast" => println!("Showing AST for: {:?}", parts.get(1)), "ir" => println!("Showing IR for: {:?}", parts.get(1)), "symbols" => println!("Listing symbols matching: {:?}", parts.get(1)), "type" => println!("Type checking: {}", parts[1..].join(" ")), _ => println!( "Unknown command: {}. Type 'help' for available commands.", parts[0] ), } true } impl Highlighter for CompilerREPL { fn highlight_prompt<'b, 's: 'b, 'p: 'b>( &'s self, prompt: &'p str, default: bool, ) -> Cow<'b, str> { if default { Borrowed("compiler> ") } else { Borrowed(prompt) } } fn highlight_hint<'h>(&self, hint: &'h str) -> Cow<'h, str> { Owned(format!("\x1b[90m{}\x1b[0m", hint)) } fn highlight<'l>(&self, line: &'l str, pos: usize) -> Cow<'l, str> { let mut highlighted = String::new(); let words: Vec<&str> = line.split_whitespace().collect(); if let Some(first_word) = words.first() { if self.commands.contains_key(*first_word) { highlighted.push_str("\x1b[32m"); highlighted.push_str(first_word); highlighted.push_str("\x1b[0m"); if line.len() > first_word.len() { highlighted.push_str(&line[first_word.len()..]); } return Owned(highlighted); } } for (i, ch) in line.chars().enumerate() { if ch == '(' || ch == ')' || ch == '{' || ch == '}' || ch == '[' || ch == ']' { if i == pos || i == pos - 1 { highlighted.push_str("\x1b[1;33m"); highlighted.push(ch); highlighted.push_str("\x1b[0m"); } else { highlighted.push(ch); } } else { highlighted.push(ch); } } Owned(highlighted) } fn highlight_char(&self, line: &str, pos: usize, kind: CmdKind) -> bool { self.highlighter.highlight_char(line, pos, kind) } } }
The highlighter colors commands differently from regular input and highlights matching brackets. This immediate visual feedback helps users spot syntax errors before execution.
Input Validation
Multi-line input support requires validation to determine when input is complete:
#![allow(unused)] fn main() { use std::borrow::Cow::{self, Borrowed, Owned}; use std::collections::{HashMap, HashSet}; use rustyline::completion::{Completer, Pair}; use rustyline::highlight::{CmdKind, Highlighter, MatchingBracketHighlighter}; use rustyline::hint::{Hinter, HistoryHinter}; use rustyline::history::DefaultHistory; use rustyline::validate::{ValidationContext, ValidationResult, Validator}; use rustyline::{CompletionType, Config, Context, EditMode, Editor, Helper, Result}; #[derive(Debug, Clone)] pub struct CompilerCommand { pub name: &'static str, pub description: &'static str, pub args: &'static str, } impl CompilerCommand { pub const COMMANDS: &'static [CompilerCommand] = &[ CompilerCommand { name: "load", description: "Load a source file", args: "<filename>", }, CompilerCommand { name: "compile", description: "Compile the current module", args: "[--optimize] [--debug]", }, CompilerCommand { name: "run", description: "Run the compiled program", args: "[args...]", }, CompilerCommand { name: "ast", description: "Show the AST", args: "[function_name]", }, CompilerCommand { name: "ir", description: "Show intermediate representation", args: "[function_name]", }, CompilerCommand { name: "symbols", description: "List all symbols", args: "[pattern]", }, CompilerCommand { name: "type", description: "Show type of expression", args: "<expression>", }, CompilerCommand { name: "help", description: "Show help", args: "[command]", }, CompilerCommand { name: "quit", description: "Exit the REPL", args: "", }, ]; } pub struct CompilerREPL { pub commands: HashMap<String, CompilerCommand>, pub keywords: HashSet<&'static str>, pub history_file: String, pub completer: CommandCompleter, pub highlighter: MatchingBracketHighlighter, pub hinter: HistoryHinter, pub validator: CompilerValidator, } impl Helper for CompilerREPL {} #[derive(Clone)] pub struct CommandCompleter { commands: Vec<String>, keywords: Vec<&'static str>, } impl CommandCompleter { pub fn new() -> Self { let commands = CompilerCommand::COMMANDS .iter() .map(|cmd| cmd.name.to_string()) .collect(); let keywords = vec![ "fn", "let", "const", "if", "else", "while", "for", "return", "struct", "enum", "impl", "trait", "pub", "mod", "use", ]; Self { commands, keywords } } } impl Default for CommandCompleter { fn default() -> Self { Self::new() } } impl Completer for CommandCompleter { type Candidate = Pair; fn complete(&self, line: &str, pos: usize, _ctx: &Context<'_>) -> Result<(usize, Vec<Pair>)> { let line_before_cursor = &line[..pos]; let words: Vec<&str> = line_before_cursor.split_whitespace().collect(); if words.is_empty() || (words.len() == 1 && !line_before_cursor.ends_with(' ')) { let prefix = words.first().unwrap_or(&""); let matches: Vec<Pair> = self .commands .iter() .filter(|cmd| cmd.starts_with(prefix)) .map(|cmd| Pair { display: cmd.clone(), replacement: cmd.clone(), }) .collect(); Ok((0, matches)) } else { let last_word = words.last().unwrap_or(&""); let word_start = line_before_cursor.rfind(last_word).unwrap_or(pos); let matches: Vec<Pair> = self .keywords .iter() .filter(|kw| kw.starts_with(last_word)) .map(|kw| Pair { display: kw.to_string(), replacement: kw.to_string(), }) .collect(); Ok((word_start, matches)) } } } impl Completer for CompilerREPL { type Candidate = Pair; fn complete(&self, line: &str, pos: usize, ctx: &Context<'_>) -> Result<(usize, Vec<Pair>)> { self.completer.complete(line, pos, ctx) } } impl Hinter for CompilerREPL { type Hint = String; fn hint(&self, line: &str, pos: usize, ctx: &Context<'_>) -> Option<String> { self.hinter.hint(line, pos, ctx) } } impl Highlighter for CompilerREPL { fn highlight_prompt<'b, 's: 'b, 'p: 'b>( &'s self, prompt: &'p str, default: bool, ) -> Cow<'b, str> { if default { Borrowed("compiler> ") } else { Borrowed(prompt) } } fn highlight_hint<'h>(&self, hint: &'h str) -> Cow<'h, str> { Owned(format!("\x1b[90m{}\x1b[0m", hint)) } fn highlight<'l>(&self, line: &'l str, pos: usize) -> Cow<'l, str> { let mut highlighted = String::new(); let words: Vec<&str> = line.split_whitespace().collect(); if let Some(first_word) = words.first() { if self.commands.contains_key(*first_word) { highlighted.push_str("\x1b[32m"); highlighted.push_str(first_word); highlighted.push_str("\x1b[0m"); if line.len() > first_word.len() { highlighted.push_str(&line[first_word.len()..]); } return Owned(highlighted); } } for (i, ch) in line.chars().enumerate() { if ch == '(' || ch == ')' || ch == '{' || ch == '}' || ch == '[' || ch == ']' { if i == pos || i == pos - 1 { highlighted.push_str("\x1b[1;33m"); highlighted.push(ch); highlighted.push_str("\x1b[0m"); } else { highlighted.push(ch); } } else { highlighted.push(ch); } } Owned(highlighted) } fn highlight_char(&self, line: &str, pos: usize, kind: CmdKind) -> bool { self.highlighter.highlight_char(line, pos, kind) } } impl Validator for CompilerValidator { fn validate(&self, ctx: &mut ValidationContext) -> Result<ValidationResult> { let input = ctx.input(); let mut stack = Vec::new(); for ch in input.chars() { match ch { '(' | '{' | '[' => stack.push(ch), ')' => { if stack.pop() != Some('(') { return Ok(ValidationResult::Invalid(Some( "Mismatched parentheses".into(), ))); } } '}' => { if stack.pop() != Some('{') { return Ok(ValidationResult::Invalid(Some("Mismatched braces".into()))); } } ']' => { if stack.pop() != Some('[') { return Ok(ValidationResult::Invalid(Some( "Mismatched brackets".into(), ))); } } _ => {} } } if stack.is_empty() { Ok(ValidationResult::Valid(None)) } else { Ok(ValidationResult::Incomplete) } } } impl Validator for CompilerREPL { fn validate(&self, ctx: &mut ValidationContext) -> Result<ValidationResult> { self.validator.validate(ctx) } } impl CompilerREPL { pub fn new() -> Self { let mut commands = HashMap::new(); for cmd in CompilerCommand::COMMANDS { commands.insert(cmd.name.to_string(), cmd.clone()); } let keywords = HashSet::from([ "fn", "let", "const", "if", "else", "while", "for", "return", "struct", "enum", "impl", "trait", "pub", "mod", "use", ]); Self { commands, keywords, history_file: "compiler_history.txt".to_string(), completer: CommandCompleter::new(), highlighter: MatchingBracketHighlighter::new(), hinter: HistoryHinter::new(), validator: CompilerValidator, } } } impl Default for CompilerREPL { fn default() -> Self { Self::new() } } pub fn create_editor() -> Result<Editor<CompilerREPL, DefaultHistory>> { let config = Config::builder() .history_ignore_space(true) .completion_type(CompletionType::List) .edit_mode(EditMode::Emacs) .build(); let helper = CompilerREPL::new(); let mut editor = Editor::with_config(config)?; editor.set_helper(Some(helper)); if editor.load_history("compiler_history.txt").is_err() { println!("No previous history."); } Ok(editor) } pub fn process_command(line: &str, repl: &CompilerREPL) -> bool { let parts: Vec<&str> = line.split_whitespace().collect(); if parts.is_empty() { return true; } match parts[0] { "help" => { if parts.len() > 1 { if let Some(cmd) = repl.commands.get(parts[1]) { println!("{} - {}", cmd.name, cmd.description); println!("Usage: {} {}", cmd.name, cmd.args); } else { println!("Unknown command: {}", parts[1]); } } else { println!("Available commands:"); for cmd in CompilerCommand::COMMANDS { println!(" {:10} - {}", cmd.name, cmd.description); } } } "quit" => return false, "load" => println!("Loading file: {:?}", parts.get(1)), "compile" => println!("Compiling with options: {:?}", &parts[1..]), "run" => println!("Running with arguments: {:?}", &parts[1..]), "ast" => println!("Showing AST for: {:?}", parts.get(1)), "ir" => println!("Showing IR for: {:?}", parts.get(1)), "symbols" => println!("Listing symbols matching: {:?}", parts.get(1)), "type" => println!("Type checking: {}", parts[1..].join(" ")), _ => println!( "Unknown command: {}. Type 'help' for available commands.", parts[0] ), } true } #[derive(Clone)] pub struct CompilerValidator; }
#![allow(unused)] fn main() { impl Validator for CompilerValidator { fn validate(&self, ctx: &mut ValidationContext) -> Result<ValidationResult> { let input = ctx.input(); let mut stack = Vec::new(); for ch in input.chars() { match ch { '(' | '{' | '[' => stack.push(ch), ')' => { if stack.pop() != Some('(') { return Ok(ValidationResult::Invalid(Some("Mismatched parentheses".into()))); } } '}' => { if stack.pop() != Some('{') { return Ok(ValidationResult::Invalid(Some("Mismatched braces".into()))); } } ']' => { if stack.pop() != Some('[') { return Ok(ValidationResult::Invalid(Some("Mismatched brackets".into()))); } } _ => {} } } if stack.is_empty() { Ok(ValidationResult::Valid(None)) } else { Ok(ValidationResult::Incomplete) } } } }
The validator checks bracket matching to determine if more input is needed. This enables natural multi-line input for function definitions and complex expressions without requiring explicit continuation markers.
Helper Integration
Rustyline uses a helper trait to combine all features:
#![allow(unused)] fn main() { use std::borrow::Cow::{self, Borrowed, Owned}; use std::collections::{HashMap, HashSet}; use rustyline::completion::{Completer, Pair}; use rustyline::highlight::{CmdKind, Highlighter, MatchingBracketHighlighter}; use rustyline::hint::{Hinter, HistoryHinter}; use rustyline::history::DefaultHistory; use rustyline::validate::{ValidationContext, ValidationResult, Validator}; use rustyline::{CompletionType, Config, Context, EditMode, Editor, Helper, Result}; #[derive(Debug, Clone)] pub struct CompilerCommand { pub name: &'static str, pub description: &'static str, pub args: &'static str, } impl CompilerCommand { pub const COMMANDS: &'static [CompilerCommand] = &[ CompilerCommand { name: "load", description: "Load a source file", args: "<filename>", }, CompilerCommand { name: "compile", description: "Compile the current module", args: "[--optimize] [--debug]", }, CompilerCommand { name: "run", description: "Run the compiled program", args: "[args...]", }, CompilerCommand { name: "ast", description: "Show the AST", args: "[function_name]", }, CompilerCommand { name: "ir", description: "Show intermediate representation", args: "[function_name]", }, CompilerCommand { name: "symbols", description: "List all symbols", args: "[pattern]", }, CompilerCommand { name: "type", description: "Show type of expression", args: "<expression>", }, CompilerCommand { name: "help", description: "Show help", args: "[command]", }, CompilerCommand { name: "quit", description: "Exit the REPL", args: "", }, ]; } impl Helper for CompilerREPL {} #[derive(Clone)] pub struct CommandCompleter { commands: Vec<String>, keywords: Vec<&'static str>, } impl CommandCompleter { pub fn new() -> Self { let commands = CompilerCommand::COMMANDS .iter() .map(|cmd| cmd.name.to_string()) .collect(); let keywords = vec![ "fn", "let", "const", "if", "else", "while", "for", "return", "struct", "enum", "impl", "trait", "pub", "mod", "use", ]; Self { commands, keywords } } } impl Default for CommandCompleter { fn default() -> Self { Self::new() } } impl Completer for CommandCompleter { type Candidate = Pair; fn complete(&self, line: &str, pos: usize, _ctx: &Context<'_>) -> Result<(usize, Vec<Pair>)> { let line_before_cursor = &line[..pos]; let words: Vec<&str> = line_before_cursor.split_whitespace().collect(); if words.is_empty() || (words.len() == 1 && !line_before_cursor.ends_with(' ')) { let prefix = words.first().unwrap_or(&""); let matches: Vec<Pair> = self .commands .iter() .filter(|cmd| cmd.starts_with(prefix)) .map(|cmd| Pair { display: cmd.clone(), replacement: cmd.clone(), }) .collect(); Ok((0, matches)) } else { let last_word = words.last().unwrap_or(&""); let word_start = line_before_cursor.rfind(last_word).unwrap_or(pos); let matches: Vec<Pair> = self .keywords .iter() .filter(|kw| kw.starts_with(last_word)) .map(|kw| Pair { display: kw.to_string(), replacement: kw.to_string(), }) .collect(); Ok((word_start, matches)) } } } impl Completer for CompilerREPL { type Candidate = Pair; fn complete(&self, line: &str, pos: usize, ctx: &Context<'_>) -> Result<(usize, Vec<Pair>)> { self.completer.complete(line, pos, ctx) } } impl Hinter for CompilerREPL { type Hint = String; fn hint(&self, line: &str, pos: usize, ctx: &Context<'_>) -> Option<String> { self.hinter.hint(line, pos, ctx) } } impl Highlighter for CompilerREPL { fn highlight_prompt<'b, 's: 'b, 'p: 'b>( &'s self, prompt: &'p str, default: bool, ) -> Cow<'b, str> { if default { Borrowed("compiler> ") } else { Borrowed(prompt) } } fn highlight_hint<'h>(&self, hint: &'h str) -> Cow<'h, str> { Owned(format!("\x1b[90m{}\x1b[0m", hint)) } fn highlight<'l>(&self, line: &'l str, pos: usize) -> Cow<'l, str> { let mut highlighted = String::new(); let words: Vec<&str> = line.split_whitespace().collect(); if let Some(first_word) = words.first() { if self.commands.contains_key(*first_word) { highlighted.push_str("\x1b[32m"); highlighted.push_str(first_word); highlighted.push_str("\x1b[0m"); if line.len() > first_word.len() { highlighted.push_str(&line[first_word.len()..]); } return Owned(highlighted); } } for (i, ch) in line.chars().enumerate() { if ch == '(' || ch == ')' || ch == '{' || ch == '}' || ch == '[' || ch == ']' { if i == pos || i == pos - 1 { highlighted.push_str("\x1b[1;33m"); highlighted.push(ch); highlighted.push_str("\x1b[0m"); } else { highlighted.push(ch); } } else { highlighted.push(ch); } } Owned(highlighted) } fn highlight_char(&self, line: &str, pos: usize, kind: CmdKind) -> bool { self.highlighter.highlight_char(line, pos, kind) } } #[derive(Clone)] pub struct CompilerValidator; impl Validator for CompilerValidator { fn validate(&self, ctx: &mut ValidationContext) -> Result<ValidationResult> { let input = ctx.input(); let mut stack = Vec::new(); for ch in input.chars() { match ch { '(' | '{' | '[' => stack.push(ch), ')' => { if stack.pop() != Some('(') { return Ok(ValidationResult::Invalid(Some( "Mismatched parentheses".into(), ))); } } '}' => { if stack.pop() != Some('{') { return Ok(ValidationResult::Invalid(Some("Mismatched braces".into()))); } } ']' => { if stack.pop() != Some('[') { return Ok(ValidationResult::Invalid(Some( "Mismatched brackets".into(), ))); } } _ => {} } } if stack.is_empty() { Ok(ValidationResult::Valid(None)) } else { Ok(ValidationResult::Incomplete) } } } impl Validator for CompilerREPL { fn validate(&self, ctx: &mut ValidationContext) -> Result<ValidationResult> { self.validator.validate(ctx) } } impl CompilerREPL { pub fn new() -> Self { let mut commands = HashMap::new(); for cmd in CompilerCommand::COMMANDS { commands.insert(cmd.name.to_string(), cmd.clone()); } let keywords = HashSet::from([ "fn", "let", "const", "if", "else", "while", "for", "return", "struct", "enum", "impl", "trait", "pub", "mod", "use", ]); Self { commands, keywords, history_file: "compiler_history.txt".to_string(), completer: CommandCompleter::new(), highlighter: MatchingBracketHighlighter::new(), hinter: HistoryHinter::new(), validator: CompilerValidator, } } } impl Default for CompilerREPL { fn default() -> Self { Self::new() } } pub fn create_editor() -> Result<Editor<CompilerREPL, DefaultHistory>> { let config = Config::builder() .history_ignore_space(true) .completion_type(CompletionType::List) .edit_mode(EditMode::Emacs) .build(); let helper = CompilerREPL::new(); let mut editor = Editor::with_config(config)?; editor.set_helper(Some(helper)); if editor.load_history("compiler_history.txt").is_err() { println!("No previous history."); } Ok(editor) } pub fn process_command(line: &str, repl: &CompilerREPL) -> bool { let parts: Vec<&str> = line.split_whitespace().collect(); if parts.is_empty() { return true; } match parts[0] { "help" => { if parts.len() > 1 { if let Some(cmd) = repl.commands.get(parts[1]) { println!("{} - {}", cmd.name, cmd.description); println!("Usage: {} {}", cmd.name, cmd.args); } else { println!("Unknown command: {}", parts[1]); } } else { println!("Available commands:"); for cmd in CompilerCommand::COMMANDS { println!(" {:10} - {}", cmd.name, cmd.description); } } } "quit" => return false, "load" => println!("Loading file: {:?}", parts.get(1)), "compile" => println!("Compiling with options: {:?}", &parts[1..]), "run" => println!("Running with arguments: {:?}", &parts[1..]), "ast" => println!("Showing AST for: {:?}", parts.get(1)), "ir" => println!("Showing IR for: {:?}", parts.get(1)), "symbols" => println!("Listing symbols matching: {:?}", parts.get(1)), "type" => println!("Type checking: {}", parts[1..].join(" ")), _ => println!( "Unknown command: {}. Type 'help' for available commands.", parts[0] ), } true } pub struct CompilerREPL { pub commands: HashMap<String, CompilerCommand>, pub keywords: HashSet<&'static str>, pub history_file: String, pub completer: CommandCompleter, pub highlighter: MatchingBracketHighlighter, pub hinter: HistoryHinter, pub validator: CompilerValidator, } }
The helper struct implements all the necessary traits and maintains shared state like command definitions and configuration. This design keeps the implementation modular while providing a cohesive interface.
Best Practices
Design commands that mirror your compiler’s architecture. If your compiler has distinct phases like parsing, type checking, and code generation, provide commands to inspect the output of each phase. This helps users understand how their code flows through the compiler.
Implement context-aware completion that understands your language’s syntax. Beyond simple keyword completion, consider completing function names, type names, and module paths based on the current compilation context. This requires integration with your compiler’s symbol tables.
Use validation to support natural multi-line input for your language. If your language uses indentation or keywords to delimit blocks, implement validation logic that understands these patterns. Users should be able to paste multi-line code naturally.
Provide rich error formatting in the REPL. When compilation errors occur, format them with source context, underlining, and helpful messages. The immediate feedback of a REPL makes it ideal for learning a language.
Consider implementing a notebook mode that can save and replay REPL sessions. This is valuable for creating reproducible examples, tutorials, and bug reports. Store both input and output with enough context to replay the session.
Add introspection commands that leverage your compiler’s internal representations. Commands to show type inference results, macro expansions, optimization decisions, and lowered code help users understand the compilation process.
The REPL can serve as more than just an interactive interpreter. It can be a powerful debugging and development tool that provides insight into every stage of compilation.
References
Foundational Compiler Textbooks
-
Compilers: Principles, Techniques, and Tools (The Dragon Book) - The definitive classic textbook by Aho, Lam, Sethi, and Ullman. Comprehensive treatment of lexical analysis, parsing, semantic analysis, and code generation with formal foundations. Updated in 2006 with modern optimization techniques and garbage collection coverage.
-
Engineering a Compiler (3rd Edition) - Keith Cooper and Linda Torczon’s practical engineering approach to compiler construction. Covers SSA forms, instruction scheduling, and graph-coloring register allocation.
-
Modern Compiler Implementation in C/Java/ML - Andrew Appel’s series providing detailed coverage of all compiler phases with working implementations. Available in three language editions. Excellent treatment of advanced topics including object-oriented and functional language compilation.
-
Introduction to Compilers and Language Design - Douglas Thain’s modern, accessible textbook offering a one-semester introduction. Enables building a simple compiler for a C-like language targeting X86 or ARM assembly with complete code examples.
Practical Hands-On Guides
-
Crafting Interpreters - Robert Nystrom’s exceptional hands-on guide building two complete interpreters from scratch. Free online. Covers parsing, semantic analysis, garbage collection, and optimization with beautiful hand-drawn illustrations.
-
Building an Optimizing Compiler - Robert Morgan’s advanced treatment of optimization techniques including data flow analysis, SSA form, and advanced optimization passes.
-
Writing a C Compiler: Build a Real Programming Language from Scratch - Nora Sandler’s book providing a clear path through compiler construction complexities. Progressive approach from simple programs to advanced features using pseudocode for any-language implementation.
-
Essentials of Compilation: An Incremental Approach in Python - Jeremy Siek’s unique incremental approach building a compiler progressively. Makes abstract concepts tangible through direct Python implementation, connecting language design decisions with compiler implementation.
-
Let’s Build a Compiler - Jack Crenshaw’s practically-oriented tutorial demystifying compiler internals. Step-by-step approach presenting up-to-date techniques with detailed implementation guidance.
Specialized Topics
-
Advanced Compiler Design and Implementation - Steven Muchnick’s comprehensive treatment of advanced compiler optimization techniques. Covers case studies of commercial compilers from Sun, IBM, DEC, and Intel. Introduces Informal Compiler Algorithm Notation (ICAN) for clear algorithm communication.
-
Parsing Techniques: A Practical Guide (2nd Edition) - Dick Grune and Ceriel Jacobs’ definitive 622-page treatment of parsing techniques. Free PDF available. Covers all parsing methods with clear explanations and practical applicability.
-
Types and Programming Languages - Benjamin Pierce’s definitive reference for understanding type systems. While not specifically a compiler book, it’s crucial for semantic analysis in compilers. Covers type checking, type inference, and advanced type system features.
-
Garbage Collection: Algorithms for Automatic Dynamic Memory Management - Richard Jones and Rafael Lins’ comprehensive survey of garbage collection algorithms. Covers all major collection strategies including mark-sweep, copying, generational, and concurrent collectors.
LLVM Infrastructure
-
Learn LLVM 17: A Beginner’s Guide - Kai Nacke and Amy Kwan’s hands-on guide to building and extending LLVM compilers. Covers frontend construction, backend development, IR generation and optimization, custom passes, and JIT compilation.
-
LLVM Tutorial: Kaleidoscope - Official step-by-step tutorial building a simple language frontend with LLVM. Covers lexing, parsing, AST construction, LLVM IR generation, JIT compilation, and optimization.
-
Clang Compiler Frontend - Ivan Murashko’s exploration of Clang internals with practical applications for static analysis and custom tooling. Covers AST operations, IDE integration, and performance optimization.
-
LLVM’s Analysis and Transform Passes - Documentation of LLVM’s optimization passes, useful for understanding what optimizations production compilers implement.
MLIR Infrastructure
-
MLIR Passes - Comprehensive documentation of MLIR’s transformation and analysis passes. Covers affine loop transformations, buffer optimizations, control flow simplifications, and dialect-specific passes for GPU, async, linalg, and other domains.
-
MLIR Tutorial - Step-by-step guide building a compiler for the Toy language using MLIR. Demonstrates how to define dialects, implement lowering passes, and leverage MLIR’s infrastructure for optimization.
-
MLIR Dialect Conversion - Guide to MLIR’s dialect conversion framework for progressive lowering between abstraction levels. Essential for understanding how to transform between different IR representations.
-
MLIR Pattern Rewriting - Documentation on MLIR’s declarative pattern rewriting infrastructure. Shows how to express transformations as patterns for maintainable optimization passes.
Cranelift Resources
-
Cranelift’s Instruction Selector DSL, ISLE: Term-Rewriting Made Practical - Deep dive into Cranelift’s instruction selection system using a custom term-rewriting DSL. Shows how to map IR operations to machine instructions systematically.
-
Cranelift, Part 4: A New Register Allocator - Detailed exploration of Cranelift’s register allocation algorithm, covering live ranges, interference graphs, and the practical engineering of a production register allocator.
-
Cranelift: Using E-Graphs for Verified, Cooperating Middle-End Optimizations - RFC describing how Cranelift uses e-graphs to solve the phase-ordering problem in compiler optimizations while maintaining correctness guarantees.
Language-Specific Implementation
-
Compiling to Assembly from Scratch - Vladimir Keleshev’s modern approach using TypeScript subset targeting ARM assembly. Covers both baseline compiler and advanced extensions with complete source code.
-
Implementing Functional Languages: A Tutorial - Simon Peyton Jones and David Lester’s guide to implementing non-strict functional languages. Free PDF. Includes complete working prototypes using lazy graph reduction.
-
Write You a Haskell - My old tutorial on functional language implementation including parser, type inference, pattern matching, typeclasses, STG intermediate language, and native code generation.
-
Compiler Construction - Niklaus Wirth’s concise, practical guide. Step-by-step approach through each compiler design stage focusing on practical implementation.
Parsing Tools and Techniques
-
The Definitive ANTLR 4 Reference - Terence Parr’s essential guide to ANTLR parser generator with LL(*) parsing technology. Covers grammar construction, tree construction, and StringTemplate code generation.
-
ANTLR Mega Tutorial - Federico Tomassetti’s comprehensive tutorial covering ANTLR setup for multiple languages (JavaScript, Python, Java, C#), testing approaches, and advanced features.
-
LR Parsing Theory and Practice - Excellent blog post demystifying the differences between LL and LR parsing with practical examples.
Courses
-
CS 6120: Advanced Compilers: The Self-Guided Online Course - Cornell’s graduate-level compiler optimization course. Covers SSA form, dataflow analysis, loop optimizations, and modern optimization techniques with hands-on projects.
-
Stanford CS 143: Compilers - Introduction to compiler construction covering lexical analysis through code generation. Includes programming assignments building a compiler for a Java-like language.
-
IU P423/P523 Compilers - Jeremy Siek’s course using incremental approach with Racket. Materials available on GitHub.
-
KAIST CS420 Compiler Design - Modern treatment with Rust implementation. Course materials and assignments available on GitHub.
Online Resources and Tutorials
-
Basics of Compiler Design - Torben Mogensen’s free PDF textbook providing solid introduction to compiler construction fundamentals.
-
Compiler Design Tutorials - Collection of articles covering compiler topics from basic to advanced with code examples.
Research Papers and Academic Resources
-
Static Single Assignment Form and the Control Dependence Graph - Cytron et al.’s seminal paper on SSA form, now standard in modern optimizing compilers.
-
A Nanopass Framework for Compiler Education - Describes breaking compiler passes into tiny transformations, making complex optimizations easier to understand and verify.
-
Linear Scan Register Allocation - Massimiliano Poletto and Vivek Sarkar’s influential paper on fast register allocation suitable for JIT compilers.
Rust-Specific Compiler Resources
-
Rust Compiler Development Guide - The official guide to rustc internals. Essential reading for understanding how a production Rust compiler works.
-
Salsa - Framework for incremental computation used by rust-analyzer. Demonstrates modern techniques for responsive compiler frontends.
-
Make A Language - Series of blog posts walking through implementing a programming language in Rust, from lexing through type checking.
-
Introduction to LLVM in Rust - Rust implementation of the LLVM Kaleidoscope tutorial demonstrating LLVM bindings.
Code Generation Resources
-
Cranelift Code Generator - Production code generator written in Rust, designed for JIT and AOT compilation. Good example of modern compiler backend architecture.
-
Introduction to LLVM - Tutorial on using LLVM’s C++ API to generate code, covering the basics of LLVM IR and the programmatic interface.
-
Compiler Design in C - Allen Holub’s 924-page detailed coverage of real-world compiler implementation focusing on code generation.
Tools and Development Environments
-
Tree-sitter - Parser generator creating incremental parsers suitable for editor integration. Good example of modern parsing technology beyond traditional compiler construction.
-
ANTLR - Popular parser generator supporting multiple target languages. Extensive documentation and community resources.
-
Compiler Explorer - Online tool for exploring compiler output across different compilers and optimization levels. Invaluable for understanding code generation.
-
ANTLRWorks - GUI development environment for ANTLR grammars with visualization and debugging features.
Community and Support
-
r/Compilers - Active subreddit for compiler construction discussions, project showcases, and questions.
-
Compiler Jobs - Matthew Gaudet’s curated list of compiler engineering positions.
-
LLVM Discourse - Official LLVM community forum for discussions about LLVM, Clang, and related projects.
-
Rust Compiler Team - Information about contributing to the Rust compiler and joining the community.