趋近智
虽然 jax.grad 和 jax.vjp 为反向模式自动微分提供支持,这对于训练大多数神经网络 (neural network)很重要,但在某些情况下,需要直接控制前向模式微分(雅可比向量 (vector)积)。前向模式计算输入变化(由“切线”向量表示)如何通过函数向前传播以影响输出。jax.jvp 函数直接计算此值。
有时,JAX 自动推导 JVP 规则的结果可能不是最佳的,甚至无法进行。这可能由以下几个原因引起:
sqrt(x) 或 log(x) 等函数)。jnp.clip 的边界处)。jax.pure_callback),JAX 无法看到其内部以计算导数。在这些情况下,您可以使用 jax.custom_jvp 装饰器定义自己的 JVP 规则。这使您能够精确地告诉 JAX 如何计算函数的前向模式导数。
要定义自定义 JVP,您需要使用 @jax.custom_jvp 装饰您的函数,然后使用 @[您的函数名].defjvp 装饰器定义一个与该函数关联的单独的 JVP 规则函数。
让我们通过一个简单例子进行说明:实现一个数值稳定的 jnp.sqrt 版本,它处理 处的导数。
import jax
import jax.numpy as jnp
@jax.custom_jvp
def stable_sqrt(x):
"""计算 sqrt(x),在接近零时使用一个小的 epsilon 保证稳定性。"""
is_zero = (x == 0.0)
# 如果 x 恰好为零,避免在梯度计算中出现除以零的情况
safe_x = jnp.where(is_zero, 1.0, x)
return jnp.sqrt(safe_x)
# 为 stable_sqrt 定义自定义 JVP 规则
@stable_sqrt.defjvp
def stable_sqrt_jvp(primals, tangents):
"""stable_sqrt 的自定义 JVP 规则。"""
x, = primals
t, = tangents # t 是切线向量 dx
# 原始计算(前向传播结果)
primal_out = stable_sqrt(x)
# 切线计算(导数部分)
# sqrt(x) 的导数是 0.5 * x**(-0.5)
# 显式处理 x=0 情况,避免出现 NaN/inf
is_zero = (x == 0.0)
# 在 0 处使用一个大的有限数作为导数,或者使用 0,
# 取决于所需行为。这里,我们使用 0。
safe_x = jnp.where(is_zero, 1.0, x) # 避免除以零
deriv = 0.5 * jax.lax.rsqrt(safe_x)
tangent_out = jnp.where(is_zero, 0.0, deriv * t) # 在 x=0 处将导数设置为 0
return primal_out, tangent_out
# 示例用法
x_val = jnp.array(0.0)
t_val = jnp.array(1.0) # 切线向量
# 使用自定义规则计算 JVP
primal_output, tangent_output = jax.jvp(stable_sqrt, (x_val,), (t_val,))
print(f"输入 x: {x_val}")
print(f"输入切线 t: {t_val}")
print(f"原始输出 (stable_sqrt(x)): {primal_output}")
print(f"切线输出 (x 处导数 * t): {tangent_output}")
x_val_pos = jnp.array(4.0)
primal_output_pos, tangent_output_pos = jax.jvp(stable_sqrt, (x_val_pos,), (t_val,))
print(f"\n输入 x: {x_val_pos}")
print(f"输入切线 t: {t_val}")
print(f"原始输出 (stable_sqrt(x)): {primal_output_pos}")
print(f"切线输出 (x 处导数 * t): {tangent_output_pos}") # 应为 0.5 * 4**(-0.5) * 1 = 0.25
使用 .defjvp 装饰的函数(此处为 stable_sqrt_jvp)接受两个参数 (parameter):
primals:一个元组,包含原始函数 (stable_sqrt) 的原始输入。在我们的示例中,primals 是 (x,)。tangents:一个元组,包含与原始输入对应的切线值。这些值表示微分的“方向”。在我们的示例中,tangents 是 (t,),而 t 对应于 x。JVP 规则函数必须返回一对值:
primal_out:计算原始函数的结果(前向传播)。通常通过调用原始函数本身来计算,如上所示:primal_out = stable_sqrt(x)。tangent_out:雅可比向量 (vector)积的结果。这表示输出如何响应输入沿 tangents 指定的方向变化。在数学上,如果 是函数, 是输入,它计算 ,其中 是 在 处的雅可比矩阵, 是输入切线向量。在示例中,这是 jnp.where(is_zero, 0.0, deriv * t)。jax.custom_vjp 的关系jax.custom_jvp 定义前向模式微分规则。其对应项 jax.custom_vjp 定义反向模式规则(向量 (vector)雅可比积),这是 jax.grad 的依据。
jax.jvp,或通过 jax.jacfwd 计算完整雅可比矩阵),则仅定义 jax.custom_jvp 可能就足够了。jax.grad),通常需要使用 jax.custom_vjp 定义自定义 VJP 规则。jax.custom_jvp 和 jax.custom_vjp 两者。定义自定义 JVP 规则在以下情况下特别有用:
clip 的导数设置为 0 或 1)。与大多数 JAX 变换一样,具有自定义 JVP 规则的函数可以与 jax.jit、jax.vmap 和其他变换正确配合使用。JAX 将在追踪和编译期间使用您的自定义规则。
通过熟练掌握 jax.custom_jvp,您可以更好地控制 JAX 的前向模式自动微分,从而为专门的计算任务实现更高效的实现。这补充了 jax.custom_vjp 为反向模式微分提供的控制。
这部分内容有帮助吗?
jax.custom_jvp, JAX developers, 2024 - JAX 中定义自定义雅可比向量积(前向模式微分规则)的官方文档。© 2026 ApX Machine LearningAI伦理与透明度•