Reducing the precision of model parameters to BFloat16 or Float16 significantly lowers the static memory footprint of the model weights. However, during training, the transient memory required to store intermediate activations for backpropagation often exceeds the memory occupied by parameters. For large language models (LLMs) with long sequence lengths, this activation memory scales linearly with the number of layers and batch size, and quadratically with sequence length due to attention mechanisms. Activation checkpointing, also known as gradient checkpointing, addresses this bottleneck by discarding intermediate activations during the forward pass and recomputing them during the backward pass.
Standard backpropagation requires the output of every operation in the computation graph to be stored in GPU memory until the gradients are calculated. For a Transformer model with L layers, this results in a memory complexity of O(L).
Activation checkpointing alters this behavior by designating specific modules as "checkpoints." Only the inputs to these checkpointed modules are preserved in memory. All intermediate activations generated within the module are discarded. When the backward pass reaches a checkpointed module, the system performs a localized forward pass using the stored inputs to regenerate the necessary intermediate states. This effectively trades computational overhead for memory savings.
The theoretical memory consumption with optimal checkpoint placement typically follows a square root relationship relative to the number of layers. If we divide a network of N nodes into segments of length N, we only store the N segment boundaries. During backpropagation, we recompute the N nodes within a segment.
Memorycheckpoint≈O(N)+O(BlockSize)
The following diagram illustrates the difference in memory state between standard training and checkpointed training during the backward pass.
Comparison of stored tensors in standard training versus activation checkpointing. In the checkpointing approach, intermediate layers are discarded and recomputed on demand.
Implementing activation checkpointing in an FSDP environment requires careful orchestration. Naive application of PyTorch's torch.utils.checkpoint on sub-modules can lead to conflicts with FSDP's sharding logic. FSDP shards parameters across GPUs, but checkpointing requires full parameter access during the recomputation phase.
PyTorch provides the apply_activation_checkpointing utility specifically for this purpose. This function ensures that the wrapping order is correct: the FSDP wrapper must encompass the checkpoint wrapper, or vice versa, depending on the desired granularity. The recommended pattern for Transformers is to wrap each Transformer Block (Attention + FeedForward).
When using apply_activation_checkpointing with FSDP, the utility automatically handles the checkpoint_wrapper. A critical configuration detail for experts is the choice between reentrant and non-reentrant wrappers.
For FSDP, it is strongly advised to use non-reentrant checkpointing (use_reentrant=False). The legacy reentrant variant relies on global state and backward hooks that can interfere with FSDP's state synchronization, potentially leading to deadlocks or incorrect gradients. The non-reentrant implementation treats the checkpointed region as a distinct autograd function, providing better stability and compatibility with distributed settings.
Here is the implementation pattern for applying checkpointing to a Transformer-based model wrapped in FSDP:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
# Define the check function to identify transformer blocks
# Assuming model structure has a class 'TransformerBlock'
check_fn = lambda submodule: isinstance(submodule, TransformerBlock)
# Apply checkpointing
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=checkpoint_wrapper,
check_fn=check_fn,
)
# Wrap in FSDP after applying checkpointing logic
# Note: In recent PyTorch versions, applying to FSDP module
# directly is also supported and often preferred.
fsdp_model = FSDP(model, auto_wrap_policy=...)
Activation checkpointing introduces a computational cost. Because the forward pass for the checkpointed segments runs twice (once for the loss calculation, once during backpropagation), the computational overhead is approximately 20% to 30%.
However, this overhead is often justifiable. The reduction in memory usage allows for a significant increase in micro-batch size. Larger batch sizes improve GPU occupancy and arithmetic intensity, potentially offsetting the recomputation cost. In scenarios where the model fits in memory without checkpointing but with a very small batch size, the GPU compute units may be underutilized. By enabling checkpointing and increasing the batch size, you often achieve higher overall tokens-per-second throughput despite the recomputation penalty.
The following chart demonstrates the relationship between batch size and memory consumption with and without checkpointing.
Memory scaling behavior. The red line indicates standard training hitting the 80GB VRAM limit at batch size 8. The blue line shows checkpointing enabling batch sizes up to 32 within the same hardware constraints.
For extreme scale models where even standard checkpointing is insufficient, PyTorch FSDP allows for selective checkpointing. Instead of checkpointing every transformer layer, you might choose to checkpoint every n-th layer. This provides a granular dial to balance memory savings against computational overhead.
Furthermore, the checkpoint_wrapper supports CPU offloading. While FSDP handles parameter offloading, activation offloading is distinct. By configuring offload_to_cpu=True in the checkpoint wrapper, the preserved inputs for the checkpointed modules are moved to system RAM and prefetched back to the GPU immediately before the recomputation step. This is particularly effective when PCIe bandwidth is not the bottleneck, allowing the training of models that are significantly larger than the aggregate GPU memory capacity of the cluster.
Total_Throughput∝StepTimeBatchSize
When optimizing, the goal is to maximize the equation above. If checkpointing increases StepTime by 30% but allows BatchSize to increase by 100% without Out-Of-Memory (OOM) errors, the net result is a substantial gain in training efficiency.
Was this section helpful?
© 2026 ApX Machine LearningEngineered with