趋近智
jax.custom_vjp 允许为函数定义自己的向量 (vector)-雅可比积规则。这在以下几种情况下很有用:
在本节中,我们将为一个常见函数实现自定义 VJP,侧重于其工作原理和验证。
考虑 Softplus 函数,其定义为:
此函数常被用作 ReLU 激活函数 (activation function)的平滑近似。
我们首先编写一个标准的 JAX 实现:
import jax
import jax.numpy as jnp
import numpy as np # Used for comparison/testing
# 朴素实现 - 对于大 x 值可能不稳定
def softplus_naive(x):
return jnp.log(1 + jnp.exp(x))
# 示例用法
x_val = 10.0
print(f"softplus_naive({x_val}) = {softplus_naive(x_val)}")
# 使用 JAX 的自动微分计算梯度
grad_softplus_naive = jax.grad(softplus_naive)
print(f"在 {x_val} 处的梯度 (JAX 自动): {grad_softplus_naive(x_val)}")
对于非常大的正数 , 可能会超出标准浮点表示(如 float32)的范围。尽管 JAX 和 XLA 通常能够智能地处理此类情况,但假设我们希望为了教学目的,或者在自动梯度确实引发问题的情况下,明确提供梯度规则。
softplus(x) 的解析导数为:
Sigmoid 函数是数值稳定的。我们可以在自定义 VJP 中使用此解析形式。
jax.custom_vjp 实现要定义自定义 VJP,我们需要三个组成部分:
@jax.custom_vjp 装饰。fwd),它计算原始输出并保存反向传播 (backpropagation)所需的任何中间值(residuals)。它必须返回 (output, residuals)。bwd),它接收 residuals 和上游梯度(g,表示 ,其中 是我们函数的输出),并计算相对于原始输入的下游梯度()。它必须返回一个梯度元组,元组中每个元素对应原始函数的一个位置参数 (parameter)。让我们为 softplus 实现这一点:
import jax
import jax.numpy as jnp
import numpy as np
# 1. 定义函数并进行装饰
@jax.custom_vjp
def softplus_custom(x):
"""用于自定义 VJP 的 Softplus 实现。"""
# 为演示目的,我们在此处使用可能不稳定的前向传播。
# 在实际场景中,您可能也会使用更稳定的前向传播。
return jnp.log(1 + jnp.exp(x))
# 2. 定义前向传播函数 (_fwd)
# 它接受与原始函数 (x) 相同的参数。
# 它返回输出 (y = softplus(x)) 和反向传播所需的中间值。
# 我们需要原始输入 'x' 以在反向传播中计算 sigmoid(x)。
def softplus_fwd(x):
y = softplus_custom(x) # 使用装饰过的函数计算原始输出
residuals = x # 保存 'x' 以供反向传播使用
return y, residuals
# 3. 定义反向传播函数 (_bwd)
# 它接收 _fwd 保存的中间值和上游梯度 'g'。
# 它返回一个关于原始函数输入的梯度元组。
# 这里,唯一的输入是 'x',因此我们返回一个包含一个元素的元组。
def softplus_bwd(residuals, g):
x = residuals # 解包中间值
# 计算梯度:g * sigmoid(x)
# 使用 jax.nn.sigmoid 确保数值稳定性
grad_x = g * jax.nn.sigmoid(x)
return (grad_x,) # 将关于 'x' 的梯度作为元组返回
# 将 fwd 和 bwd 函数与 custom_vjp 函数关联
softplus_custom.defvjp(softplus_fwd, softplus_bwd)
# --- 现在让我们测试它 ---
# 使用自定义 VJP 实现计算梯度
grad_softplus_custom = jax.grad(softplus_custom)
# 测试值
test_values = jnp.array([-10.0, 0.0, 10.0, 80.0, 100.0]) # 包含大值
# 比较朴素自动微分梯度与自定义梯度
print("输入 | 朴素梯度 (自动微分) | 自定义梯度 (VJP)")
print("------|-----------------------|-------------------")
for x_val in test_values:
# 为朴素版本使用 try-except 块,该版本可能溢出/发出警告
try:
naive_grad = jax.grad(softplus_naive)(x_val)
except OverflowError:
naive_grad = float('inf') # 或者适当处理
custom_grad = grad_softplus_custom(x_val)
print(f"{x_val:<5.1f} | {naive_grad:<21.8f} | {custom_grad:<19.8f}")
# JIT 编译按预期工作
jit_grad_softplus_custom = jax.jit(jax.grad(softplus_custom))
print(f"\nJIT 编译的自定义梯度在 10.0 处: {jit_grad_softplus_custom(10.0):.8f}")
运行上述代码时,您应该会发现由 jax.grad(softplus_naive) 和 jax.grad(softplus_custom) 计算的梯度是相同的(在浮点精度范围内)。
jax.nn.sigmoid(x),这是数值稳定的。尽管 JAX 对 log(1 + exp(x)) 的默认自动微分通常也能为梯度生成稳定结果,但显式定义它能保证这种行为。请注意,softplus_custom 中的前向传播仍是朴素版本;在生产场景中,可能需要实现稳定的前向传播并提供自定义 VJP。softplus_fwd 来计算原始函数的值并缓存输入 x。softplus_bwd 函数随后使用此缓存的 x 和传入的梯度 g 来计算最终梯度 g * sigmoid(x)。defvjp 调用将这两个函数与 @jax.custom_vjp 装饰器关联起来。residuals 中保存的内容。保存不必要的大的中间数组会增加内存消耗。只保存反向传播 (backpropagation)严格所需的内容。有时,在反向传播中重新计算值可以作为降低内存使用量的权衡。fwd 函数接收所有输入,bwd 函数必须返回一个梯度元组,其长度与原始函数的位置输入数量相同。对于输出,bwd 函数接收一个 g,其结构与原始输出匹配(或者对于不涉及梯度路径的输出为 None)。jax.custom_jvp: 对于前向模式微分,过程类似,使用 @jax.custom_jvp、defjvp,并定义一个函数,在给定 (x, x_dot) 的情况下计算 (y, y_dot)。本次实践练习说明了在 JAX 中定义自定义梯度的核心机制。它是处理特定数值或性能要求的有力工具。请记住,在可行的情况下,务必针对 JAX 的自动微分或数值梯度彻底测试您的自定义规则。
这部分内容有帮助吗?
© 2026 ApX Machine LearningAI伦理与透明度•