趋近智
有效训练大型机器学习模型需要仔细处理的不仅仅是模型参数本身,还有相关的状态,例如优化器统计信息和伪随机数生成器(PRNG)密钥。JAX 的函数式编程方法,其中函数理想情况下没有副作用,这意味着状态必须显式地进行管理:作为输入传入函数并作为输出返回。尽管这最初可能看起来冗长,但它提供了很大的清晰度并简化了调试,尤其是在复杂的分布式环境中。
在 JAX 中,模型参数、优化器状态(例如,动量缓冲区、学习率调度)、批量归一化统计信息和 PRNG 密钥通常表示为 PyTree。PyTree 只是 JAX 可以视为容器(如列表、元组、字典)和叶子节点(如 JAX 数组或标准 Python 类型)的嵌套结构的任何 Python 对象。
import jax
import jax.numpy as jnp
import optax # Common JAX optimizer library
# 参数的示例结构(可能更深)
params = {
'encoder': {
'layer_1': {'w': jnp.ones((128, 256)), 'b': jnp.zeros(256)},
'layer_norm': {'scale': jnp.ones(256), 'bias': jnp.zeros(256)}
},
'decoder': {
'output': {'w': jnp.ones((256, 10)), 'b': jnp.zeros(10)}
}
}
# 优化器状态示例(结构通常与参数相似)
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)
# 单个PRNG密钥
key = jax.random.PRNGKey(0)
# 为了方便,您可以将它们捆绑在一起
training_state = {
'params': params,
'opt_state': opt_state,
'rng_key': key,
'step': 0
}
# 验证它是一个PyTree
leaves, treedef = jax.tree_util.tree_flatten(training_state)
print(f"Number of leaf nodes (arrays, scalars): {len(leaves)}")
# 输出:叶子节点(数组、标量)的数量:11(取决于具体的结构和优化器)
使用 PyTree 是一个重要方面,因为 JAX 转换(jit、grad、vmap、pmap)旨在对这些结构进行操作。当您将 jax.grad 应用于一个接收并返回 PyTree 的函数时,JAX 会计算指定输入 PyTree(s) 中所有数值叶子节点(数组)的梯度。同样,pmap 根据其参数自动在设备上复制或分发 PyTree 结构。这使得管理可能复杂的嵌套状态结构更加系统化。
尽管您可以手动使用字典或自定义类来管理状态,但 Flax 和 Haiku 等库提供了专门为神经网络设计的高级抽象,显著简化了状态管理。
Flax 通常鼓励将相关的状态组件归入专用对象,通常使用 TrainState 等模式。这个类通常包含模型参数(params)、优化器状态(opt_state)、当前训练步骤(step),有时还包括模型定义本身(apply_fn)。
# Flax 示例
from flax.training import train_state
import optax
class SimpleModel(nn.Module):
features: int
@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.features)(x)
return x
class TrainState(train_state.TrainState):
# 可选地在这里添加批量统计或其他状态
batch_stats: Any = None # BatchNorm 示例
# 初始化
key = jax.random.PRNGKey(0)
model = SimpleModel(features=10)
dummy_input = jnp.ones([1, 128])
params = model.init(key, dummy_input)['params']
optimizer = optax.adam(1e-3)
# 创建状态对象
state = TrainState.create(
apply_fn=model.apply, # 运行模型的函数
params=params,
tx=optimizer # 优化器转换
)
# 在训练步骤中,您将以不可变的方式更新此状态
# new_state = state.apply_gradients(grads=grads)
主要思路保持不变:状态显式地保存在一个对象(一个 PyTree)中,更新会生成 新的 状态对象,而不是就地修改。这与 JAX 的函数式方法完全吻合。
Haiku 采用略微不同的方法,使用 hk.transform 或 hk.transform_with_state。它将纯函数逻辑与参数(params)和可变状态(state,例如用于批量归一化统计信息)分离。
import haiku as hk
import optax
def forward_fn(x, is_training):
net = hk.Sequential([
hk.Linear(512), jax.nn.relu,
hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.99),
hk.Linear(10)
])
return net(x, is_training=is_training)
# 转换函数以处理状态
forward = hk.transform_with_state(forward_fn)
key = hk.PRNGSequence(0)
dummy_input = jnp.ones([1, 128])
# 初始化参数和可变状态
params, state = forward.init(next(key), dummy_input, is_training=True)
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(params)
# 应用函数(示例:is_training=True 更新批量归一化状态)
# logits, new_state = forward.apply(params, state, next(key), batch_data, is_training=True)
# grads, new_state = grad_fn(params, state, ...) # grad_fn 需要处理状态
# updates, new_opt_state = optimizer.update(grads, opt_state, params)
# new_params = optax.apply_updates(params, updates)
Haiku 要求您显式管理 init 和 apply 返回的 params 和 state PyTree。同样,函数式原则适用:状态作为输入传入,更新后的状态作为输出返回。
pmap)当使用 pmap 进行数据并行扩展训练时,正确管理跨多个设备的状态变得必要。
在典型的
pmap数据并行设置中,参数和优化器状态会被复制,数据会被分片,PRNG 密钥必须是每个设备独有的。映射的函数计算每个设备的损失和梯度等结果。
参数和优化器状态: 对于标准的数据并行,模型参数(params)和相关的优化器状态(opt_state)通常在所有设备上是相同的。您在主机上初始化它们一次,pmap 会自动将这些 PyTree 广播(复制)到每个设备。当在每个设备上计算梯度时,需要先聚合(例如,使用 lax.pmean 进行平均),然后才能更新参数和优化器状态。这个更新步骤通常也在 pmap 化函数中执行,确保复制的状态在所有设备上保持一致。
PRNG 密钥: 这经常是不易察觉的错误来源。如果您在所有设备上为诸如 dropout 或数据增强等操作使用 相同 的 PRNG 密钥,所有设备将生成 相同 的随机数,这会违背随机性的目的或导致相关结果。正确的方法是在调用 pmap 之前,在主机上生成一个主密钥,然后将其拆分为每个设备独有的子密钥。
num_devices = jax.local_device_count()
key = jax.random.PRNGKey(42)
# 在主机上拆分一次
device_keys = jax.random.split(key, num_devices)
# pmap 化函数示例(简化)
@jax.pmap
def train_step_pmap(params, opt_state, local_key, batch):
# 在内部使用每个设备的 'local_key'
dropout_key, new_local_key = jax.random.split(local_key)
# ... 使用 dropout_key 执行带有 dropout 的前向传播 ...
# ... 计算梯度 ...
# grads = ...
# grads = lax.pmean(grads, axis_name='devices') # 聚合梯度
# ... 更新参数和优化器状态 ...
# return loss, new_params, new_opt_state, new_local_key
# 调用 pmap,传入设备特定密钥的数组
# loss, params, opt_state, device_keys = train_step_pmap(params, opt_state, device_keys, sharded_batch)
创建一个包含唯一 PRNG 密钥的数组,每个设备一个,通过在将它们传入
pmap之前 拆分一个主密钥。
pmap 化函数在特定设备上的每次执行将从 device_keys 数组中接收其对应的唯一密钥。请记住也要从 pmap 化函数中返回更新后的密钥,以便在下一步中正确地继续 PRNG 序列。
训练大型模型可能需要数小时、数天甚至数周。定期将训练状态(检查点)保存到磁盘上是必要的。这使得您可以在中断后继续训练,并保存最终训练好的模型。
需要保存什么?通常,您需要:
params):学习到的权重和偏置。opt_state):对于正确恢复训练很重要,特别是对于带有动量或自适应学习率的优化器(如 Adam)。JAX 生态系统内的库为此提供工具。例如,orbax.checkpoint 正在成为一个标准方案,提供异步检查点(在后台保存状态而不中断训练)和灵活的数据保存结构等功能。像 Flax 这样的框架也内置了序列化工具(flax.training.checkpoints)。
# 使用 Orbax 的示例(简化)
import orbax.checkpoint as ocp
# 假设 'state' 是一个包含 params、opt_state、step 等的 PyTree
# 对于 pmap,状态可能会被复制。通常从设备 0 保存。
# state_to_save = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state)) # 从设备 0 获取状态
checkpointer = ocp.StandardCheckpointer() # Or AsyncCheckpointer
save_path = '/path/to/checkpoints/step_10000'
checkpointer.save(save_path, args=ocp.args.StandardSave(state))
# 恢复:
# restored_state = checkpointer.restore(save_path, args=ocp.args.StandardRestore(state))
# 如果使用 pmap,将恢复的状态复制到所有设备
# state = jax.device_put_replicated(restored_state, jax.local_devices())
有效管理参数和状态是在 JAX 中构建和训练大型模型的一个基本组成部分。通过使用 PyTree 和生态系统库提供的抽象,并仔细处理状态的分布和持久化,您可以构建可扩展的训练循环。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造