趋近智
jax.jitjitjit 的常见问题jitgrad 进行自动微分jax.gradgrad的grad)jax.value_and_grad)vmap 实现自动向量化jax.vmapin_axes,out_axes)vmapvmap 与 jit 和 gradvmap的性能考量pmap 在多设备上并行计算jax.pmapin_axes, out_axes)lax.psum、lax.pmean等)pmap 与其他变换结合使用pmap 化的函数机器学习中的常见情况涉及管理优化算法的状态,这可以通过显式状态传递模式来处理。随机梯度下降(SGD)或 Adam 这样的优化器根据计算出的梯度迭代更新模型参数。除了参数本身,许多优化器也维护自己的内部状态,比如动量向量或自适应学习率。
考虑基本的梯度下降更新规则。为了找到函数 f(params) 的最小值,我们通过沿着梯度 ∇f(params) 相反的方向移动来迭代更新参数 params:
params新=params旧−学习率×∇f(params旧)在函数式环境中,我们不能就地修改 params。相反,我们需要一个函数,它接受当前参数和梯度,并返回 新 参数。
让我们尝试最小化一个简单的二次函数 f(x)=(x−3)2。最小值显然在 x=3。梯度是 ∇f(x)=2(x−3)。
import jax
import jax.numpy as jnp
# 定义要最小化的函数
def objective_function(x):
return (x - 3.0)**2
# 使用 jax.grad 计算梯度
grad_fn = jax.grad(objective_function)
# 定义优化器更新步骤(简单 SGD)
# 此函数接受当前状态(此处仅为 'params')
# 和梯度,并返回新状态。
def sgd_update(params, gradients, learning_rate):
"""执行一个 SGD 更新步骤。"""
new_params = params - learning_rate * gradients
# 返回更新后的状态
return new_params
# --- 优化循环 ---
# 初始参数值(我们的初始状态)
current_params = jnp.array(0.0)
learning_rate = 0.1
num_steps = 20
print(f"初始参数: {current_params:.4f}")
# 运行优化循环
for step in range(num_steps):
# 1. 计算当前参数的梯度
gradients = grad_fn(current_params)
# 2. 使用更新函数计算新参数
# 传递当前状态 ('current_params') 和梯度
# 接收新状态 ('next_params')
next_params = sgd_update(current_params, gradients, learning_rate)
# 3. 更新状态以进行下一次迭代
current_params = next_params
if (step + 1) % 5 == 0:
print(f"步骤 {step+1:3d}, 参数: {current_params:.4f}, 梯度: {gradients:.4f}")
print(f"\n最终优化参数: {current_params:.4f}")
在这个循环中,current_params 持有状态。sgd_update 函数是纯粹的;它接受状态和梯度,并返回一个 新 状态 (next_params)。然后我们显式地重新赋值 current_params = next_params,以便将状态向前传递。这种模式与 JAX 转换完美配合。
许多优化器需要额外状态。让我们实现带动量的 SGD。更新涉及一个速度项 v:
v新=动量×v旧+学习率×∇f(params旧)params新=params旧−v新注意到我们现在需要追踪 params 和 velocity (v)。这种组合信息构成优化器的状态。我们可以使用 PyTree,例如字典,来存储这种结构化状态。
import jax
import jax.numpy as jnp
from typing import NamedTuple # 或使用字典
# 定义要最小化的函数(同前)
def objective_function(x):
return (x - 3.0)**2
# 使用 jax.grad 计算梯度
grad_fn = jax.grad(objective_function)
# 使用 NamedTuple(或字典)定义优化器状态的结构
class OptimizerState(NamedTuple):
params: jax.Array
velocity: jax.Array
# 定义优化器更新步骤(带动量的 SGD)
# 此函数接受组合状态(参数和速度)
# 和梯度,并返回新的组合状态。
def momentum_update(state: OptimizerState, gradients, learning_rate, momentum):
"""执行一个带动量的 SGD 更新步骤。"""
new_velocity = momentum * state.velocity + learning_rate * gradients
new_params = state.params - new_velocity
# 将更新后的状态作为新的 OptimizerState 对象返回
return OptimizerState(params=new_params, velocity=new_velocity)
# --- 带动量的优化循环 ---
# 初始参数值和速度
initial_params = jnp.array(0.0)
initial_velocity = jnp.array(0.0)
# 初始状态现在是一个包含参数和速度的结构
current_state = OptimizerState(params=initial_params, velocity=initial_velocity)
learning_rate = 0.1
momentum_coeff = 0.9 # 常用动量值
num_steps = 20
print(f"初始状态: 参数={current_state.params:.4f}, 速度={current_state.velocity:.4f}")
# 运行优化循环
for step in range(num_steps):
# 1. 计算状态中当前参数的梯度
gradients = grad_fn(current_state.params)
# 2. 使用动量更新函数计算新状态
# 传递当前的组合状态和梯度
# 接收新的组合状态
next_state = momentum_update(current_state, gradients, learning_rate, momentum_coeff)
# 3. 更新状态以进行下一次迭代
current_state = next_state
if (step + 1) % 5 == 0:
print(f"步骤 {step+1:3d}, 参数: {current_state.params:.4f}, 速度: {current_state.velocity:.4f}, 梯度: {gradients:.4f}")
print(f"\n最终优化参数: {current_state.params:.4f}")
这里,状态 (current_state) 是一个 NamedTuple(也可以是字典 { 'params': ..., 'velocity': ... })。momentum_update 函数将整个状态对象作为输入,并返回一个新的、已更新的状态对象。循环结构保持不变:计算梯度,使用当前状态调用更新函数,并使用返回的新状态进行下一步。
这种显式状态管理模式的一个重要优势在于它与 jax.jit 等 JAX 转换的自然兼容性。我们可以轻松编译我们的更新函数以获得更好的性能:
# 编译动量更新函数
jitted_momentum_update = jax.jit(momentum_update, static_argnums=(2, 3)) # learning_rate 和 momentum 是静态的
# --- 使用 JIT 编译的函数重新运行优化循环 ---
# 重置状态
current_state = OptimizerState(params=initial_params, velocity=initial_velocity)
print("\n--- 正在使用 JIT 编译的更新运行 ---")
print(f"初始状态: 参数={current_state.params:.4f}, 速度={current_state.velocity:.4f}")
for step in range(num_steps):
gradients = grad_fn(current_state.params) # 如果需要,也可以对 grad_fn 进行 JIT 编译
# 使用编译后的更新函数
next_state = jitted_momentum_update(current_state, gradients, learning_rate, momentum_coeff)
current_state = next_state
if (step + 1) % 5 == 0:
print(f"步骤 {step+1:3d}, 参数: {current_state.params:.4f}, 速度: {current_state.velocity:.4f}, 梯度: {gradients:.4f}")
print(f"\n最终优化参数 (JIT 编译): {current_state.params:.4f}")
因为 momentum_update 是一个纯函数(它的输出仅取决于其输入,没有副作用)并遵循显式状态传递模式,所以 jax.jit 可以有效地追踪和编译它。我们将 learning_rate 和 momentum 标记为静态参数,因为它们的值在循环期间不会改变,并且它们不是 JAX 数组,这可以防止不必要的重新编译。
这个例子展示了如何在 JAX 的函数式方法中处理优化器的演变状态。通过显式地将状态传递进出纯更新函数,通常使用 PyTree 进行结构化,我们创建了清晰、易于管理且易于兼容 JAX 强大转换(如 jit 和 grad)的代码。这种模式在 JAX 中构建更复杂的模型和训练循环时很基本。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造