Training high-parameter models in full 32-bit precision (FP32) is inefficient and often impossible on standard GPU clusters due to memory limitations. Mixed precision training addresses this by utilizing 16-bit formats for the majority of arithmetic operations while maintaining a master copy of weights in FP32 for updates. For PyTorch FSDP, selecting between Brain Floating Point (BFloat16) and standard Floating Point (Float16) is not merely a preference but a decision dictated by hardware capabilities and convergence stability requirements.
To make an informed decision, one must understand the bit-level representation of these formats. The primary difference lies in how the 16 bits are allocated between the exponent (range) and the mantissa (precision).
Standard Float16 (IEEE 754) allocates 1 bit for the sign, 5 bits for the exponent, and 10 bits for the mantissa (significand). This format offers higher precision within a narrow dynamic range. The limited exponent width means the largest representable number is 65,504, and the smallest positive normal number is approximately 6.1×10−5. In deep learning, gradients frequently fall below this threshold (underflow) or activations exceed the maximum (overflow), necessitating aggressive loss scaling.
BFloat16, developed by Google Brain, alters this tradeoff. It allocates 1 bit for the sign, 8 bits for the exponent, and 7 bits for the mantissa. The exponent width matches that of FP32. Consequently, BFloat16 preserves the dynamic range of a standard 32-bit float, effectively acting as a truncated FP32. While it loses significant precision (3 fewer bits in the mantissa compared to FP16), the extended range makes it inherently against underflow and overflow without complex scaling logic.
Comparison of bit allocations across floating point formats. BFloat16 maintains the exponent width of FP32 to preserve dynamic range.
The operational difference between these formats manifests primarily in the training loop stability.
When using Float16, the gradients often become small enough to vanish (underflow to zero). To counter this, PyTorch employs GradScaler. This utility multiplies the loss by a scale factor (e.g., 216) before the backward pass, shifting gradients into the representable range of FP16. After backpropagation, the gradients are unscaled before the optimizer step. This introduces computational overhead and complexity. If the scale factor is too high, gradients overflow to infinity; if too low, they underflow. The scaler must dynamically adjust this factor, which can lead to skipped steps if Inf or NaN values are detected.
BFloat16 eliminates the need for loss scaling entirely. Because its dynamic range matches FP32 (≈10−38 to 1038), gradients rarely underflow or overflow during standard training runs. This stability is particularly critical for Large Language Models (LLMs) trained with Transformer architectures, where attention scores and activation spikes can be volatile.
The following chart visualizes the representable range constraints. Notice how quickly Float16 hits the ceiling compared to BFloat16.
The effective dynamic range of BFloat16 spans significantly wider than Float16, matching the operational boundaries of FP32.
In FSDP, mixed precision is controlled via the MixedPrecision configuration object. This class dictates the dtype for three specific stages of the training lifecycle:
param_dtype: The format in which parameters are cast before the forward pass.reduce_dtype: The format used for gradient synchronization (AllReduce) across ranks.buffer_dtype: The format for buffers (e.g., BatchNorm statistics).Correctly setting these parameters determines the memory savings and numerical safety of your run.
If your cluster is equipped with NVIDIA Ampere (A100), Hopper (H100), or newer architectures, BFloat16 is the standard. It offers hardware acceleration for matrix multiplications via Tensor Cores.
import torch
from torch.distributed.fsdp import MixedPrecision
bf16_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32, # See note below on reduction
buffer_dtype=torch.bfloat16,
)
Note the reduce_dtype. While it is tempting to set this to bfloat16 to reduce communication bandwidth, doing so is risky. BFloat16 has low precision (only 7 mantissa bits). When summing gradients across hundreds of GPUs (AllReduce), you encounter the "swamping" phenomenon where adding small gradient updates to a large accumulation buffer results in the small values being lost entirely. Keeping reduce_dtype=torch.float32 ensures that the gradient averaging remains precise, while param_dtype=torch.bfloat16 ensures the heavy compute (forward/backward) uses the faster, lighter format.
For older hardware like V100 (Volta) or T4 (Turing), BFloat16 is not supported natively. You must resort to Float16 and manage the GradScaler.
fp16_policy = MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float16,
buffer_dtype=torch.float16,
)
# Requires external scaler management
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
scaler = ShardedGradScaler()
When using reduce_dtype=torch.float16, you gain communication speed but increase the risk of overflow during the reduction. However, since FP16 has higher precision in the mantissa than BF16, it is less susceptible to swamping, making 16-bit reduction slightly safer than in the BF16 case, provided the values stay within range.
The choice of precision impacts memory throughput (DRAM bandwidth) and arithmetic throughput (TFLOPS).
In FSDP, the memory savings from param_dtype allow you to increase the local batch size. If a model layer occupies W bytes in FP32, using mixed precision reduces the active memory required for that layer's forward pass to W/2 plus the overhead of keeping the FP32 master weights in the optimizer state (which are sharded in FSDP).
| Feature | BFloat16 | Float16 |
|---|---|---|
| Hardware Requirement | NVIDIA Ampere (A100) or newer | Volta (V100) or newer |
| Mantissa Precision | Low (7 bits) | High (10 bits) |
| Dynamic Range | High (8 bit exponent) | Low (5 bit exponent) |
| Loss Scaling | Not Required | Required (GradScaler) |
| Reduction Stability | Poor (keep reduction in FP32) | Moderate (prone to overflow) |
| Use Case | LLMs, Transformers, Large Clusters | Legacy Hardware, Convolutional Nets |
For training large language models on modern clusters, BFloat16 with FP32 reduction is the dominant strategy. It minimizes memory usage and maximizes stability without the administrative overhead of loss scalers. Use Float16 only when strictly limited by hardware generation.
Was this section helpful?
© 2026 ApX Machine LearningEngineered with