Understanding the theoretical underpinnings of Just-In-Time (JIT) compilation, such as tracing versus scripting or adaptive optimization, is essential. However, gaining practical insight requires inspecting the actual output generated by these JIT compilers. Analyzing the Intermediate Representation (IR) or compiled code provides concrete evidence of how optimizations like operator fusion, constant folding, and specialization are applied, directly connecting the high-level Python code to the low-level execution plan. This practical analysis is invaluable for debugging performance issues, verifying optimization effectiveness, and deepening your understanding of the JIT process.
In this hands-on section, we'll walk through the process of JIT-compiling simple model fragments using popular frameworks like PyTorch (TorchScript) and TensorFlow (XLA) and then analyze the resulting IR. We assume you have a working Python environment with PyTorch and TensorFlow installed.
PyTorch's JIT module, TorchScript, provides two main ways to convert Python code to an optimizable graph representation: tracing (torch.jit.trace
) and scripting (torch.jit.script
). Let's start with tracing.
Consider a simple sequence of operations: a linear layer followed by a ReLU activation and another linear layer.
import torch
import torch.nn as nn
class SimpleModel(nn.Module):
def __init__(self, input_features, hidden_features, output_features):
super().__init__()
self.linear1 = nn.Linear(input_features, hidden_features)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(hidden_features, output_features)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
# Instantiate the model and create dummy input
input_size = 128
hidden_size = 256
output_size = 64
model = SimpleModel(input_size, hidden_size, output_size)
dummy_input = torch.randn(32, input_size) # Batch size 32
# Trace the model
traced_model = torch.jit.trace(model, dummy_input)
print("Model successfully traced.")
Tracing executes the model with the provided dummy input and records the operations performed. The resulting traced_model
contains a static graph representation specialized to the input shape (32×128). We can inspect this graph:
# Print the TorchScript Graph IR
print(traced_model.graph)
You will see output resembling this (details may vary slightly across PyTorch versions):
graph(%self.1 : __torch__.SimpleModel,
%x : Float(32, 128, strides=[128, 1], requires_grad=0, device=cpu)):
%linear1 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear1"](%self.1)
%3 : Float(256, 128, strides=[128, 1], device=cpu) = prim::GetAttr[name="weight"](%linear1)
%4 : Float(256, strides=[1], device=cpu) = prim::GetAttr[name="bias"](%linear1)
%5 : Tensor = aten::linear(%x, %3, %4) # <eval_with_key>.13:10:8
%relu : __torch__.torch.nn.modules.activation.ReLU = prim::GetAttr[name="relu"](%self.1)
%7 : Tensor = aten::relu(%5) # <eval_with_key>.14:8:8
%linear2 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear2"](%self.1)
%9 : Float(64, 256, strides=[256, 1], device=cpu) = prim::GetAttr[name="weight"](%linear2)
%10 : Float(64, strides=[1], device=cpu) = prim::GetAttr[name="bias"](%linear2)
%11 : Tensor = aten::linear(%7, %9, %10) # <eval_with_key>.15:8:8
return (%11)
Analysis:
%self.1
) and the input tensor %x
with its traced shape Float(32, 128, ...)
.aten::linear
, aten::relu
) or attribute access (prim::GetAttr
).%3
, %4
, %9
, %10
) are embedded as graph inputs or attributes, retrieved via prim::GetAttr
.if
statements (unless they were constant during tracing). It's specialized for inputs of shape (32, 128)
. Running traced_model
with a different shape might work if the operations support it, but the graph itself doesn't explicitly represent dynamic shape logic.TorchScript automatically applies optimization passes, including fusion. While simple Linear -> ReLU
fusion isn't always guaranteed or easily visible directly in this high-level graph dump (it often happens in later lowering stages), let's consider how we might look for it. More complex patterns, especially in convolutional networks (Conv-BN-ReLU), are common fusion targets.
We can examine the graph after potential optimizations, although the default printout often shows the initial graph. To see optimized graphs or execution plans, you might need profiling tools or internal APIs (use with caution as they can change):
# Example using an internal API (subject to change)
# This provides a different view, sometimes showing optimized/fused ops
# print(torch._C._jit_get_profiling_executor_graph(traced_model.graph))
Alternatively, analyzing the performance characteristics (e.g., using torch.profiler.profile
) before and after JIT compilation can reveal the impact of optimizations like fusion, even if the IR visualization isn't explicit about the fused operation.
Now, let's use scripting, which directly analyzes the Python bytecode.
import torch
import torch.nn as nn
# Assume SimpleModel class is defined as above
# Script the model
scripted_model = torch.jit.script(model)
print("Model successfully scripted.")
# Print the TorchScript Graph IR
print(scripted_model.graph)
The output graph will look very similar to the traced one for this linear model because it contains no Python control flow that tracing would miss. However, if the model included data-dependent control flow (e.g., an if
statement based on tensor values), the scripted graph would explicitly represent this control flow using prim::If
nodes, whereas a traced graph would only contain the path taken during tracing.
Analysis (script
vs trace
):
if
, for
) directly in the graph, making it more flexible for models with dynamic behavior. Tracing only captures the operations executed for the specific trace input.TensorFlow uses XLA (Accelerated Linear Algebra) as its optimizing compiler, often invoked via tf.function(jit_compile=True)
. XLA operates on its own IR, HLO (High-Level Optimizer IR).
import tensorflow as tf
# Define a simple function
@tf.function(jit_compile=True)
def simple_computation(x, w1, b1, w2, b2):
y = tf.matmul(x, w1) + b1
y = tf.nn.relu(y)
z = tf.matmul(y, w2) + b2
return z
# Create some example inputs
input_shape = (32, 128)
hidden_shape = 256
output_shape = 64
x_in = tf.random.normal(input_shape)
w1_in = tf.random.normal((input_shape[1], hidden_shape))
b1_in = tf.random.normal((hidden_shape,))
w2_in = tf.random.normal((hidden_shape, output_shape))
b2_in = tf.random.normal((output_shape,))
# Execute the JIT-compiled function
result = simple_computation(x_in, w1_in, b1_in, w2_in, b2_in)
print("XLA JIT function executed.")
# print(result.numpy()) # To see the output
Inspecting the HLO generated by XLA typically involves using environment variables or TensorFlow logging/profiling tools. One common way is to use TensorBoard for profiling, which can visualize the XLA compilation steps and the resulting HLO graph. Another method involves setting environment variables before running the script:
export TF_XLA_FLAGS="--tf_xla_dump_to=/path/to/dump/folder"
# Now run your Python script
python your_script.py
This command instructs XLA to dump various stages of its compilation process, including HLO graphs (often as .hlo.dot
files or text proto formats), into the specified directory. You can then inspect these files.
For example, a .dot
file can be visualized using Graphviz tools (dot -Tpng input.dot -o output.png
). The HLO graph might look conceptually like this (simplified):
Simplified conceptual HLO graph for the
simple_computation
function. Nodes represent operations (dot for matmul, add, maximum for relu), edges show data flow, and parameters represent inputs.
Analysis:
dot
(for matrix multiplication), add
, maximum
(often used for ReLU), broadcast
, and parameters.fusion
node. For example, the MatMul + BiasAdd + ReLU
sequence is a prime candidate for fusion into a single optimized kernel, which would appear as one node in the optimized HLO graph.This hands-on analysis demonstrates how to access and interpret the intermediate representations generated by ML JIT compilers. By examining TorchScript graphs or XLA HLO dumps, you can:
Analyzing JIT output is a fundamental skill for performance engineers working with ML frameworks. It bridges the gap between high-level model code and the optimized, hardware-specific execution plan, providing crucial insights for achieving maximum performance. Use these techniques to explore the JIT behavior for your own models and investigate the impact of different JIT strategies discussed in this chapter.
© 2025 ApX Machine Learning