趋近智
jax.jitjitjit 的常见问题jitgrad 进行自动微分jax.gradgrad的grad)jax.value_and_grad)vmap 实现自动向量化jax.vmapin_axes,out_axes)vmapvmap 与 jit 和 gradvmap的性能考量pmap 在多设备上并行计算jax.pmapin_axes, out_axes)lax.psum、lax.pmean等)pmap 与其他变换结合使用pmap 化的函数在JAX中管理状态时,主要思想是函数纯粹性:函数不应修改外部状态或有隐藏的副作用。相反,它们应将当前状态作为输入,并返回新状态作为输出,同时包含任何结果。这种“状态输入,状态输出”模式在使用JAX变换(如jit和grad)时非常重要。
这些练习将帮助您熟悉在不同情况下实现这种模式。我们将从简单的开始,逐步完成一个类似于机器学习优化组成部分的任务。请记住,JAX经常使用PyTrees(嵌套元组、列表、字典)来方便地处理复杂状态。
让我们回顾一下有状态计数器的例子。您的任务是实现一个update_counter函数,它接受当前计数(状态)和增量值。它应返回新计数以及增量值本身作为辅助输出。然后,对此函数应用jax.jit并进行测试。
说明:
jax.numpy。update_counter函数。它应接受count和increment作为参数。new_count = count + 1。(new_count, increment)。第一个元素是更新后的状态,第二个是辅助输出。jax.jit创建此函数的jit编译版本。count = 0)。for循环多次调用jit编译的函数(例如,5次)。在每次迭代中:
count和一个示例increment值(例如,循环索引)。count(更新下一次迭代的状态)和returned_increment。import jax
import jax.numpy as jnp
# 1. 定义有状态函数
def update_counter(count, increment):
"""递增计数并返回新计数和增量值。"""
new_count = count + 1
# 返回 (新状态, 辅助输出)
return new_count, increment
# 2. JIT编译函数
jitted_update_counter = jax.jit(update_counter)
# 3. 初始化状态
current_count = 0
print(f"初始计数:{current_count}")
# 4. 运行循环,每次更新状态
num_steps = 5
for i in range(num_steps):
# 传入当前状态,获取新状态和输出
current_count, returned_increment = jitted_update_counter(current_count, i)
print(f"步骤 {i+1}: 新计数 = {current_count}, 返回增量 = {returned_increment}")
预期输出:
初始计数:0
步骤 1: 新计数 = 1, 返回增量 = 0
步骤 2: 新计数 = 2, 返回增量 = 1
步骤 3: 新计数 = 3, 返回增量 = 2
步骤 4: 新计数 = 4, 返回增量 = 3
步骤 5: 新计数 = 5, 返回增量 = 4
这个简单例子展示了主要模式:循环在编译函数外部管理状态(current_count),在每一步中将其传入并接收更新后的版本。
现在,让我们实现一个函数来计算简单移动平均。移动平均是根据新的输入值更新的。我们还需要记录目前为止已见值的总和以及值的总数量。
说明:
update_moving_average函数,它接受state和new_value作为输入。state将是一个元组(current_sum, count)。适当地初始化它(例如,(0.0, 0))。state元组。new_sum = current_sum + new_value。new_count = count + 1。current_average = new_sum / new_count。如有必要,处理count可能为0的初始情况(尽管此处加1可避免除以零)。new_state = (new_sum, new_count)。(new_state, current_average)。update_moving_average。在每次迭代中更新状态变量并打印计算出的平均值。update_moving_average应用jax.jit并查看它是否正常工作。import jax
import jax.numpy as jnp
# 1. 定义用于移动平均的有状态函数
def update_moving_average(state, new_value):
"""使用新值更新移动平均状态。"""
current_sum, count = state # 解包状态
new_sum = current_sum + new_value
new_count = count + 1
current_average = new_sum / new_count
new_state = (new_sum, new_count) # 打包新状态
return new_state, current_average
# 可选:JIT编译函数
# jitted_update_moving_average = jax.jit(update_moving_average)
# 如果您取消注释此行,请在下方使用 jitted_update_moving_average
# 2. 示例数据和初始状态
data_sequence = jnp.array([2.0, 4.0, 6.0, 8.0, 10.0])
initial_state = (0.0, 0) # (总和, 计数)
print(f"初始状态 (总和, 计数): {initial_state}")
print(f"数据序列: {data_sequence}")
# 3. 迭代并更新
current_state = initial_state
for i, value in enumerate(data_sequence):
# 传入当前状态和值,获取新状态和平均值
current_state, avg = update_moving_average(current_state, value)
# 如果正在JIT编译,请使用此行替代:
# current_state, avg = jitted_update_moving_average(current_state, value)
print(f"值 {value:.1f} 后: 新状态 = ({current_state[0]:.1f}, {current_state[1]}), 移动平均 = {avg:.2f}")
预期输出:
初始状态 (总和, 计数): (0.0, 0)
数据序列: [ 2. 4. 6. 8. 10.]
值 2.0 后: 新状态 = (2.0, 1), 移动平均 = 2.00
值 4.0 后: 新状态 = (6.0, 2), 移动平均 = 3.00
值 6.0 后: 新状态 = (12.0, 3), 移动平均 = 4.00
值 8.0 后: 新状态 = (20.0, 4), 移动平均 = 5.00
值 10.0 后: 新状态 = (30.0, 5), 移动平均 = 6.00
在这里,状态是一个元组,一个简单的PyTree。函数正确地更新并返回这个结构化状态以及计算出的平均值。JIT编译应该可以工作,因为函数是纯粹的并遵循状态传递模式。
本练习模拟了像梯度下降这样的优化算法中的一个更新步骤。我们将把被优化的参数作为状态进行管理。我们的目标是使一个简单函数最小化,例如 f(x)=x2。
说明:
loss_fn(x) = x**2。jax.grad定义梯度函数:grad_fn = jax.grad(loss_fn)。gradient_descent_step函数。它应接受params(x的当前值)和learning_rate作为参数。grad_fn计算损失相对于params的梯度。new_params = params - learning_rate * gradient_value。new_params(这是更新后的状态)。params(例如,jnp.array(5.0))。learning_rate(例如,0.1)。gradient_descent_step,传入当前的params和learning_rate。new_params更新params变量。gradient_descent_step应用jax.jit。它能工作吗?为什么?import jax
import jax.numpy as jnp
# 1. 定义损失函数
def loss_fn(x):
return x**2
# 2. 获取梯度函数
grad_fn = jax.grad(loss_fn)
# 3. 定义状态更新函数(梯度下降步骤)
def gradient_descent_step(params, learning_rate):
"""执行一步梯度下降。"""
gradient_value = grad_fn(params)
new_params = params - learning_rate * gradient_value
# 返回新状态(更新后的参数)
return new_params
# 可选:JIT编译步骤函数
# jitted_gradient_descent_step = jax.jit(gradient_descent_step, static_argnums=(1,))
# 注意:如果learning_rate在编译函数的作用域内不改变,它通常被视为静态参数。
# 4. 初始化
current_params = jnp.array(5.0)
learning_rate = 0.1
num_steps = 5
print(f"初始参数: {current_params}")
print(f"学习率: {learning_rate}")
# 5. 执行更新步骤
for i in range(num_steps):
# 传入当前状态 (参数),获取新状态
current_params = gradient_descent_step(current_params, learning_rate)
# 如果正在JIT编译,请使用此行替代:
# current_params = jitted_gradient_descent_step(current_params, learning_rate)
print(f"步骤 {i+1}: 更新后的参数 = {current_params:.4f}")
预期输出:
初始参数: 5.0
学习率: 0.1
步骤 1: 更新后的参数 = 4.0000
步骤 2: 更新后的参数 = 3.2000
步骤 3: 更新后的参数 = 2.5600
步骤 4: 更新后的参数 = 2.0480
步骤 5: 更新后的参数 = 1.6384
本练习展示了状态(模型参数params)是如何在更新函数外部明确管理的。gradient_descent_step函数是纯粹的;它根据输入计算新状态并返回。这种模式在JAX中构建优化器或训练循环时非常重要。请注意,如果您jit此函数,如果learning_rate在编译函数的作用域内不改变,它通常最好标记为静态参数(使用static_argnums或static_argnames),因为这可以提升编译效率。
这些实际练习巩固了JAX所需的状态管理函数式方法。通过明确地传入和传出状态,您的函数保持纯粹并与JAX的强大变换(如jit、grad、vmap和pmap)兼容。这种模式可以从简单的计数器有效扩展到涉及神经网络参数和优化器统计信息的嵌套PyTrees等复杂状态。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造