趋近智
JAX 通过采用函数式模式来管理状态,主要是通过明确地将状态传入函数并返回更新后的状态。这种方法保持了函数纯度,这对 JAX 的变换兼容性很重要。如何将这些函数式状态管理原则与 jax.jit、jax.grad、jax.vmap 和 jax.pmap 结合使用是一个主要关注点。函数纯度确保有状态函数可以与这些变换良好地组合。
将 jax.jit 与管理状态的函数一起使用是简单直接的。由于函数将当前状态作为参数 (parameter)并返回新状态,jit 可以像追踪任何其他纯函数一样追踪它。状态(通常表示为 PyTree)被追踪器视为常规输入和输出。
我们来回顾一下简单的有状态计数器例子:
import jax
import jax.numpy as jnp
def stateful_counter(count_state, increment):
"""递增计数器状态。"""
new_count = count_state['count'] + increment
return {'count': new_count} # 返回新状态
# 初始状态
initial_state = {'count': 0}
# 应用函数
state1 = stateful_counter(initial_state, jnp.array(1))
state2 = stateful_counter(state1, jnp.array(5))
print(f"初始状态: {initial_state}")
print(f"首次递增后的状态: {state1}")
print(f"第二次递增后的状态: {state2}")
现在,我们使用 jax.jit 编译 stateful_counter:
# JIT 编译函数
jit_counter = jax.jit(stateful_counter)
# 运行编译后的版本
jitted_state1 = jit_counter(initial_state, jnp.array(1))
jitted_state2 = jit_counter(jitted_state1, jnp.array(5)) # 如果形状/类型匹配,会重用编译后的代码
print(f"\n首次递增后的 JIT 编译状态: {jitted_state1}")
print(f"第二次递增后的 JIT 编译状态: {jitted_state2}")
# 验证结构和值是否相同
assert jax.tree_util.tree_all(jax.tree_map(lambda x, y: jnp.all(x == y), state2, jitted_state2))
如你所见,jit 毫无问题地处理状态字典(一个 PyTree)。JAX 使用初始状态结构和参数类型/形状来追踪函数。后续具有匹配结构和类型的调用会重用编译后的代码,为复杂的有状态计算(例如神经网络 (neural network)训练步骤)带来显著的速度提升。
重要提示: 请记住,jit 追踪函数是基于状态 PyTree 的结构及其叶节点(数组)的类型/形状。如果你的状态结构在调用之间发生改变(例如,向字典添加新键),jit 将需要重新编译函数,这可能影响性能。保持一致的状态结构是更好的做法。
jax.grad 的自动微分也与明确的状态传递结合良好。通常,状态包含我们希望求导的参数 (parameter)(例如模型权重 (weight))。函数通常既返回一个要微分的值(如损失)又返回更新后的状态。
考虑一个计算平方误差损失并更新参数状态的简单函数:
import jax
import jax.numpy as jnp
def predict_and_update(params, x):
"""一个简单的线性预测函数。状态 = 参数。"""
# 预测使用状态中的参数
pred = params['w'] * x + params['b']
# 返回预测值(value)和未改变的状态
return pred, params # 状态在此处未被修改
def loss_fn(params, x, y_target):
"""计算损失且不改变状态。"""
pred, _ = predict_and_update(params, x) # 使用预测函数
loss = jnp.mean((pred - y_target)**2)
# 只返回损失值;状态单独处理
return loss
# 示例参数(状态)和数据
params_state = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}
x_data = jnp.array([1.0, 2.0, 3.0])
y_target_data = jnp.array([3.5, 5.5, 7.5]) # 理想值: w=2.5, b=1.0
# 计算损失
current_loss = loss_fn(params_state, x_data, y_target_data)
print(f"当前损失: {current_loss}")
# 获取损失函数关于 'params'(参数 0)的梯度函数
grad_loss_fn = jax.grad(loss_fn, argnums=0) # 对参数求导
# 计算梯度
grads = grad_loss_fn(params_state, x_data, y_target_data)
print(f"梯度: {grads}")
这里,loss_fn 将 params(我们的状态)作为输入。我们使用 jax.grad 并指定 argnums=0 来获取关于 params 的梯度。JAX 正确追踪 loss_fn 内部使用的 predict_and_update 函数,并计算 w 和 b 的梯度。
通常,你既需要损失值也需要梯度。jax.value_and_grad 非常适合这个场景:
# 获取一个同时返回损失和梯度的函数
value_and_grad_fn = jax.value_and_grad(loss_fn, argnums=0)
# 同时计算损失和梯度
loss_val, grads_val = value_and_grad_fn(params_state, x_data, y_target_data)
print(f"\n使用 value_and_grad:")
print(f"损失: {loss_val}")
print(f"梯度: {grads_val}")
现在,让我们将其与状态更新步骤结合起来,模拟梯度下降 (gradient descent)的一个步骤:
def training_step(params, x, y_target, learning_rate):
"""执行一步梯度下降。"""
loss, grads = jax.value_and_grad(loss_fn, argnums=0)(params, x, y_target)
# 使用梯度更新参数(明确的状态更新)
# jax.tree_map 将函数逐元素应用于 PyTree
updated_params = jax.tree_map(
lambda p, g: p - learning_rate * g, params, grads
)
# 返回损失和新状态(更新后的参数)
return loss, updated_params
# 执行一步训练
learning_rate = 0.1
loss_step1, params_step1 = training_step(params_state, x_data, y_target_data, learning_rate)
print(f"\n经过一步训练后:")
print(f"损失: {loss_step1}")
print(f"更新后的参数: {params_step1}")
# 执行另一步
loss_step2, params_step2 = training_step(params_step1, x_data, y_target_data, learning_rate)
print(f"\n经过第二步训练后:")
print(f"损失: {loss_step2}")
print(f"更新后的参数: {params_step2}")
这个 training_step 函数接受参数状态,计算梯度,并返回更新后的参数状态。它是一个纯函数,使其适合进行如 jit 的进一步变换。
vmap 的向量 (vector)化jax.vmap 允许你自动向量化 (quantization)函数,包括那些管理状态的函数。这对于处理批量数据非常有用。你需要使用 in_axes 参数 (parameter)告诉 vmap 如何将每个参数(包括状态)映射到批处理维度上。
我们来修改计数器,使其能在批处理上操作。假设我们有一批增量,并希望有独立的计数器(尽管共享相同的逻辑):
# 定义一批计数器的初始状态
# 假设我们批处理中有 3 个计数器
batch_size = 3
# 状态需要相应地复制或批处理
batched_initial_state = {'count': jnp.array([0, 0, 0])} # 批处理状态
# 增量批处理
batched_increments = jnp.array([1, 5, 10])
# 对计数器函数按状态('count')和增量进行向量化
# state['count'] 轴 0,增量轴 0
vmap_counter = jax.vmap(stateful_counter, in_axes=({'count': 0}, 0))
# 应用向量化函数
batched_state1 = vmap_counter(batched_initial_state, batched_increments)
print(f"\n向量化计数器:")
print(f"批处理初始状态: {batched_initial_state}")
print(f"批处理增量: {batched_increments}")
print(f"增量后的批处理状态: {batched_state1}")
# 使用不同的增量再次应用
batched_increments2 = jnp.array([2, 3, 4])
batched_state2 = vmap_counter(batched_state1, batched_increments2)
print(f"第二次增量后的批处理状态: {batched_state2}")
这里,in_axes=({'count': 0}, 0) 告诉 vmap:
count_state),查看字典内部。对于 'count',沿轴 0 进行映射。increment),沿轴 0 进行映射。out_axes 参数(默认为所有输出的 0)指定了输出的结构方式。在这种情况下,返回的状态字典 {'count': ...} 的 'count' 值沿轴 0 堆叠。
在机器学习 (machine learning)中,批处理中共享模型参数很常见。vmap 通过将相应的 in_axes 设置为 None,可以轻松处理这种情况。
# 示例: 对一批 x 应用 predict_and_update,使用*相同*的参数
# 参数(状态)是共享的: in_axes=None
# x_data 是批处理的: in_axes=0
vmap_predict = jax.vmap(predict_and_update, in_axes=(None, 0))
# 我们原始的单一参数状态
params_state = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}
# x 数据的批处理
x_batch = jnp.array([1.0, 2.0, 3.0, 4.0])
# 运行向量化预测
# 输出状态将被复制,预测结果将被批处理
batched_preds, batched_params_out = vmap_predict(params_state, x_batch)
print(f"\n向量化预测 (共享参数):")
print(f"输入参数 (状态): {params_state}")
print(f"输入 x 批处理: {x_batch}")
print(f"批处理预测结果: {batched_preds}")
# 注意: 输出状态只是被 vmap 复制的输入状态
# print(f"Output Params (State): {batched_params_out}") # 会显示复制的参数
pmap 的并行化当使用 jax.pmap 在多个设备(GPU/TPU)上进行并行化时,类似的原则也适用,尽管细节涉及设备放置和集体操作。与 vmap 类似,pmap 需要使用 in_axes 指定如何将输入(包括状态)映射到设备。
in_axes=None)。每个设备都持有完整副本。in_axes=0)。当状态并行更新时(例如,在不同的数据分片上计算梯度),你通常需要在 pmap 化的函数中使用集体操作(lax.psum、lax.pmean 等),以便在更新复制状态之前聚合结果(如平均梯度)。使用 pmap 处理状态需要仔细考虑数据分布和同步,这建立在 pmap 章中讨论的原则之上。
当你组合这些变换时,它们真正的用处就显现出来。例如,典型的机器学习 (machine learning)训练循环涉及计算一批数据的梯度(vmap + grad),并为了性能而编译整个步骤(jit)。
我们来 JIT 编译我们的 training_step 函数:
# JIT 编译训练步骤函数
jit_training_step = jax.jit(training_step)
# 重置状态
params_state = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}
learning_rate = 0.1
x_data = jnp.array([1.0, 2.0, 3.0])
y_target_data = jnp.array([3.5, 5.5, 7.5])
print(f"\nJIT 编译后的训练步骤:")
print(f"初始参数: {params_state}")
# 运行编译后的训练步骤
loss_jitted1, params_jitted1 = jit_training_step(params_state, x_data, y_target_data, learning_rate)
print(f"步骤 1 损失: {loss_jitted1}, 参数: {params_jitted1}")
# 再次运行(由于编译缓存,应该更快)
loss_jitted2, params_jitted2 = jit_training_step(params_jitted1, x_data, y_target_data, learning_rate)
print(f"步骤 2 损失: {loss_jitted2}, 参数: {params_jitted2}")
这个 jit_training_step 现在能高效执行梯度计算和参数 (parameter)更新。如果我们需要同时处理多个独立的训练批次(不太常见),我们可以进一步用 vmap 封装它;或者将其集成到 pmap 中进行分布式训练。
明确的状态传递模式结合 PyTree,提供了一种有效的方式来处理状态,使其与 JAX 的核心变换和谐地协同工作,从而实现复杂、高性能的计算,例如训练大型机器学习模型。
这部分内容有帮助吗?
© 2026 ApX Machine LearningAI伦理与透明度•