This benchmark explores the training performance and memory usage of TensorFlow, PyTorch, and JAX using a simple convolutional neural network (CNN). While the test is primarily for fun and exploration, it aims to highlight practical differences across frameworks when used under consistent conditions. Deep learning framework selection often goes beyond syntax. Factors like training speed, memory efficiency, and compilation behaviour become crucial, especially when working with limited hardware or scaling to larger models.
I've run all the experiments on an NVIDIA L4 GPU, a common choice in cloud environments and accessible for development. Rather than pushing each framework to its limits, the goal is to evaluate their performance on a level playing field.
All three frameworks used:
The dataset was generated in memory using NumPy to eliminate disk I/O variability. TensorFlow used tf.data
with basic prefetching, PyTorch leveraged DataLoader
with pinned memory, and JAX handled shuffling and batching manually to simulate similar behaviour.
A one-epoch warm-up was performed before recording metrics to account for GPU initialization and JIT compilation. Timing started immediately afterwards and included only the training time for the measured epochs.
RAM and VRAM were measured with psutil
and NVIDIA's pynvml
:
jit_compile=True
in Keras to enable XLA. Compiles layers individually, offering partial graph execution.torch.compile()
for static graph optimization.@jit
with XLA, which increases VRAM usage due to staging buffers but delivers strong performance in larger models.Framework-specific optimizations were intentionally excluded to maintain fairness:
tf.function
outside of Keraschannels_last
layout or AMP/mixed-precisionpjit
or xmap
sharding techniquesFramework | Training Time (s) | Accuracy* | Initial RAM (GB) | Max RAM (GB) | Initial VRAM (GB) | Max VRAM (GB) |
---|---|---|---|---|---|---|
TensorFlow | 90.88 | ~ 0.1 | 6.13 | 7.92 | 2.58 | 8.74 |
PyTorch | 82.86 | ~ 0.1 | 4.22 | 5.22 | 1.13 | 6.69 |
JAX | 99.44 | ~ 0.1 | 4.42 | 3.29 | 9.25 | 17.45 |
*Accuracy is included as a sanity check. Since the architecture and data are the same across all frameworks, consistency is expected.
Training time was shortest with PyTorch, benefiting from its dynamic execution model and optimized backend for CNNs. TensorFlow followed closely but showed slightly higher times, likely due to graph tracing and tf.data
overhead. JAX was slower, not due to inefficient code paths but because of the up-front compilation cost that isn't amortized in small-scale tests.
RAM usage was highest in TensorFlow, mostly from dataset buffering and graph construction. PyTorch remained the most efficient, aided by lazy evaluation and clear separation between host and device memory. JAX used the least RAM by aggressively offloading data and executing it to GPU memory.
VRAM usage was highest in JAX due to XLA staging and limited intermediate memory reuse in smaller models. PyTorch had the leanest GPU footprint thanks to its flexible, eager execution. TensorFlow landed in the middle, using XLA but with more granular graph segments that don't hold as much in memory at once.
While PyTorch came ahead in speed and memory efficiency, TensorFlow held its own with consistent performance and a more structured pipeline. Both are solid choices depending on your priorities and tooling needs.
JAX's numbers reflect its compile-first approach. In this scale of task, that cost is visible but with larger models or tuned pipelines, it can become an advantage. From experience, JAX shines in scenarios where batch sizes are larger, and the training loop is run many times, allowing it to reuse compiled functions efficiently.
This benchmark avoids more complex scenarios, so results should be interpreted in that context. Specifically, it does not include:
These factors could significantly impact the relative performance and efficiency of each framework.
PyTorch edged ahead in speed and VRAM efficiency for this setup, while TensorFlow remained competitive. JAX lagged behind, but this was primarily due to compilation overhead and memory staging.
In conclusion, all three frameworks are viable. Your best choice depends on your project's context and priorities.
© 2025 ApX Machine Learning. All rights reserved.
AutoML Platform