Just launched! No big plans. Just sharing our AI/ML research and practical tips.

Follow on LinkedIn

TensorFlow vs PyTorch vs JAX: Performance Benchmark

Wei Ming T.

By Wei Ming T. on Mar 24, 2025

This benchmark explores the training performance and memory usage of TensorFlow, PyTorch, and JAX using a simple convolutional neural network (CNN). While the test is primarily for fun and exploration, it aims to highlight practical differences across frameworks when used under consistent conditions. Deep learning framework selection often goes beyond syntax. Factors like training speed, memory efficiency, and compilation behaviour become crucial, especially when working with limited hardware or scaling to larger models.

I've run all the experiments on an NVIDIA L4 GPU, a common choice in cloud environments and accessible for development. Rather than pushing each framework to its limits, the goal is to evaluate their performance on a level playing field.

Setup and Methodology

All three frameworks used:

  • The same CNN architecture with Conv2D, MaxPooling, and Dense layers
  • A synthetic dataset of 224x224x3 RGB images with 10 output classes
  • Batch size of 32, trained over 5 epochs with 100 steps per epoch
  • Single GPU (L4) setup with no mixed precision, no data augmentation

The dataset was generated in memory using NumPy to eliminate disk I/O variability. TensorFlow used tf.data with basic prefetching, PyTorch leveraged DataLoader with pinned memory, and JAX handled shuffling and batching manually to simulate similar behaviour.

A one-epoch warm-up was performed before recording metrics to account for GPU initialization and JIT compilation. Timing started immediately afterwards and included only the training time for the measured epochs.

RAM and VRAM were measured with psutil and NVIDIA's pynvml:

  • Initial RAM/VRAM: Captured after model creation and data loading before training began.
  • Max RAM/VRAM: Captured during peak memory usage while the model was actively training.

Compilation Differences

  • TensorFlow: Used jit_compile=True in Keras to enable XLA. Compiles layers individually, offering partial graph execution.
  • PyTorch: Used torch.compile() for static graph optimization.
  • JAX: Compiles entire functions using @jit with XLA, which increases VRAM usage due to staging buffers but delivers strong performance in larger models.

Framework-specific optimizations were intentionally excluded to maintain fairness:

  • TensorFlow's tf.function outside of Keras
  • PyTorch's channels_last layout or AMP/mixed-precision
  • JAX's pjit or xmap sharding techniques

Benchmark Results

Training Performance and Memory Usage

Framework Training Time (s) Accuracy* Initial RAM (GB) Max RAM (GB) Initial VRAM (GB) Max VRAM (GB)
TensorFlow 90.88 ~ 0.1 6.13 7.92 2.58 8.74
PyTorch 82.86 ~ 0.1 4.22 5.22 1.13 6.69
JAX 99.44 ~ 0.1 4.42 3.29 9.25 17.45

*Accuracy is included as a sanity check. Since the architecture and data are the same across all frameworks, consistency is expected.

Interpreting the Differences

Training time was shortest with PyTorch, benefiting from its dynamic execution model and optimized backend for CNNs. TensorFlow followed closely but showed slightly higher times, likely due to graph tracing and tf.data overhead. JAX was slower, not due to inefficient code paths but because of the up-front compilation cost that isn't amortized in small-scale tests.

RAM usage was highest in TensorFlow, mostly from dataset buffering and graph construction. PyTorch remained the most efficient, aided by lazy evaluation and clear separation between host and device memory. JAX used the least RAM by aggressively offloading data and executing it to GPU memory.

VRAM usage was highest in JAX due to XLA staging and limited intermediate memory reuse in smaller models. PyTorch had the leanest GPU footprint thanks to its flexible, eager execution. TensorFlow landed in the middle, using XLA but with more granular graph segments that don't hold as much in memory at once.

Framework Behavior Insights

While PyTorch came ahead in speed and memory efficiency, TensorFlow held its own with consistent performance and a more structured pipeline. Both are solid choices depending on your priorities and tooling needs.

JAX's numbers reflect its compile-first approach. In this scale of task, that cost is visible but with larger models or tuned pipelines, it can become an advantage. From experience, JAX shines in scenarios where batch sizes are larger, and the training loop is run many times, allowing it to reuse compiled functions efficiently.

Limitations and Considerations

This benchmark avoids more complex scenarios, so results should be interpreted in that context. Specifically, it does not include:

  • Multi-GPU or distributed training
  • Mixed precision (e.g., bfloat16 or FP16)
  • Real-world datasets
  • Data augmentation or advanced pipeline techniques
  • Non-CNN architectures (e.g., transformers or LSTMs)

These factors could significantly impact the relative performance and efficiency of each framework.

Conclusion

PyTorch edged ahead in speed and VRAM efficiency for this setup, while TensorFlow remained competitive. JAX lagged behind, but this was primarily due to compilation overhead and memory staging.

  • PyTorch: Great for fast iteration and minimal memory usage on limited hardware.
  • TensorFlow: Ideal for production-ready pipelines and robust tooling.
  • JAX: Best suited for large-scale training and advanced users who can invest in tuning for performance.

In conclusion, all three frameworks are viable. Your best choice depends on your project's context and priorities.

© 2025 ApX Machine Learning. All rights reserved.

AutoML Platform

Beta
  • Early access to high-performance cloud ML infrastructure
  • Train models faster with scalable distributed computing
  • Shape the future of cloud-powered no-code ML
Learn More
;