While PyTorch's state_dict
and TorchScript provide excellent mechanisms for saving and deploying models within the PyTorch ecosystem, there are times when you need your models to work across different machine learning frameworks. Perhaps your team uses TensorFlow for deployment, or you want to leverage a specific hardware accelerator that has optimized support for a common model format. This is where the Open Neural Network Exchange (ONNX) format becomes incredibly valuable.
ONNX is an open standard for representing machine learning models. It defines a common set of operators (the building blocks of models, like convolutions or matrix multiplications) and a common file format (.onnx
). The goal is to enable interoperability: you can train a model in one framework (like PyTorch), export it to ONNX format, and then load and run it in another framework (like TensorFlow, Caffe2, MXNet) or specialized ONNX runtimes.
For developers familiar with TensorFlow's ecosystem, which includes SavedModel for serving and TensorFlow Lite for mobile, ONNX might seem like an extra step. However, it offers several advantages:
Think of ONNX as a universal translator for your neural network models.
The ONNX workflow: A PyTorch model is exported to the ONNX format, which can then be used by various runtimes and tools, including those in the TensorFlow ecosystem.
PyTorch has built-in support for exporting models to the ONNX format using the torch.onnx.export()
function. This function traces your model to convert its operations into an ONNX graph.
Let's look at the common parameters for torch.onnx.export()
:
model
: Your PyTorch model (an instance of torch.nn.Module
).args
: A tuple of example inputs that your model expects. This input is used to trace the model's execution path. The shape and data type of this example input are important.f
: The path where the ONNX model will be saved (e.g., "my_model.onnx"
).input_names
: (Optional) A list of names to assign to the input nodes in the ONNX graph.output_names
: (Optional) A list of names to assign to the output nodes in the ONNX graph.dynamic_axes
: (Optional) A dictionary specifying which axes of inputs/outputs are dynamic (e.g., batch size, sequence length). This is very useful for real-world models. For example, {'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
indicates that the first dimension of both 'input' and 'output' is dynamic and named 'batch_size'.opset_version
: The ONNX operator set version to use. ONNX evolves, and newer versions add support for more operators. It's generally good to use a reasonably recent version that your target deployment environment supports.Here's a simple example of exporting a basic PyTorch model:
import torch
import torch.nn as nn
import torch.onnx
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(10, 5) # Input features: 10, Output features: 5
def forward(self, x):
return self.linear(x)
# Instantiate the model
model = SimpleModel()
model.eval() # Set the model to evaluation mode
# Create dummy input matching the model's expected input_shape
# Batch size of 1, 10 input features
dummy_input = torch.randn(1, 10)
# Define input and output names for clarity in the ONNX model
input_names = ["input_tensor"]
output_names = ["output_tensor"]
# Export the model
torch.onnx.export(model,
dummy_input,
"simple_model.onnx",
input_names=input_names,
output_names=output_names,
opset_version=12, # Specify an ONNX opset version
dynamic_axes={'input_tensor': {0: 'batch_size'}, # batch_size is dynamic
'output_tensor': {0: 'batch_size'}})
print("Model exported to simple_model.onnx")
When you run this, PyTorch executes your SimpleModel
with dummy_input
, records the operations, and translates them into the ONNX format, saving the result as simple_model.onnx
. The dynamic_axes
argument is particularly important. Without it, the exported ONNX model would expect inputs with a fixed batch size (1 in this case). By specifying dynamic_axes
, we tell the ONNX exporter that the batch dimension can vary.
Once you have an .onnx
file, you need a runtime to execute it. The most common one is ONNX Runtime, an open-source, high-performance inference engine for ONNX models. It's cross-platform and supports hardware acceleration.
You can install ONNX Runtime via pip:
pip install onnxruntime
Here's how you might load and run the simple_model.onnx
we just created:
import onnxruntime
import numpy as np
# Create an ONNX Runtime inference session
ort_session = onnxruntime.InferenceSession("simple_model.onnx")
# Prepare an example input (must match the model's expected input, including batch size)
# Let's use a batch size of 3 for this inference run
input_data = np.random.randn(3, 10).astype(np.float32)
# Get the input name from the model (or use the one we defined during export)
input_name = ort_session.get_inputs()[0].name
# Run inference
ort_inputs = {input_name: input_data}
ort_outs = ort_session.run(None, ort_inputs)
# The output is a list of numpy arrays
output_data = ort_outs[0]
print("Input shape:", input_data.shape)
print("Output shape:", output_data.shape)
print("Output data (first row):", output_data[0])
This code snippet loads the ONNX model, prepares an input NumPy array, and runs inference using ONNX Runtime. You'll notice we can use a different batch size (3) than the one used for the dummy_input
during export (1), thanks to dynamic_axes
.
As a TensorFlow developer, you might wonder how ONNX fits into your existing workflows. If you receive an ONNX model (perhaps exported from PyTorch), you can convert it to a TensorFlow format (like SavedModel) to integrate it into TensorFlow-based deployment pipelines (e.g., TensorFlow Serving).
The onnx-tf
converter is a popular tool for this:
pip install onnx-tf
You can then use it to convert an .onnx
file to a TensorFlow SavedModel:
# Assuming onnx_tf is installed
from onnx_tf.backend import prepare
import onnx
# Load the ONNX model
onnx_model = onnx.load("simple_model.onnx")
# Prepare the TensorFlow representation
tf_rep = prepare(onnx_model)
# Export as TensorFlow SavedModel
tf_rep.export_graph("simple_model_tf_savedmodel")
print("ONNX model converted to TensorFlow SavedModel format in 'simple_model_tf_savedmodel'")
This creates a standard TensorFlow SavedModel directory, which can then be loaded using tf.saved_model.load()
or deployed with TensorFlow Serving.
Conversely, if you have TensorFlow models, you can convert them to ONNX using tools like tf2onnx
(pip install tf2onnx
). This allows you to bring TensorFlow models into the ONNX ecosystem, potentially to be used by PyTorch or other ONNX-compatible tools, though this course focuses on the PyTorch-to-TensorFlow direction.
While ONNX is powerful, keep these points in mind:
opset_version
. You might need to simplify your model or implement custom ONNX operators (an advanced topic). Always check the ONNX documentation for supported operators.opset_version
you choose, the ONNX Runtime version, and any converter tools (like onnx-tf
). Mismatches can lead to errors or unexpected behavior.torch.onnx.export
function works by tracing the model with a sample input. If your model has control flow that changes based on input data (which is less common for models intended for ONNX export but possible in PyTorch), the trace might not capture all execution paths. TorchScript, discussed earlier, can sometimes be more robust for models with complex control flow before exporting to ONNX.input_names
and output_names
during export makes the ONNX model much easier to work with later, especially when using ONNX Runtime or converter tools.ONNX provides a significant bridge between different ML frameworks. For TensorFlow developers learning PyTorch, understanding how to export PyTorch models to ONNX opens up possibilities for integrating these models into existing TensorFlow-centric deployment pipelines or leveraging the broader ONNX ecosystem for optimization and execution.
© 2025 ApX Machine Learning