As we established, traditional compiler IRs often lack the expressiveness to adequately capture the high-level semantics of machine learning computation graphs. They operate at a level too low to effectively reason about concepts like convolutions, batch normalizations, or tensor layouts directly. This is where multi-level IRs, particularly MLIR, provide a significant advantage by allowing representations tailored to specific abstraction domains through the use of dialects.
MLIR's core design philosophy embraces the idea that no single IR level is sufficient for the entire compilation pipeline. Instead, it provides a framework for defining multiple dialects, each encapsulating a specific set of operations, types, and attributes relevant to a particular domain or abstraction level. For representing high-level ML graphs, dialects serve as the primary mechanism to bridge the gap between the source framework (like TensorFlow, PyTorch, JAX) or a standardized format (like TOSA) and the compiler's internal representation.
Common dialects used at this stage include:
tf
dialect: Directly mirrors operations and types from the TensorFlow ecosystem. This allows for a near one-to-one import of a TensorFlow GraphDef or SavedModel into MLIR, preserving TensorFlow-specific semantics and attributes.tosa
dialect: Represents the Tensor Operator Set Architecture, a standardized set of tensor operations intended as a common target for different frameworks. It provides a more framework-agnostic representation compared to the tf
dialect.mhlo
/ stablehlo
dialect: Originally from XLA (Accelerated Linear Algebra), these dialects represent operations at a level suitable for high-level optimization passes like fusion, often used as an intermediate step after importing from a framework-specific dialect.The choice of which high-level dialect to use first often depends on the compiler's frontend strategy and the source model format. The essential point is that these dialects allow the compiler to "speak the language" of the ML framework, at least initially.
Within a high-level dialect, each MLIR operation corresponds conceptually to a node in the original computation graph. For instance, a 2D convolution operation from TensorFlow might be represented by a tf.Conv2D
operation in the tf
dialect or a tosa.conv2d
operation in the tosa
dialect.
Crucially, these MLIR operations carry attributes that capture the high-level configuration details necessary to preserve the exact semantics of the original framework operation. These attributes are not typically representable in lower-level IRs like LLVM IR. Examples include:
SAME
, VALID
)NHWC
, NCHW
)Consider a simplified conceptual representation of a TensorFlow Conv2D
operation followed by a BiasAdd
and Relu
activation:
A conceptual TensorFlow graph snippet involving convolution, bias addition, and ReLU activation.
This sequence might be represented in the MLIR tf
dialect like this (syntax simplified for clarity):
// %input, %filter, %bias are MLIR SSA values representing tensors
%conv_output = "tf.Conv2D"(%input, %filter) {
strides = [1, 1, 1, 1],
padding = "SAME",
data_format = "NHWC",
dilations = [1, 1, 1, 1]
} : (tensor<1x28x28x1xf32>, tensor<5x5x1x32xf32>) -> tensor<1x28x28x32xf32>
%bias_add_output = "tf.BiasAdd"(%conv_output, %bias) {
data_format = "NHWC"
} : (tensor<1x28x28x32xf32>, tensor<32xf32>) -> tensor<1x28x28x32xf32>
%output = "tf.Relu"(%bias_add_output)
: (tensor<1x28x28x32xf32>) -> tensor<1x28x28x32xf32>
Notice how the MLIR operations (tf.Conv2D
, tf.BiasAdd
, tf.Relu
) directly mirror the TensorFlow concepts. The attributes (strides
, padding
, data_format
) are attached to the relevant operation, preserving essential metadata. The tensor types include shape and element type information.
MLIR adheres to the Static Single Assignment (SSA) form, common in modern compilers. In this form, each variable (represented as an MLIR value) is defined exactly once and can be used multiple times. This naturally models the data flow dependencies inherent in computation graphs.
%conv_output
, %bias_add_output
, %output
in the example above).This use-def chain structure explicitly represents the directed edges of the original computation graph, making data dependencies clear and facilitating analysis and transformation. The MLIR representation, even using high-level dialects, is fundamentally a dataflow graph expressed in SSA form.
Representing the ML model within the compiler using these high-level dialects is significant because it allows optimizations to operate on familiar, domain-specific concepts before crucial information is lost through lowering. For example:
tf.Conv2D
, tf.BiasAdd
, and tf.Relu
are explicit. Lower-level IRs would obscure this pattern behind generic memory access and arithmetic operations.data_format
attribute) are visible.By preserving the high-level semantics of the source graph, MLIR dialects like tf
and tosa
provide the necessary abstraction level for powerful graph-level optimizations. This representation serves as the entry point for the progressive lowering process, where these high-level operations are gradually transformed and decomposed into lower-level dialects, eventually reaching hardware-specific instructions, as we will explore in subsequent sections and chapters.
© 2025 ApX Machine Learning