melior
The melior
crate provides safe Rust bindings to MLIR (Multi-Level Intermediate Representation), enabling compiler developers to leverage MLIR’s powerful infrastructure for building optimizing compilers and domain-specific code generators. MLIR represents computations as a graph of operations organized into regions and blocks, supporting multiple levels of abstraction from high-level tensor operations to low-level machine instructions. The melior bindings expose MLIR’s dialect system, allowing developers to work with various IR representations including arithmetic operations, control flow, tensor computations, and LLVM IR generation.
The architecture of melior wraps MLIR’s C API with idiomatic Rust abstractions, providing type safety and memory safety guarantees while maintaining the flexibility of MLIR’s extensible design. The crate supports all standard MLIR dialects including func, arith, scf, tensor, memref, and llvm, enabling progressive lowering from high-level abstractions to executable code. This multi-level approach allows compilers to perform optimizations at the most appropriate abstraction level, improving both compilation time and generated code quality.
Installation on macOS
To use melior on macOS, you need to install LLVM/MLIR 20 via Homebrew:
brew install llvm@20
You can get the LLVM installation path with:
$(brew --prefix llvm@20)
Add melior to your Cargo.toml:
[dependencies]
melior = "0.25"
Basic Usage
Creating an MLIR context with all dialects loaded:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } }
The context manages dialect registration and configuration. MLIR requires dialects to be loaded before you can use their operations. The create_test_context
function loads all standard dialects and LLVM translations for immediate use.
Function Creation
Building simple arithmetic functions:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } }
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } }
Function creation involves specifying parameter types and return types using MLIR’s type system. The function body consists of a region containing basic blocks, with the entry block receiving function parameters as block arguments. This structure supports both simple functions and complex control flow patterns.
Arithmetic Operations
Creating constant values and arithmetic computations:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } }
The arith dialect provides integer and floating-point arithmetic operations. Constants are materialized using arith::constant operations, and computations build expression trees through operation chaining. Each operation produces results that subsequent operations consume, creating a dataflow graph representation.
Type and Attribute Builders
Helper utilities for creating MLIR types and attributes:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } }
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } }
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } }
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } }
These builders provide convenient methods for creating common MLIR types and attributes without manually constructing them each time. The TypeBuilder handles integer types, index types, and function types. The AttributeBuilder creates string attributes, integer attributes, and type attributes.
Module Operations
Utility functions for working with MLIR modules:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } }
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } }
These utilities help verify module correctness and convert modules to their textual MLIR representation for debugging and inspection.
MLIR Transformations and Optimization Passes
MLIR’s power comes from its transformation infrastructure. The PassManager orchestrates optimization passes that transform and optimize IR at different abstraction levels.
Basic Transformations
Canonicalizer simplifies IR by applying local pattern-based rewrites:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } }
Common Subexpression Elimination (CSE) removes redundant computations:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } }
Sparse Conditional Constant Propagation (SCCP) performs constant folding and dead code elimination:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } }
Function Optimizations
Inlining replaces function calls with their bodies:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } }
Symbol DCE removes unused functions and global symbols:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } }
Loop Optimizations
Loop-Invariant Code Motion (LICM) hoists invariant computations out of loops:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } }
Memory Optimizations
Promote memory allocations to SSA registers using mem2reg:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } }
GPU Transformations
Convert parallel patterns to GPU kernels using GPU dialect lowering:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } }
Utility Passes
Strip debug information for release builds using strip-debuginfo:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } }
Optimization Pipelines
Combine multiple passes into an optimization pipeline:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } }
Custom Pass Pipelines
Build fluent transformation pipelines with the PassPipeline builder:
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } }
#![allow(unused)] fn main() { use melior::dialect::{arith, func, DialectRegistry}; use melior::ir::attribute::{IntegerAttribute, StringAttribute, TypeAttribute}; use melior::ir::operation::OperationLike; use melior::ir::r#type::{FunctionType, IntegerType}; use melior::ir::*; use melior::pass::{gpu, transform, PassManager}; use melior::utility::{register_all_dialects, register_all_llvm_translations, register_all_passes}; use melior::{Context, Error}; /// Creates a test context with all dialects loaded pub fn create_test_context() -> Context { let context = Context::new(); let registry = DialectRegistry::new(); register_all_dialects(®istry); register_all_passes(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); register_all_llvm_translations(&context); context } /// Creates a simple function that adds two integers pub fn create_add_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let index_type = Type::index(context); module.body().append_operation(func::func( context, StringAttribute::new(context, "add"), TypeAttribute::new( FunctionType::new(context, &[index_type, index_type], &[index_type]).into(), ), { let block = Block::new(&[(index_type, location), (index_type, location)]); let sum = block .append_operation(arith::addi( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[sum.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a function that multiplies two 32-bit integers pub fn create_multiply_function(context: &Context) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i32_type = IntegerType::new(context, 32).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "multiply"), TypeAttribute::new(FunctionType::new(context, &[i32_type, i32_type], &[i32_type]).into()), { let block = Block::new(&[(i32_type, location), (i32_type, location)]); let product = block .append_operation(arith::muli( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[product.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Creates a constant integer value pub fn create_constant(context: &Context, value: i64) -> Result<Module<'_>, Error> { let location = Location::unknown(context); let module = Module::new(location); let i64_type = IntegerType::new(context, 64).into(); module.body().append_operation(func::func( context, StringAttribute::new(context, "get_constant"), TypeAttribute::new(FunctionType::new(context, &[], &[i64_type]).into()), { let block = Block::new(&[]); let constant = block .append_operation(arith::constant( context, IntegerAttribute::new(i64_type, value).into(), location, )) .result(0) .unwrap(); block.append_operation(func::r#return(&[constant.into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location, )); Ok(module) } /// Shows how to verify MLIR modules pub fn verify_module(module: &Module<'_>) -> bool { module.as_operation().verify() } /// Shows how to print MLIR to string pub fn module_to_string(module: &Module<'_>) -> String { format!("{}", module.as_operation()) } /// Apply canonicalization transforms to simplify IR pub fn apply_canonicalization(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.run(module) } /// Apply CSE (Common Subexpression Elimination) to remove redundant /// computations pub fn apply_cse(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_cse()); pass_manager.run(module) } /// Apply inlining transformation to inline function calls pub fn apply_inlining(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_inliner()); pass_manager.run(module) } /// Apply loop-invariant code motion to optimize loops pub fn apply_licm(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_loop_invariant_code_motion()); pass_manager.run(module) } /// Apply SCCP (Sparse Conditional Constant Propagation) for constant folding pub fn apply_sccp(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_sccp()); pass_manager.run(module) } /// Apply a pipeline of optimization passes pub fn optimize_module(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); // Standard optimization pipeline pass_manager.add_pass(transform::create_canonicalizer()); pass_manager.add_pass(transform::create_cse()); pass_manager.add_pass(transform::create_sccp()); pass_manager.add_pass(transform::create_inliner()); pass_manager.add_pass(transform::create_canonicalizer()); // Run again after inlining pass_manager.run(module) } /// Example of applying symbol DCE (Dead Code Elimination) pub fn apply_symbol_dce(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_symbol_dce()); pass_manager.run(module) } /// Apply mem2reg transformation to promote memory to registers pub fn apply_mem2reg(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_mem_2_reg()); pass_manager.run(module) } /// Convert parallel loops to GPU kernels pub fn convert_to_gpu(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(gpu::create_gpu_kernel_outlining()); pass_manager.run(module) } /// Strip debug information from the module pub fn strip_debug_info(context: &Context, module: &mut Module<'_>) -> Result<(), Error> { let pass_manager = PassManager::new(context); pass_manager.add_pass(transform::create_strip_debug_info()); pass_manager.run(module) } /// Type builder helper for creating complex types pub struct TypeBuilder<'c> { context: &'c Context, } impl<'c> TypeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn i32(&self) -> Type<'c> { IntegerType::new(self.context, 32).into() } pub fn i64(&self) -> Type<'c> { IntegerType::new(self.context, 64).into() } pub fn index(&self) -> Type<'c> { Type::index(self.context) } pub fn function(&self, inputs: &[Type<'c>], outputs: &[Type<'c>]) -> FunctionType<'c> { FunctionType::new(self.context, inputs, outputs) } } /// Attribute builder helper for creating attributes pub struct AttributeBuilder<'c> { context: &'c Context, } impl<'c> AttributeBuilder<'c> { pub fn new(context: &'c Context) -> Self { Self { context } } pub fn string(&self, value: &str) -> StringAttribute<'c> { StringAttribute::new(self.context, value) } pub fn integer(&self, ty: Type<'c>, value: i64) -> IntegerAttribute<'c> { IntegerAttribute::new(ty, value) } pub fn type_attr(&self, ty: Type<'c>) -> TypeAttribute<'c> { TypeAttribute::new(ty) } } /// Custom pass builder for creating transformation pipelines pub struct PassPipeline<'c> { pass_manager: PassManager<'c>, } #[cfg(test)] mod tests { use super::*; #[test] fn test_add_function() { let context = create_test_context(); let module = create_add_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @add")); assert!(ir.contains("arith.addi")); } #[test] fn test_multiply_function() { let context = create_test_context(); let module = create_multiply_function(&context).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); assert!(ir.contains("arith.muli")); } #[test] fn test_constant_creation() { let context = create_test_context(); let module = create_constant(&context, 42).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("arith.constant")); assert!(ir.contains("42")); } #[test] fn test_type_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let i32 = types.i32(); let i64 = types.i64(); let func_type = types.function(&[i32, i32], &[i64]); assert_eq!(func_type.input(0).unwrap(), i32); assert_eq!(func_type.input(1).unwrap(), i32); assert_eq!(func_type.result(0).unwrap(), i64); } #[test] fn test_attribute_builder() { let context = create_test_context(); let types = TypeBuilder::new(&context); let attrs = AttributeBuilder::new(&context); let name = attrs.string("test_function"); // StringAttribute doesn't expose string value directly in the API assert!(name.is_string()); let value = attrs.integer(types.i32(), 100); assert_eq!(value.value(), 100); } #[test] fn test_canonicalization() { let context = create_test_context(); let mut module = create_add_function(&context).unwrap(); let _before = module_to_string(&module); apply_canonicalization(&context, &mut module).unwrap(); let after = module_to_string(&module); assert!(verify_module(&module)); // Canonicalization should preserve or simplify the IR assert!(!after.is_empty()); } #[test] fn test_optimization_pipeline() { let context = create_test_context(); let mut module = create_multiply_function(&context).unwrap(); // Apply optimization pipeline optimize_module(&context, &mut module).unwrap(); assert!(verify_module(&module)); let ir = module_to_string(&module); assert!(ir.contains("func.func @multiply")); } #[test] fn test_pass_pipeline_builder() { let context = create_test_context(); let mut module = create_constant(&context, 100).unwrap(); let pipeline = PassPipeline::new(&context) .canonicalize() .eliminate_common_subexpressions() .propagate_constants(); pipeline.run(&mut module).unwrap(); assert!(verify_module(&module)); } } impl<'c> PassPipeline<'c> { pub fn new(context: &'c Context) -> Self { Self { pass_manager: PassManager::new(context), } } /// Add a canonicalization pass pub fn canonicalize(self) -> Self { self.pass_manager .add_pass(transform::create_canonicalizer()); self } /// Add a CSE pass pub fn eliminate_common_subexpressions(self) -> Self { self.pass_manager.add_pass(transform::create_cse()); self } /// Add an inlining pass pub fn inline_functions(self) -> Self { self.pass_manager.add_pass(transform::create_inliner()); self } /// Add SCCP for constant propagation pub fn propagate_constants(self) -> Self { self.pass_manager.add_pass(transform::create_sccp()); self } /// Add loop optimizations pub fn optimize_loops(self) -> Self { self.pass_manager .add_pass(transform::create_loop_invariant_code_motion()); self } /// Run the pipeline on a module pub fn run(self, module: &mut Module<'c>) -> Result<(), Error> { self.pass_manager.run(module) } } }
The PassPipeline builder provides a fluent API for constructing custom optimization sequences. Each transformation method returns self, allowing method chaining. The pipeline executes passes in the order they were added, enabling precise control over optimization phases.
Best Practices
Structure compilation as progressive lowering through multiple abstraction levels. Start with domain-specific representations and lower gradually to executable code. This approach enables optimizations at appropriate abstraction levels and improves compiler modularity.
Leverage MLIR’s dialect system to separate concerns. Use high-level dialects for domain logic, mid-level dialects for general optimizations, and low-level dialects for code generation. This separation enables reuse across different compilation pipelines.
Design custom operations to be composable and orthogonal. Avoid monolithic operations that combine multiple concepts. Instead, build complex behaviors from simple, well-defined operations that optimization passes can analyze and transform.
Use MLIR’s type system to enforce invariants. Rich types catch errors early and enable optimizations. Track properties like tensor dimensions, memory layouts, and value constraints through the type system rather than runtime checks.
Implement verification for custom operations. Verification catches IR inconsistencies early and provides better error messages. Well-verified IR enables aggressive optimizations without compromising correctness.
Build reusable transformation patterns. Patterns that match common idioms enable optimization across different contexts. Parameterize patterns to handle variations while maintaining transformation correctness.