As discussed previously, traditional compiler IRs often lack the expressiveness required to represent the hierarchical nature and domain-specific operations found in machine learning models. MLIR (Multi-Level Intermediate Representation) addresses this by providing a common, extensible infrastructure. Instead of a fixed IR, MLIR offers a framework for defining and composing different dialects, each tailored to a specific level of abstraction or domain.
MLIR's power stems from a few fundamental, well-defined structures:
Operations (Ops): The primary unit of semantics and structure. An Op represents a computation (like matrix multiplication), a structural element (like a function definition), a constant value, or a meta-operation (like module termination). Every Op belongs to a specific dialect and has:
arith.addi
, func.func
, linalg.matmul
).Attributes: Compile-time constant values attached to Ops. They represent static information that doesn't change during execution, such as literal constants, type information, dimension permutations, or configuration strings. MLIR has built-in attributes (integers, floats, strings, types) and dialects can define custom ones. For example, a convolution Op might have attributes for stride (#stride<[1, 1]>
) and padding (#padding<"SAME">
).
Types: Define the data types of SSA values (Op results and block arguments). MLIR provides standard built-in types like integers (i32
, i8
), floats (f32
, f16
), vectors (vector<4xf32>
), and tensors (tensor<10x20xf32>
). Crucially, the type system is extensible via dialects, allowing the representation of specialized hardware types or abstract domain-specific types (like quantized types).
Regions: Ordered lists of Blocks contained within an Op. Regions provide the mechanism for nesting and defining scopes. For instance, a func.func
Op contains a Region representing the function body, and looping Ops like scf.for
contain Regions for their loop bodies. This hierarchical structure is fundamental to MLIR's multi-level nature.
Blocks: Sequences of Ops within a Region, analogous to basic blocks in traditional compilers. A Block takes a list of arguments (SSA values) and contains a sequence of Ops. The last Op in a Block must be a terminator Op (like func.return
or cf.br
), which defines control flow, transferring execution to other Blocks or terminating the containing Region.
High-level structure of an MLIR function Op (
func.func
) containing a Region with one Block. The Block takes an argument, executes Ops sequentially (conceptually using SSA values), and ends with a terminator Op.
The core MLIR structures are intentionally minimal. The real power and specificity come from Dialects. A dialect acts as a namespace for a set of related Ops, Attributes, and Types. Think of it like a library or module providing domain-specific constructs.
MLIR comes with several standard dialects:
builtin
: Defines fundamental concepts like module structure, basic types (integers, floats), attributes (strings, arrays), and function type definitions. It forms the bedrock of any MLIR program.func
: Defines function-related Ops like func.func
(function definition), func.call
, and func.return
.arith
: Provides standard arithmetic operations (addition, subtraction, comparison, etc.) primarily on scalar and vector types.vector
: Defines Ops for manipulating values of the vector
type (e.g., broadcasts, shuffles, contractions).tensor
: Defines Ops operating on the tensor
type (e.g., tensor.extract
, tensor.insert
, tensor.cast
).linalg
: Offers high-level, structured operations on tensors and buffers, often expressed generically (e.g., linalg.generic
, linalg.matmul
, linalg.conv_2d_nhwc_hwcf
). This dialect is important for abstracting linear algebra computations before detailed loop generation.memref
: Represents memory buffers (memref
type) with associated layout information and provides Ops for allocation (memref.alloc
), deallocation, and access. Used closer to buffer-based code generation.scf
: Structured Control Flow dialect, providing Ops for loops (scf.for
, scf.parallel
) and conditionals (scf.if
) with explicit region nesting.cf
: Control Flow dialect, providing traditional basic block branching (cf.br
, cf.cond_br
), similar to LLVM IR's control flow.Beyond these, compiler projects define their own dialects. For example, TensorFlow might use an tf
or mhlo
dialect to represent its graph operations directly. A compiler targeting a specific accelerator might define a custom dialect representing the hardware's unique instructions.
Let's look at a small MLIR snippet representing a function that adds two 32-bit floating-point tensors element-wise:
// func.func defines a function named 'add_tensors'.
// It takes two arguments: %arg0 and %arg1, both of type tensor<10x20xf32>.
// It returns one result of the same type.
func.func @add_tensors(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>) -> tensor<10x20xf32> {
// The function body is a single block.
// The 'linalg.generic' Op performs an element-wise operation.
// 'indexing_maps' define how each dimension of the inputs/output relates to the loop iterators (affine maps).
// 'iterator_types' specify parallel loops for dimensions 'd0' and 'd1'.
%result = linalg.generic {
indexing_maps = [
affine_map<(d0, d1) -> (d0, d1)>, // Input 0 map
affine_map<(d0, d1) -> (d0, d1)>, // Input 1 map
affine_map<(d0, d1) -> (d0, d1)> // Output map
],
iterator_types = ["parallel", "parallel"] // Two parallel loops
}
// 'ins' specifies the input tensors (%arg0, %arg1).
// 'outs' specifies the output tensor buffer(s). Here, it's implicitly allocated based on operands.
ins(%arg0, %arg1 : tensor<10x20xf32>, tensor<10x20xf32>)
outs(%arg0 : tensor<10x20xf32>) { // Note: Output buffer shape inference often uses an input shape.
// The region defines the element-wise computation.
^bb0(%in0: f32, %in1: f32, %out_unused: f32): // Block arguments correspond to elements from inputs/outputs
// 'arith.addf' performs floating-point addition on the scalar elements.
%sum = arith.addf %in0, %in1 : f32
// 'linalg.yield' returns the computed element for the output tensor.
linalg.yield %sum : f32
} -> tensor<10x20xf32> // The Op returns the resulting tensor.
// 'func.return' terminates the function, returning the result tensor.
func.return %result : tensor<10x20xf32>
}
In this example:
func.func
, func.return
are from the func
dialect.linalg.generic
, linalg.yield
are from the linalg
dialect.arith.addf
is from the arith
dialect.tensor<10x20xf32>
and f32
are built-in types.@add_tensors
is a symbol name (an attribute).%arg0
, %arg1
, %result
, %in0
, %in1
, %sum
are SSA values.{...}
for linalg.generic
contains attributes (indexing_maps
, iterator_types
) and a Region defining the element-level computation.This structured representation, combining Ops from different dialects (func
, linalg
, arith
), clearly defines the function's signature, its high-level tensor operation (linalg.generic
), and the precise element-wise computation (arith.addf
) within its nested Region. It serves as a starting point for various optimizations before eventually being lowered to machine code.
The combination of a minimal core structure (Ops, Regions, Blocks, Attributes, Types) and the extensible dialect system makes MLIR a powerful foundation for building modern compilers, especially in complex domains like machine learning where multiple levels of abstraction must coexist and interact. We will see in subsequent sections how MLIR dialects are used to represent entire ML graphs and how transformations progressively lower these representations towards executable code.
© 2025 ApX Machine Learning