As introduced, Just-In-Time (JIT) compilation fundamentally shifts when optimizations and code generation occur, moving these phases from an offline Ahead-of-Time (AOT) process to the actual runtime execution. This temporal shift imposes unique and demanding requirements on the Intermediate Representation (IR) used within the JIT compiler. Unlike AOT scenarios where compilation time is less critical, the IR design in a JIT system must prioritize not only expressive power but also the speed of its construction, manipulation, and lowering, as these directly impact the user-perceived application latency.
Core Requirements for JIT Intermediate Representations
The effectiveness of an ML JIT compiler hinges significantly on the capabilities of its IR. Several properties are particularly important:
- Multi-Level Abstraction and Flexibility: JIT compilation often starts by capturing operations at a level close to the source framework (e.g., Python operations in PyTorch or TensorFlow graph nodes). The IR must faithfully represent these high-level constructs, including dynamic control flow and framework-specific semantics. Subsequently, it needs to support progressive lowering through potentially multiple intermediate levels, enabling graph-level optimizations (like fusion), tensor/loop-level transformations (like tiling), and finally mapping to low-level hardware instructions. This necessitates an IR capable of representing computations at varying granularities.
- Efficiency of Manipulation: Since compilation occurs during execution, the time spent constructing, traversing, and transforming the IR is critical overhead. The IR data structures must be lightweight and allow for fast pattern matching, rewriting, and analysis. Complex or slow-to-manipulate IRs can negate the performance benefits gained from JIT specialization.
- Representation of Dynamic Information: A primary motivation for JIT compilation is leveraging runtime information. The IR must have first-class support for representing dynamic properties, most commonly tensor shapes and data types, which might not be fully known when compilation begins. This often involves using symbolic dimensions, constraints, or type placeholders within the IR structure itself. This dynamic information is the key enabler for runtime specialization.
- Extensibility: The ML landscape evolves rapidly with new operators, hardware targets, and optimization techniques. A JIT system's IR needs to be extensible, allowing the straightforward addition of new operations, data types, or even entirely new abstraction levels (often through a dialect mechanism, as seen in MLIR, discussed in Chapter 2) without requiring fundamental changes to the compiler's core infrastructure.
- Efficient Lowering Pathways: While flexibility is essential, the IR must also provide well-defined and efficient pathways to lower the high-level representation towards executable code. This involves a series of transformation passes that progressively reduce abstraction, resolve dynamic properties (where possible), and map operations to hardware-specific constructs or a lower-level IR like LLVM IR.
Handling Dynamicism in the IR
A core challenge addressed by JIT IRs is representing information that is initially unknown or variable. Tensor shapes are the canonical example. An AOT compiler might require all tensor dimensions to be static constants. A JIT compiler, however, often encounters tensors where some dimensions depend on runtime inputs.
The IR can handle this using mechanisms like:
- Symbolic Dimensions: Representing unknown dimensions with symbols (e.g.,
tensor<Nx1024xf32>
).
- Shape Functions/Constraints: Associating operations with functions or constraints that define output shapes based on input shapes, even if some input dimensions are symbolic. For instance, a matrix multiplication
C = matmul(A, B)
where A
has shape (M,K) and B
has shape (K,N) would have an IR representation encoding that C
has shape (M,N), regardless of whether M, K, or N are concrete values or symbols.
- Type Refinement: As compilation progresses or runtime information becomes available, the JIT compiler can refine the types and shapes within the IR, replacing symbolic dimensions with concrete values derived from tracing or input guards.
This ability to represent and manipulate partially specified information is fundamental to enabling runtime specialization, where the JIT generates code optimized for the actual tensor shapes encountered during a specific execution trace.
Layered Abstraction and Lowering Flow
To balance high-level semantics with low-level optimization needs, JIT IRs often adopt a layered or multi-dialect approach. A typical flow within the JIT might look like this:
A conceptual view of IR lowering stages within a JIT compiler. Compilation starts from captured framework operations, passes through progressively lower-level IRs enabling different optimizations, and culminates in target code generation. Shape specialization often occurs during the transition from high-level to mid-level IR.
This layered approach allows optimizations to be applied at the most suitable level of abstraction: graph fusion on the high-level IR, loop tiling on the mid-level IR, and instruction scheduling on the low-level IR. The JIT compiler orchestrates the transitions (lowerings) between these layers.
Relation to Graph Acquisition (Tracing vs. Scripting)
The method used to capture the user's model impacts the initial form of the JIT IR:
- Tracing: When a model is traced with sample inputs, the JIT observes the sequence of operations executed. The IR is typically built as a dataflow graph where nodes represent the executed operations and edges represent tensor dependencies. Handling control flow encountered during tracing (like conditionals or loops) requires careful IR design to capture the branching logic and potential shape variations accurately. The IR must represent the specific path taken during the trace while potentially embedding information needed to handle other paths if re-compilation occurs.
- Scripting: When a model is defined using a restricted subset of the host language (like TorchScript or TensorFlow's
tf.function
decorator with autograph), the JIT parses this code directly. The resulting IR often more closely resembles an Abstract Syntax Tree (AST) or includes explicit control-flow structures (like scf.if
or scf.for
in MLIR terminology). This gives the compiler more explicit program structure to analyze compared to tracing.
In both cases, the initial IR captures the program structure, which is then refined and optimized using the dynamic context available at runtime.
Contrasting JIT and AOT IR Requirements
While sharing foundational concepts with AOT compiler IRs (like SSA form, operation semantics), JIT IRs operate under different constraints. AOT compilers can afford expensive analyses and transformations because compilation time is offline. They often rely on detailed static information about shapes and types.
JIT IRs, conversely, must be:
- Fast to Generate: Capturing the model via tracing or scripting must be quick.
- Fast to Optimize: Core optimizations applied at runtime must be computationally inexpensive. More complex optimizations might be deferred or applied adaptively based on execution counts (see Section 7.6).
- Adaptable: Designed to incorporate runtime information gracefully and trigger re-optimization or recompilation when necessary (e.g., if tensor shapes change significantly between invocations).
TensorFlow's XLA uses HLO (High Level Optimizer IR) which is graph-based and suitable for aggressive fusion, while PyTorch's TorchScript uses an IR that retains more Pythonic semantics initially before lowering. Both are designed to balance representational power with the performance demands of JIT compilation, embodying the principles discussed here. These systems are examined in more detail in Sections 7.7 and 7.8.
In summary, the intermediate representation is a cornerstone of any JIT compilation system for machine learning. Its design must navigate the trade-offs between faithfully representing high-level, potentially dynamic program semantics and enabling efficient, runtime-sensitive optimization and code generation. The ability to handle dynamic information, support multiple levels of abstraction, and facilitate rapid manipulation are defining characteristics of effective JIT IRs.