Let's transition from the theoretical discussion of graph optimization strategies to a practical implementation. This section provides a hands-on exercise focused on building a basic operator fusion pass. While real-world compiler frameworks employ highly sophisticated graph rewriting engines and cost models, constructing a simplified version helps solidify the core concepts of pattern matching and graph transformation.
We will focus on fusing a common sequence found in convolutional neural networks: a Conv2D
operation, followed by a BiasAdd
, and finally a ReLU
activation. Fusing these into a single FusedConv2D_BiasAdd_ReLU
operation can significantly reduce memory bandwidth usage and kernel launch overhead, especially on accelerators like GPUs.
Before implementing the fusion pass, we need a representation for our computation graph. For this exercise, we'll use a simple Python structure. Assume we have graph nodes represented as objects or dictionaries containing at least:
id
: A unique identifier for the node.op_type
: The type of operation (e.g., 'Conv2D', 'BiasAdd', 'ReLU', 'Input', 'FusedOp').inputs
: A list of id
s of the nodes providing input to this node.outputs
: A list of id
s of the nodes consuming the output of this node.attributes
: A dictionary containing operation-specific parameters (e.g., strides, padding for 'Conv2D').A graph itself can be a collection (e.g., a dictionary or list) of these nodes, potentially along with methods to traverse it or find nodes by ID.
Our goal is to find subgraphs matching the Conv2D -> BiasAdd -> ReLU
pattern. This requires traversing the graph and examining node sequences. A common approach is to iterate through all nodes in a topological order (or simply iterate through all nodes if cycles aren't a primary concern for this specific pattern).
For each node identified as a Conv2D
:
BiasAdd
operation.BiasAdd
node also has exactly one output connection.BiasAdd
is a ReLU
operation.If all these conditions are met, we have identified an instance of our target fusion pattern.
Let's visualize the pattern before fusion:
A typical computation graph segment involving Convolution, Bias Addition, and ReLU activation before fusion.
Once the pattern (Conv -> Add -> Relu)
is identified, the graph needs to be transformed:
FusedNode
, with op_type = 'FusedConv2D_BiasAdd_ReLU'
.FusedNode
should be the inputs of the original Conv
node (the input tensor and weights) and the bias input from the BiasAdd
node.ReLU
node should now consume the output of FusedNode
. Update their inputs
list accordingly.Conv
, Add
, and Relu
nodes from the graph's node collection and add the new FusedNode
.FusedNode
needs to inherit relevant attributes. For example, it needs the convolution parameters (strides, padding) from the original Conv
node. The fact that a ReLU activation is included can be stored as an attribute within FusedNode
.Here is a conceptual Python snippet illustrating the core rewriting logic (assuming a graph object graph
with methods like get_node
, add_node
, remove_node
, and update_edge
):
import uuid # For generating unique IDs
def fuse_conv_bias_relu(graph, conv_node_id):
"""
Attempts to fuse Conv2D -> BiasAdd -> ReLU starting from conv_node_id.
Returns True if fusion occurred, False otherwise.
"""
conv_node = graph.get_node(conv_node_id)
if conv_node.op_type != 'Conv2D' or len(conv_node.outputs) != 1:
return False
add_node_id = conv_node.outputs[0]
add_node = graph.get_node(add_node_id)
if add_node is None or add_node.op_type != 'BiasAdd' or len(add_node.outputs) != 1:
return False
# Assume BiasAdd input order: [conv_output, bias_vector]
if len(add_node.inputs) != 2 or add_node.inputs[0] != conv_node_id:
return False
bias_node_id = add_node.inputs[1]
relu_node_id = add_node.outputs[0]
relu_node = graph.get_node(relu_node_id)
if relu_node is None or relu_node.op_type != 'ReLU':
# Note: We might allow ReLU to have multiple consumers
return False
print(f"Found pattern: {conv_node.id} -> {add_node.id} -> {relu_node.id}")
# 1. Create Fused Node
fused_node_id = f"fused_{uuid.uuid4().hex[:6]}"
fused_node_attrs = conv_node.attributes.copy() # Inherit conv attributes
fused_node_attrs['activation'] = 'ReLU' # Mark activation type
fused_node_inputs = [conv_node.inputs[0], conv_node.inputs[1], bias_node_id] # Input, Weights, Bias
# 2. Store original outputs of the ReLU node
original_relu_outputs = list(relu_node.outputs) # Copy before modifying
# 3. Create the fused node structure (depends on your graph representation)
# This part is conceptual - adapt to your Node/Graph class structure
graph.add_node(
id=fused_node_id,
op_type='FusedConv2D_BiasAdd_ReLU',
inputs=fused_node_inputs,
outputs=original_relu_outputs, # Initially connect to ReLU's consumers
attributes=fused_node_attrs
)
# 4. Update Consumers of the original ReLU node
for consumer_id in original_relu_outputs:
consumer_node = graph.get_node(consumer_id)
if consumer_node:
# Find where relu_node_id was an input and replace with fused_node_id
try:
idx = consumer_node.inputs.index(relu_node_id)
consumer_node.inputs[idx] = fused_node_id
except ValueError:
print(f"Warning: Consumer {consumer_id} did not list {relu_node_id} as input.")
# 5. Update Producers connected to the original nodes
# The inputs to Conv and Bias are now inputs to the fused node, handled in step 3.
# We need to ensure the original input nodes point *away* from the deleted nodes
# and *towards* the new fused node if needed (depends on representation).
# For simplicity here, we assume edge updates happen primarily via the consumer's 'inputs' list.
# 6. Remove Original Nodes
graph.remove_node(conv_node_id)
graph.remove_node(add_node_id)
graph.remove_node(relu_node_id)
print(f"Successfully fused into node {fused_node_id}")
return True
# --- Example Usage ---
# Assume 'graph' is populated with nodes
# for node_id in list(graph.nodes.keys()): # Iterate over copy of keys
# if graph.get_node(node_id) and graph.get_node(node_id).op_type == 'Conv2D':
# fuse_conv_bias_relu(graph, node_id)
After applying this transformation, the graph segment would look like this:
The computation graph segment after applying the fusion pass, combining the three operations into one.
This example simplifies many aspects:
FusedConv2D_BiasAdd_ReLU
. If not, the fusion might offer no performance gain or even require generating complex code for the fused operation. Libraries like NVIDIA's cuDNN or Intel's oneDNN heavily influence which fusions are practical.This practical exercise provides a foundational understanding of how graph-level fusion passes operate. Building upon this, you can explore more complex patterns, integrate cost models, and consider target-specific constraints, moving closer to the sophisticated optimization passes found in production ML compilers.
© 2025 ApX Machine Learning