Automation tools and high-level frameworks handle many optimizations transparently. However, distinct performance gains often require custom intervention. Writing a graph pass allows you to define specific transformation logic tailored to your model architecture or hardware constraints. This process involves traversing the Intermediate Representation (IR), identifying patterns that match inefficient subgraphs, and rewriting them into more efficient structures.The Pass Manager and Visitor PatternMost modern ML compilers, including TVM and MLIR, organize optimizations using a Pass Manager. The Pass Manager orchestrates the execution of various passes, ensuring dependencies are met and the graph remains valid between transformations.The core mechanism for implementing a pass is the Visitor Pattern. The compiler traverses the graph nodes, typically in post-order (leaves to root), and invokes a specific callback function for each node type. As a developer, you override these callbacks to inject custom logic.When you write a pass, you are essentially defining two logical steps:Pattern Matching: Inspecting the current node and its neighbors to see if they fit a target structure.Rewriting: Creating new nodes that represent the optimized logic and updating the graph pointers to replace the old subgraph.Identifying a Fusion CandidateTo demonstrate this, consider a common operator fusion scenario: fusing a convolution operation followed immediately by a Rectified Linear Unit (ReLU) activation. In a naive execution, the hardware performs the convolution, writes the result to memory, reads it back, applies the ReLU, and writes it again.Fusing these operations enables the hardware to apply the activation on the output of the convolution while it is still in the accumulator or cache.The following diagram illustrates the structural change required in the graph.digraph G { rankdir=TB; bgcolor="transparent"; node [shape=box, style="filled,rounded", fontname="Sans-Serif", fontsize=10, margin=0.2, color="transparent"]; edge [fontname="Sans-Serif", fontsize=9, color="#868e96"]; subgraph cluster_0 { label="Original Graph"; style="dashed"; color="#adb5bd"; fontcolor="#495057"; node_in [label="Input", fillcolor="#e9ecef"]; node_conv [label="Conv2D", fillcolor="#a5d8ff"]; node_relu [label="ReLU", fillcolor="#ffc9c9"]; node_out [label="Output", fillcolor="#e9ecef"]; node_in -> node_conv; node_conv -> node_relu; node_relu -> node_out; } subgraph cluster_1 { label="Transformed Graph"; style="dashed"; color="#adb5bd"; fontcolor="#495057"; t_in [label="Input", fillcolor="#e9ecef"]; t_fused [label="Conv2D_ReLU", fillcolor="#b2f2bb"]; t_out [label="Output", fillcolor="#e9ecef"]; t_in -> t_fused; t_fused -> t_out; } }Comparison of a standard operator sequence and a fused operator node.Implementing the TransformationIn a Python-based compiler interface (similar to PyTorch FX or TVM Relay), a pass is implemented as a class that inherits from a Mutator or Transformer base class.The implementation requires identifying a Call node where the operator is ReLU. Once found, you inspect the input to that ReLU. If the input is a Conv2D node, a match is confirmed.Here is a structural example of how such a pass is implemented in a generic IR framework:class FuseConvReLU(ExprMutator): def visit_call(self, call_node): # First, visit children to ensure bottom-up optimization new_call = super().visit_call(call_node) # Step 1: Pattern Matching # Check if the current node is a ReLU operation if new_call.op.name == "nn.relu": # Inspect the input to the ReLU # The input is typically at index 0 of arguments input_node = new_call.args[0] # Check if the input is a Conv2D operation if isinstance(input_node, Call) and input_node.op.name == "nn.conv2d": # Pattern Matched: ReLU(Conv2D(...)) return self.rewrite_conv_relu(input_node) # No match found, return the original node return new_call def rewrite_conv_relu(self, conv_node): # Step 2: Rewriting # Create a new operator that represents the fused kernel # We extract attributes (weights, strides, padding) from the original conv new_op = Op("nn.conv2d_relu") # Construct the new Call node using the original inputs # The inputs to Conv2D (data, weight) become inputs to Conv2D_ReLU return Call(new_op, conv_node.args, conv_node.attrs)Handling Graph Validity and TypesWhen modifying the graph, maintaining the correctness of the program is the primary concern. The compiler relies on type information (tensor shapes and data types) to allocate memory. When you replace Conv2D and ReLU with Conv2D_ReLU, the output shape of the new node must match the output shape of the original ReLU node.In element-wise fusion (like ReLU), the shape remains identical. However, if you were fusing operations that alter shapes (such as a pooling layer), you must ensure the new operator propagates shape information correctly. Most IR frameworks include a Relayer or TypeInference pass that should be run immediately after your custom mutation to update the metadata of the new nodes.Advanced Matching: The Multi-Consumer ProblemA common issue in graph substitution is the multi-consumer scenario. If the output of the Conv2D node is used by the ReLU and another node (for example, a skip connection in a ResNet), you cannot simply fuse the Conv2D into the ReLU.If you fuse them, the Conv2D instruction disappears. The other node expecting the raw convolution output will essentially lose its input, or the compiler will be forced to duplicate the convolution computation.To handle this, passes include a check for the number of consumers: # Inside the pattern matcher if input_node.op.name == "nn.conv2d": # Check how many other nodes reference this convolution users = self.get_users(input_node) if len(users) > 1: # The convolution result is needed elsewhere. # We cannot fuse safely without duplicating work. return new_call return self.rewrite_conv_relu(input_node)This check ensures that optimizations do not inadvertently increase the computational load by forcing re-computation of shared intermediate values.Verification and TestingAfter implementing the pass, verification is necessary to ensure the transformation is semantically equivalent. This involves running the original graph and the transformed graph with the same random input data.$$ \text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (Y_{original} - Y_{transformed})^2 $$The Mean Squared Error (MSE) between the outputs should be zero (or within floating-point tolerance). If the outputs diverge, the rewriting logic likely mishandled an attribute, such as padding or stride, during the construction of the new node. Using visual inspection tools to print the IR before and after the pass helps isolate where the structure diverges from expectation.