Distributed Training with Flax and JAX, Flax team, 2024 (Flax Documentation) - Flax 官方指南,展示如何将 Flax 模型和 TrainState 与 JAX 的 pmap 集成,以进行数据并行训练,包括状态复制、数据分片和梯度聚合。
JAX: Composable transformations of Python+NumPy programs, James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Skye Wanderman-Milne, and Adam Paszke, 2023Journal of Machine Learning Research, Vol. 24 (Microtome Publishing)DOI: 10.5555/3579549.3579555 - 介绍 JAX 的基础论文,阐述了其设计原则、自动微分以及通过 XLA 进行编译的机制,这些构成了其并行化能力的基础。