趋近智
反向模式自动微分是深度学习 (deep learning)中梯度下降 (gradient descent)的核心工具。它能高效计算一个标量输出(比如损失函数 (loss function))相对于可能数百万个输入(模型参数 (parameter))的梯度。尽管 jax.grad 提供了一种用户友好的方式来获取这些梯度,但理解其运作原理,即向量 (vector)-雅可比积 (VJP),能够实现更高级的功能。jax.vjp 是 JAX 中用于计算 VJP 的函数。
回顾一下,对于一个函数 ,它在点 处的雅可比矩阵 包含所有偏导数 。雅可比矩阵是一个 的矩阵。
反向模式自动微分计算一个 余切 向量 与雅可比矩阵的乘积,得到 。这个乘积就是向量-雅可比积 (VJP)。结果是一个 中的向量(或者严格来说,是一个 的行向量)。
这为什么有用?考虑一个典型的机器学习 (machine learning)情境,它涉及到一系列函数组合,最终会产生一个标量损失 。令 是这条链中的最后一个函数,它将某个中间表示 映射到最终的标量损失 。因此,。梯度 是一个行向量 。如果 且 ,链式法则告诉我们 。这正是一个 VJP!向量 是从输出反向传播 (backpropagation)的梯度(或余切),而 是我们想要其输入梯度的函数 的雅可比矩阵。
反向模式通过在计算图中反向传播这些余切向量来工作,对于一个标量损失函数 (loss function) ,从初始余切 开始。
jax.vjp 计算 VJPJAX 通过 jax.vjp 提供 VJP 计算的直接访问接口。其基本签名如下:
y, vjp_fun = jax.vjp(fun, *primals)
我们来分解一下:
fun: 要进行微分的 Python 可调用对象(函数)。*primals: 用于评估函数和计算 VJP 的输入参数 (parameter)(x, y 等)。这些是隐式计算雅可比矩阵的点。y: 函数在原像(primals)处评估的输出,即 y = fun(*primals)。这是前向传播的结果。vjp_fun: 一个计算 VJP 的 新函数(一个闭包)。这个函数接受一个参数,即余切向量 (vector) v,它必须与输出 y 具有相同的结构(形状和数据类型)。当作为 vjp_fun(v) 调用时,它返回 VJP:。结果是一个元组,包含针对每个原像输入的计算梯度。让我们从一个简单的函数 开始。
import jax
import jax.numpy as jnp
def f(x, y):
return x**2 * jnp.sin(y)
# 定义原始输入
x_primal = 3.0
y_primal = jnp.pi / 2.0
# 计算 VJP
y_out, vjp_fun = jax.vjp(f, x_primal, y_primal)
print(f"前向传播输出 y: {y_out}")
# 输出 y 是标量,所以余切 v 也是标量。
# 我们使用 v = 1.0,这相当于计算 d(1.0 * y) / dx 和 d(1.0 * y) / dy
# 这等同于找到 y 的梯度。
cotangent_v = 1.0
primal_grads = vjp_fun(cotangent_v)
print(f"关于 x 的梯度: {primal_grads[0]}") # 应该为 2 * x * sin(y) = 2 * 3.0 * sin(pi/2) = 6.0
print(f"关于 y 的梯度: {primal_grads[1]}") # 应该为 x^2 * cos(y) = 3.0^2 * cos(pi/2) = 0.0
# 与 jax.grad 比较
grad_f = jax.grad(f, argnums=(0, 1)) # 获取关于 x 和 y 的梯度
grads_via_grad = grad_f(x_primal, y_primal)
print(f"通过 jax.grad 获得的梯度: {grads_via_grad}")
在此情况下,调用 vjp_fun(1.0) 得到的结果与 jax.grad(f, argnums=(0, 1))(x, y) 相同。这表明了一个基本联系:针对标量函数的 jax.grad 实质上是 jax.vjp 的一个便捷封装,它自动提供初始余切 1.0。
现在考虑一个向量 (vector)输出的函数:,其中 。
import jax
import jax.numpy as jnp
def g(x):
return jnp.array([x[0]**2, x[1]**3])
# 定义原始输入
x_primal = jnp.array([2.0, 3.0])
# 计算 VJP
y_out, vjp_fun = jax.vjp(g, x_primal)
print(f"前向传播输出 y: {y_out}") # 应该为 [4., 27.]
# 输出 y 是一个形状为 (2,) 的向量,因此余切 v 也必须是形状为 (2,) 的向量
cotangent_v = jnp.array([10.0, 20.0]) # 一个任意的余切向量
# 计算 VJP: v^T J
primal_grad = vjp_fun(cotangent_v)
print(f"VJP 结果 (关于 x 的梯度): {primal_grad}")
# 预期:
# 雅可比矩阵 J = [[2*x0, 0], [0, 3*x1^2]] = [[4, 0], [0, 27]]
# v^T J = [10, 20] * [[4, 0], [0, 27]] = [10*4, 20*27] = [40, 540]
# jax.vjp 返回一个元组,即使只有一个原始输入,所以结果是 ([40., 540.],)
在这里,vjp_fun 接收向量 cotangent_v 并将其乘以在 x_primal 处评估的函数 g 的雅可比矩阵,从而产生返回给 x 的梯度贡献。余切的形状必须与函数输出 y_out 的形状匹配。
jax.vjp?尽管 jax.grad 足以应对需要标量损失梯度的标准梯度下降 (gradient descent)情况,但 jax.vjp 提供更多控制:
grad 更高效地组织。jax.custom_vjp 允许您为特定函数定义定制的 VJP 规则。理解 jax.vjp 是此项操作的必要条件。jax.vjp 可以清楚地了解反向模式 AD 如何运作以及 jax.grad 如何构建。总而言之,jax.vjp 为 JAX 强大的反向模式自动微分系统提供了一个更低级的接口。它计算向量 (vector)-雅可比积,这代表了反向传播 (backpropagation)的核心操作。尽管 jax.grad 处理标量损失梯度的常见情况,但掌握 jax.vjp 会为您提供处理更复杂的微分任务和更全面理解该系统的工具。
这部分内容有帮助吗?
jax.vjp 函数的官方API参考和使用详情,用于直接计算向量-雅可比积。© 2026 ApX Machine LearningAI伦理与透明度•