有效训练大型机器学习模型需要仔细处理的不仅仅是模型参数本身,还有相关的状态,例如优化器统计信息和伪随机数生成器(PRNG)密钥。JAX 的函数式编程方法,其中函数理想情况下没有副作用,这意味着状态必须显式地进行管理:作为输入传入函数并作为输出返回。尽管这最初可能看起来冗长,但它提供了很大的清晰度并简化了调试,尤其是在复杂的分布式环境中。状态作为 PyTree在 JAX 中,模型参数、优化器状态(例如,动量缓冲区、学习率调度)、批量归一化统计信息和 PRNG 密钥通常表示为 PyTree。PyTree 只是 JAX 可以视为容器(如列表、元组、字典)和叶子节点(如 JAX 数组或标准 Python 类型)的嵌套结构的任何 Python 对象。import jax import jax.numpy as jnp import optax # Common JAX optimizer library # 参数的示例结构(可能更深) params = { 'encoder': { 'layer_1': {'w': jnp.ones((128, 256)), 'b': jnp.zeros(256)}, 'layer_norm': {'scale': jnp.ones(256), 'bias': jnp.zeros(256)} }, 'decoder': { 'output': {'w': jnp.ones((256, 10)), 'b': jnp.zeros(10)} } } # 优化器状态示例(结构通常与参数相似) optimizer = optax.adam(learning_rate=1e-3) opt_state = optimizer.init(params) # 单个PRNG密钥 key = jax.random.PRNGKey(0) # 为了方便,您可以将它们捆绑在一起 training_state = { 'params': params, 'opt_state': opt_state, 'rng_key': key, 'step': 0 } # 验证它是一个PyTree leaves, treedef = jax.tree_util.tree_flatten(training_state) print(f"Number of leaf nodes (arrays, scalars): {len(leaves)}") # 输出:叶子节点(数组、标量)的数量:11(取决于具体的结构和优化器)使用 PyTree 是一个重要方面,因为 JAX 转换(jit、grad、vmap、pmap)旨在对这些结构进行操作。当您将 jax.grad 应用于一个接收并返回 PyTree 的函数时,JAX 会计算指定输入 PyTree(s) 中所有数值叶子节点(数组)的梯度。同样,pmap 根据其参数自动在设备上复制或分发 PyTree 结构。这使得管理可能复杂的嵌套状态结构更加系统化。使用生态系统库进行状态管理尽管您可以手动使用字典或自定义类来管理状态,但 Flax 和 Haiku 等库提供了专门为神经网络设计的高级抽象,显著简化了状态管理。Flax:显式状态对象Flax 通常鼓励将相关的状态组件归入专用对象,通常使用 TrainState 等模式。这个类通常包含模型参数(params)、优化器状态(opt_state)、当前训练步骤(step),有时还包括模型定义本身(apply_fn)。# Flax 示例 from flax.training import train_state import optax class SimpleModel(nn.Module): features: int @nn.compact def __call__(self, x): x = nn.Dense(features=self.features)(x) return x class TrainState(train_state.TrainState): # 可选地在这里添加批量统计或其他状态 batch_stats: Any = None # BatchNorm 示例 # 初始化 key = jax.random.PRNGKey(0) model = SimpleModel(features=10) dummy_input = jnp.ones([1, 128]) params = model.init(key, dummy_input)['params'] optimizer = optax.adam(1e-3) # 创建状态对象 state = TrainState.create( apply_fn=model.apply, # 运行模型的函数 params=params, tx=optimizer # 优化器转换 ) # 在训练步骤中,您将以不可变的方式更新此状态 # new_state = state.apply_gradients(grads=grads)主要思路保持不变:状态显式地保存在一个对象(一个 PyTree)中,更新会生成 新的 状态对象,而不是就地修改。这与 JAX 的函数式方法完全吻合。Haiku:分离参数和状态Haiku 采用略微不同的方法,使用 hk.transform 或 hk.transform_with_state。它将纯函数逻辑与参数(params)和可变状态(state,例如用于批量归一化统计信息)分离。import haiku as hk import optax def forward_fn(x, is_training): net = hk.Sequential([ hk.Linear(512), jax.nn.relu, hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.99), hk.Linear(10) ]) return net(x, is_training=is_training) # 转换函数以处理状态 forward = hk.transform_with_state(forward_fn) key = hk.PRNGSequence(0) dummy_input = jnp.ones([1, 128]) # 初始化参数和可变状态 params, state = forward.init(next(key), dummy_input, is_training=True) optimizer = optax.adam(1e-3) opt_state = optimizer.init(params) # 应用函数(示例:is_training=True 更新批量归一化状态) # logits, new_state = forward.apply(params, state, next(key), batch_data, is_training=True) # grads, new_state = grad_fn(params, state, ...) # grad_fn 需要处理状态 # updates, new_opt_state = optimizer.update(grads, opt_state, params) # new_params = optax.apply_updates(params, updates)Haiku 要求您显式管理 init 和 apply 返回的 params 和 state PyTree。同样,函数式原则适用:状态作为输入传入,更新后的状态作为输出返回。分布式环境中的状态管理(pmap)当使用 pmap 进行数据并行扩展训练时,正确管理跨多个设备的状态变得必要。digraph 状态流 { rankdir=LR; node [shape=box, style=rounded, fontname="Arial", fontsize=10, color="#495057", fontcolor="#495057"]; edge [fontname="Arial", fontsize=9, color="#adb5bd"]; subgraph cluster_inputs { label = "输入"; style=filled; color="#e9ecef"; batch [label="数据批次\n(分片)"]; params_in [label="参数\n(复制)"]; opt_state_in [label="优化器状态\n(复制)"]; rng_key_in [label="PRNG密钥\n(每设备)"]; } train_step [label="pmap化训练步骤", shape=ellipse, color="#1c7ed6", fontcolor="#1c7ed6", style=bold]; subgraph cluster_outputs { label = "输出"; style=filled; color="#e9ecef"; loss [label="损失\n(每设备)"]; grads [label="梯度\n(每设备)"]; new_rng_key_out [label="更新后的PRNG密钥\n(每设备)"]; } batch -> train_step; params_in -> train_step; opt_state_in -> train_step; rng_key_in -> train_step; train_step -> loss; train_step -> grads; # 梯度已计算,需要聚合 train_step -> new_rng_key_out; # 参数/优化器状态的更新在梯度聚合后发生 label = "pmap 中的状态流"; fontsize=12; }在典型的 pmap 数据并行设置中,参数和优化器状态会被复制,数据会被分片,PRNG 密钥必须是每个设备独有的。映射的函数计算每个设备的损失和梯度等结果。参数和优化器状态: 对于标准的数据并行,模型参数(params)和相关的优化器状态(opt_state)通常在所有设备上是相同的。您在主机上初始化它们一次,pmap 会自动将这些 PyTree 广播(复制)到每个设备。当在每个设备上计算梯度时,需要先聚合(例如,使用 lax.pmean 进行平均),然后才能更新参数和优化器状态。这个更新步骤通常也在 pmap 化函数中执行,确保复制的状态在所有设备上保持一致。PRNG 密钥: 这经常是不易察觉的错误来源。如果您在所有设备上为诸如 dropout 或数据增强等操作使用 相同 的 PRNG 密钥,所有设备将生成 相同 的随机数,这会违背随机性的目的或导致相关结果。正确的方法是在调用 pmap 之前,在主机上生成一个主密钥,然后将其拆分为每个设备独有的子密钥。num_devices = jax.local_device_count() key = jax.random.PRNGKey(42) # 在主机上拆分一次 device_keys = jax.random.split(key, num_devices) # pmap 化函数示例(简化) @jax.pmap def train_step_pmap(params, opt_state, local_key, batch): # 在内部使用每个设备的 'local_key' dropout_key, new_local_key = jax.random.split(local_key) # ... 使用 dropout_key 执行带有 dropout 的前向传播 ... # ... 计算梯度 ... # grads = ... # grads = lax.pmean(grads, axis_name='devices') # 聚合梯度 # ... 更新参数和优化器状态 ... # return loss, new_params, new_opt_state, new_local_key # 调用 pmap,传入设备特定密钥的数组 # loss, params, opt_state, device_keys = train_step_pmap(params, opt_state, device_keys, sharded_batch)digraph PmapRNG { rankdir=LR; node [shape=box, style=rounded, fontname="Arial", fontsize=10, color="#495057", fontcolor="#495057"]; edge [fontname="Arial", fontsize=9, color="#adb5bd"]; main_key [label="主 PRNG 密钥"]; split_op [label="jax.random.split(key,\n num_devices)", shape=invhouse, color="#f76707", fontcolor="#f76707"]; pmap_func [label="pmap(...)", shape=cds, height=1.0, width=1.0, fixedsize=true, color="#7048e8", fontcolor="#7048e8"]; subgraph cluster_devices { label = "pmap 的输入"; style=dashed; color="#adb5bd"; key_0 [label="设备 0 密钥"]; key_1 [label="设备 1 密钥"]; key_N [label="设备 N 密钥"]; } main_key -> split_op; split_op -> key_0 [label="keys[0]"]; split_op -> key_1 [label="keys[1]"]; split_op -> key_N [label="keys[N]"]; key_0 -> pmap_func; key_1 -> pmap_func; key_N -> pmap_func; label = "在 pmap 前拆分 PRNG 密钥"; fontsize=12; }创建一个包含唯一 PRNG 密钥的数组,每个设备一个,通过在将它们传入 pmap 之前 拆分一个主密钥。pmap 化函数在特定设备上的每次执行将从 device_keys 数组中接收其对应的唯一密钥。请记住也要从 pmap 化函数中返回更新后的密钥,以便在下一步中正确地继续 PRNG 序列。序列化和检查点训练大型模型可能需要数小时、数天甚至数周。定期将训练状态(检查点)保存到磁盘上是必要的。这使得您可以在中断后继续训练,并保存最终训练好的模型。需要保存什么?通常,您需要:模型参数(params):学习到的权重和偏置。优化器状态(opt_state):对于正确恢复训练很重要,特别是对于带有动量或自适应学习率的优化器(如 Adam)。训练步数计数器:用于跟踪进度和管理学习率调度。PRNG 密钥(可选但推荐):保存最新的 PRNG 密钥状态可以确保更好的可复现性,如果您的输入管道或模型涉及您希望在不同运行中精确控制的随机性。其他状态(例如,批量统计):如果您的模型包含像 BatchNorm 这样的组件,它们的移动平均值也必须被保存和恢复。JAX 生态系统内的库为此提供工具。例如,orbax.checkpoint 正在成为一个标准方案,提供异步检查点(在后台保存状态而不中断训练)和灵活的数据保存结构等功能。像 Flax 这样的框架也内置了序列化工具(flax.training.checkpoints)。# 使用 Orbax 的示例(简化) import orbax.checkpoint as ocp # 假设 'state' 是一个包含 params、opt_state、step 等的 PyTree # 对于 pmap,状态可能会被复制。通常从设备 0 保存。 # state_to_save = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state)) # 从设备 0 获取状态 checkpointer = ocp.StandardCheckpointer() # Or AsyncCheckpointer save_path = '/path/to/checkpoints/step_10000' checkpointer.save(save_path, args=ocp.args.StandardSave(state)) # 恢复: # restored_state = checkpointer.restore(save_path, args=ocp.args.StandardRestore(state)) # 如果使用 pmap,将恢复的状态复制到所有设备 # state = jax.device_put_replicated(restored_state, jax.local_devices())有效管理参数和状态是在 JAX 中构建和训练大型模型的一个基本组成部分。通过使用 PyTree 和生态系统库提供的抽象,并仔细处理状态的分布和持久化,您可以构建可扩展的训练循环。