Training high-parameter models in full 32-bit precision (FP32) is inefficient and often impossible on standard GPU clusters due to memory limitations. Mixed precision training addresses this by utilizing 16-bit formats for the majority of arithmetic operations while maintaining a master copy of weights in FP32 for updates. For PyTorch FSDP, selecting between Brain Floating Point (BFloat16) and standard Floating Point (Float16) is not merely a preference but a decision dictated by hardware capabilities and convergence stability requirements.The Anatomy of 16-Bit FormatsTo make an informed decision, one must understand the bit-level representation of these formats. The primary difference lies in how the 16 bits are allocated between the exponent (range) and the mantissa (precision).Standard Float16 (IEEE 754) allocates 1 bit for the sign, 5 bits for the exponent, and 10 bits for the mantissa (significand). This format offers higher precision within a narrow dynamic range. The limited exponent width means the largest representable number is 65,504, and the smallest positive normal number is approximately $$6.1 \times 10^{-5}$$. In deep learning, gradients frequently fall below this threshold (underflow) or activations exceed the maximum (overflow), necessitating aggressive loss scaling.BFloat16, developed by Google Brain, alters this tradeoff. It allocates 1 bit for the sign, 8 bits for the exponent, and 7 bits for the mantissa. The exponent width matches that of FP32. Consequently, BFloat16 preserves the dynamic range of a standard 32-bit float, effectively acting as a truncated FP32. While it loses significant precision (3 fewer bits in the mantissa compared to FP16), the extended range makes it inherently against underflow and overflow without complex scaling logic.digraph G { rankdir=TB; node [shape=record, style=filled, fontname="Helvetica", fontsize=12, color="#adb5bd"]; edge [fontname="Helvetica", fontsize=10]; subgraph cluster_fp32 { label = "FP32 (32-bit)"; style = dashed; color = "#ced4da"; fp32 [label="{Sign (1)|Exponent (8)|Mantissa (23)}", fillcolor="#e9ecef"]; } subgraph cluster_bf16 { label = "BFloat16 (16-bit)"; style = dashed; color = "#ced4da"; bf16 [label="{Sign (1)|Exponent (8)|Mantissa (7)}", fillcolor="#d0bfff"]; } subgraph cluster_fp16 { label = "Float16 (16-bit)"; style = dashed; color = "#ced4da"; fp16 [label="{Sign (1)|Exponent (5)|Mantissa (10)}", fillcolor="#99e9f2"]; } fp32 -> bf16 [label="Truncation (Easy Conversion)", color="#868e96", style=dotted]; fp32 -> fp16 [label="Requires Re-scaling", color="#868e96", style=dotted]; }Comparison of bit allocations across floating point formats. BFloat16 maintains the exponent width of FP32 to preserve dynamic range.Convergence Stability and Loss ScalingThe operational difference between these formats manifests primarily in the training loop stability.When using Float16, the gradients often become small enough to vanish (underflow to zero). To counter this, PyTorch employs GradScaler. This utility multiplies the loss by a scale factor (e.g., $$2^{16}$$) before the backward pass, shifting gradients into the representable range of FP16. After backpropagation, the gradients are unscaled before the optimizer step. This introduces computational overhead and complexity. If the scale factor is too high, gradients overflow to infinity; if too low, they underflow. The scaler must dynamically adjust this factor, which can lead to skipped steps if Inf or NaN values are detected.BFloat16 eliminates the need for loss scaling entirely. Because its dynamic range matches FP32 ($$ \approx 10^{-38} \text{ to } 10^{38} $$), gradients rarely underflow or overflow during standard training runs. This stability is particularly critical for Large Language Models (LLMs) trained with Transformer architectures, where attention scores and activation spikes can be volatile.The following chart visualizes the representable range constraints. Notice how quickly Float16 hits the ceiling compared to BFloat16.{"layout": {"xaxis": {"title": "Magnitude (Log Scale)", "type": "log", "range": [-45, 45], "showgrid": true, "gridcolor": "#dee2e6"}, "yaxis": {"showticklabels": false, "showgrid": false}, "shapes": [{"type": "rect", "x0": 1e-38, "x1": 3.4e38, "y0": 2, "y1": 3, "fillcolor": "#d0bfff", "opacity": 0.7, "line": {"width": 0}}, {"type": "rect", "x0": 6e-5, "x1": 65504, "y0": 0, "y1": 1, "fillcolor": "#99e9f2", "opacity": 0.7, "line": {"width": 0}}], "annotations": [{"x": 1e0, "y": 2.5, "text": "BFloat16 / FP32 Range", "showarrow": false, "font": {"color": "#5f3dc4"}}, {"x": 1e0, "y": 0.5, "text": "Float16 Range", "showarrow": false, "font": {"color": "#0c8599"}}], "height": 250, "margin": {"t": 30, "b": 40, "l": 40, "r": 40}}, "data": []}The effective dynamic range of BFloat16 spans significantly wider than Float16, matching the operational boundaries of FP32.Configuring FSDP Mixed PrecisionIn FSDP, mixed precision is controlled via the MixedPrecision configuration object. This class dictates the dtype for three specific stages of the training lifecycle:param_dtype: The format in which parameters are cast before the forward pass.reduce_dtype: The format used for gradient synchronization (AllReduce) across ranks.buffer_dtype: The format for buffers (e.g., BatchNorm statistics).Correctly setting these parameters determines the memory savings and numerical safety of your run.BFloat16 Configuration (Recommended for Ampere+)If your cluster is equipped with NVIDIA Ampere (A100), Hopper (H100), or newer architectures, BFloat16 is the standard. It offers hardware acceleration for matrix multiplications via Tensor Cores.import torch from torch.distributed.fsdp import MixedPrecision bf16_policy = MixedPrecision( param_dtype=torch.bfloat16, reduce_dtype=torch.float32, # See note below on reduction buffer_dtype=torch.bfloat16, )Note the reduce_dtype. While it is tempting to set this to bfloat16 to reduce communication bandwidth, doing so is risky. BFloat16 has low precision (only 7 mantissa bits). When summing gradients across hundreds of GPUs (AllReduce), you encounter the "swamping" phenomenon where adding small gradient updates to a large accumulation buffer results in the small values being lost entirely. Keeping reduce_dtype=torch.float32 ensures that the gradient averaging remains precise, while param_dtype=torch.bfloat16 ensures the heavy compute (forward/backward) uses the faster, lighter format.Float16 Configuration (Legacy Hardware)For older hardware like V100 (Volta) or T4 (Turing), BFloat16 is not supported natively. You must resort to Float16 and manage the GradScaler.fp16_policy = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, ) # Requires external scaler management from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler scaler = ShardedGradScaler()When using reduce_dtype=torch.float16, you gain communication speed but increase the risk of overflow during the reduction. However, since FP16 has higher precision in the mantissa than BF16, it is less susceptible to swamping, making 16-bit reduction slightly safer than in the BF16 case, provided the values stay within range.Precision-Throughput Trade-offsThe choice of precision impacts memory throughput (DRAM bandwidth) and arithmetic throughput (TFLOPS).Memory Bandwidth: Both BF16 and FP16 halve the memory traffic for model weights and activations compared to FP32. This is often the primary accelerator for LLM training, as these workloads are typically memory-bound rather than compute-bound.Compute Throughput: On A100 GPUs, BF16 and FP16 Tensor Cores offer theoretically identical peak throughput (312 TFLOPS). However, BF16 often yields slightly better performance because it avoids the kernel launch overhead and memory read/write operations associated with dynamic loss scaling checks.In FSDP, the memory savings from param_dtype allow you to increase the local batch size. If a model layer occupies $$W$$ bytes in FP32, using mixed precision reduces the active memory required for that layer's forward pass to $$W/2$$ plus the overhead of keeping the FP32 master weights in the optimizer state (which are sharded in FSDP).Summary of RecommendationsFeatureBFloat16Float16Hardware RequirementNVIDIA Ampere (A100) or newerVolta (V100) or newerMantissa PrecisionLow (7 bits)High (10 bits)Dynamic RangeHigh (8 bit exponent)Low (5 bit exponent)Loss ScalingNot RequiredRequired (GradScaler)Reduction StabilityPoor (keep reduction in FP32)Moderate (prone to overflow)Use CaseLLMs, Transformers, Large ClustersLegacy Hardware, Convolutional NetsFor training large language models on modern clusters, BFloat16 with FP32 reduction is the dominant strategy. It minimizes memory usage and maximizes stability without the administrative overhead of loss scalers. Use Float16 only when strictly limited by hardware generation.