趋近智
训练现代机器学习 (machine learning)模型经常会达到单个加速器内存和计算能力的极限。本章着重介绍有效扩展JAX应用所需的方法。
我们将考察如何使用Flax和Haiku等库来构建大型模型,以及如何在这些框架中结合pmap实现分布式数据并行。你将学习到一些实用策略,例如梯度累积(模拟更大的有效批量大小)、梯度检查点(jax.checkpoint)(以重新计算为代价减少内存使用)以及混合精度训练(进一步节省内存并可能提高速度)。我们还将介绍与模型并行性相关的一些思想,并讨论适用于分布式环境的优化器。
本章结束后,你将明白如何结合这些不同的JAX功能和生态系统工具,以应对训练大型神经网络 (neural network)的挑战。
6.1 大型模型训练中的挑战概述
6.2 JAX 生态系统库(Flax, Haiku)简介
6.3 模型参数和状态的处理
6.4 将 pmap 与训练框架结合使用
6.5 梯度累积
6.6 梯度检查点(再物化)
6.7 混合精度训练
6.8 模型并行策略
6.9 大规模优化算法
6.10 实践:实现梯度检查点