让我们将理论付诸实践。您已经知道 jax.custom_vjp 允许您为函数定义自己的向量-雅可比积规则。这在以下几种情况下很有用:数值稳定性: 自动推导的梯度可能存在数值问题(如溢出或下溢),而精心设计的解析梯度可以避免这些问题。性能: 您可能知道一种比对前向传播求导计算梯度更有效率的方法。非 JAX 代码: 如果您的函数涉及调用外部库或 JAX 不知道如何求导的操作,您必须手动提供梯度。梯度停止/近似: 您可能希望提供近似梯度,或阻止梯度流经计算的某些部分。在本节中,我们将为一个常见函数实现自定义 VJP,侧重于其工作原理和验证。示例:Softplus 的自定义梯度考虑 Softplus 函数,其定义为: $$ \text{softplus}(x) = \log(1 + e^x) $$ 此函数常被用作 ReLU 激活函数的平滑近似。我们首先编写一个标准的 JAX 实现:import jax import jax.numpy as jnp import numpy as np # Used for comparison/testing # 朴素实现 - 对于大 x 值可能不稳定 def softplus_naive(x): return jnp.log(1 + jnp.exp(x)) # 示例用法 x_val = 10.0 print(f"softplus_naive({x_val}) = {softplus_naive(x_val)}") # 使用 JAX 的自动微分计算梯度 grad_softplus_naive = jax.grad(softplus_naive) print(f"在 {x_val} 处的梯度 (JAX 自动): {grad_softplus_naive(x_val)}")对于非常大的正数 $x$, $e^x$ 可能会超出标准浮点表示(如 float32)的范围。尽管 JAX 和 XLA 通常能够智能地处理此类情况,但假设我们希望为了教学目的,或者在自动梯度确实引发问题的情况下,明确提供梯度规则。softplus(x) 的解析导数为: $$ \frac{d}{dx} \text{softplus}(x) = \frac{e^x}{1 + e^x} = \frac{1}{1 + e^{-x}} = \text{sigmoid}(x) $$ Sigmoid 函数是数值稳定的。我们可以在自定义 VJP 中使用此解析形式。使用 jax.custom_vjp 实现要定义自定义 VJP,我们需要三个组成部分:原始函数,使用 @jax.custom_vjp 装饰。一个前向传播函数(fwd),它计算原始输出并保存反向传播所需的任何中间值(residuals)。它必须返回 (output, residuals)。一个反向传播函数(bwd),它接收 residuals 和上游梯度(g,表示 $d\mathcal{L}/dy$,其中 $y$ 是我们函数的输出),并计算相对于原始输入的下游梯度($d\mathcal{L}/dx = (d\mathcal{L}/dy) \times (dy/dx)$)。它必须返回一个梯度元组,元组中每个元素对应原始函数的一个位置参数。让我们为 softplus 实现这一点:import jax import jax.numpy as jnp import numpy as np # 1. 定义函数并进行装饰 @jax.custom_vjp def softplus_custom(x): """用于自定义 VJP 的 Softplus 实现。""" # 为演示目的,我们在此处使用可能不稳定的前向传播。 # 在实际场景中,您可能也会使用更稳定的前向传播。 return jnp.log(1 + jnp.exp(x)) # 2. 定义前向传播函数 (_fwd) # 它接受与原始函数 (x) 相同的参数。 # 它返回输出 (y = softplus(x)) 和反向传播所需的中间值。 # 我们需要原始输入 'x' 以在反向传播中计算 sigmoid(x)。 def softplus_fwd(x): y = softplus_custom(x) # 使用装饰过的函数计算原始输出 residuals = x # 保存 'x' 以供反向传播使用 return y, residuals # 3. 定义反向传播函数 (_bwd) # 它接收 _fwd 保存的中间值和上游梯度 'g'。 # 它返回一个关于原始函数输入的梯度元组。 # 这里,唯一的输入是 'x',因此我们返回一个包含一个元素的元组。 def softplus_bwd(residuals, g): x = residuals # 解包中间值 # 计算梯度:g * sigmoid(x) # 使用 jax.nn.sigmoid 确保数值稳定性 grad_x = g * jax.nn.sigmoid(x) return (grad_x,) # 将关于 'x' 的梯度作为元组返回 # 将 fwd 和 bwd 函数与 custom_vjp 函数关联 softplus_custom.defvjp(softplus_fwd, softplus_bwd) # --- 现在让我们测试它 --- # 使用自定义 VJP 实现计算梯度 grad_softplus_custom = jax.grad(softplus_custom) # 测试值 test_values = jnp.array([-10.0, 0.0, 10.0, 80.0, 100.0]) # 包含大值 # 比较朴素自动微分梯度与自定义梯度 print("输入 | 朴素梯度 (自动微分) | 自定义梯度 (VJP)") print("------|-----------------------|-------------------") for x_val in test_values: # 为朴素版本使用 try-except 块,该版本可能溢出/发出警告 try: naive_grad = jax.grad(softplus_naive)(x_val) except OverflowError: naive_grad = float('inf') # 或者适当处理 custom_grad = grad_softplus_custom(x_val) print(f"{x_val:<5.1f} | {naive_grad:<21.8f} | {custom_grad:<19.8f}") # JIT 编译按预期工作 jit_grad_softplus_custom = jax.jit(jax.grad(softplus_custom)) print(f"\nJIT 编译的自定义梯度在 10.0 处: {jit_grad_softplus_custom(10.0):.8f}")验证与分析运行上述代码时,您应该会发现由 jax.grad(softplus_naive) 和 jax.grad(softplus_custom) 计算的梯度是相同的(在浮点精度范围内)。正确性: 这证实了我们的自定义 VJP 规则正确计算了解析梯度(sigmoid)。数值稳定性(梯度): 反向传播显式使用了 jax.nn.sigmoid(x),这是数值稳定的。尽管 JAX 对 log(1 + exp(x)) 的默认自动微分通常也能为梯度生成稳定结果,但显式定义它能保证这种行为。请注意,softplus_custom 中的前向传播仍是朴素版本;在生产场景中,可能需要实现稳定的前向传播并提供自定义 VJP。机制: 我们定义了 softplus_fwd 来计算原始函数的值并缓存输入 x。softplus_bwd 函数随后使用此缓存的 x 和传入的梯度 g 来计算最终梯度 g * sigmoid(x)。defvjp 调用将这两个函数与 @jax.custom_vjp 装饰器关联起来。考量中间值: 仔细选择要在 residuals 中保存的内容。保存不必要的大的中间数组会增加内存消耗。只保存反向传播严格所需的内容。有时,在反向传播中重新计算值可以作为降低内存使用量的权衡。多输入/多输出: 如果您的函数有多个输入或输出,fwd 函数接收所有输入,bwd 函数必须返回一个梯度元组,其长度与原始函数的位置输入数量相同。对于输出,bwd 函数接收一个 g,其结构与原始输出匹配(或者对于不涉及梯度路径的输出为 None)。jax.custom_jvp: 对于前向模式微分,过程类似,使用 @jax.custom_jvp、defjvp,并定义一个函数,在给定 (x, x_dot) 的情况下计算 (y, y_dot)。本次实践练习说明了在 JAX 中定义自定义梯度的核心机制。它是处理特定数值或性能要求的有力工具。请记住,在可行的情况下,务必针对 JAX 的自动微分或数值梯度彻底测试您的自定义规则。