Mixed-Precision Training, Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich Elsen, David Garcia, Boris Ginsburg, Michael Houston, Oleksii Kuchaiev, Ganesh Venkatesh, Hao Wu, 2017International Conference on Learning Representations (ICLR) 2018DOI: 10.48550/arXiv.1710.03740 - The seminal paper introducing mixed-precision training with float16 and the concept of loss scaling to mitigate numerical underflow.
BFloat16: A New Standard For Deep Learning, Neal H. Liu, Norman P. Jouppi, 2019 (Google AI Blog) - Describes the bfloat16 format, its advantages for deep learning, particularly its wider dynamic range compared to float16, and its adoption in Google's TPUs.
JAX documentation for Data types (dtypes), JAX Authors and Contributors, 2024 (JAX Documentation) - Provides detailed information on JAX's handling of numerical data types, including jnp.float16 and jnp.bfloat16, essential for understanding type compatibility in JAX.
Mixed precision in Flax, Flax Authors and Contributors, 2024 (Flax Documentation) - A practical guide on configuring and implementing mixed precision training within the Flax framework, covering param_dtype and dtype settings.