趋近智
自动微分(AD)是 jax.grad 背后的技术。它是一套用于数值计算计算机程序所定义函数导数的方法。与符号微分(操作数学表达式)或数值微分(使用有限差分,可能产生近似误差)不同,AD 通过在基本运算层面系统应用微积分链式法则,高效地计算出精确导数。
AD 主要有两种模式:正向模式和反向模式。jax.grad 主要使用反向模式自动微分,在机器学习 (machine learning)社区中常被称为反向传播 (backpropagation)。让我们来了解它是如何运作的。
反向模式 AD 的核心是将任何计算看作一系列基本运算(如加法、乘法、sin、cos 等)应用于某些输入值的过程。这个序列可以表示为计算图,图中的节点代表输入变量或运算,有向边代表数据流。
考虑一个简单函数:。我们可以将其分解为中间步骤:
计算图如下所示:。
的一个简单计算图。数据从输入 经过中间运算 流向最终输出 。
当你正常执行函数,例如 f(2.0) 时,你将进行一次图的正向传播:
在此正向传播过程中,JAX 等 AD 系统不仅计算最终值 ,通常还会存储中间值(如 )以及图的结构。这些是下一阶段所需的。
目标是计算最终输出 相对于初始输入 的梯度,即 。反向模式通过从输出开始,将导数反向传播通过图来实现这一目标。
初始化: 输出相对于自身的导数显然是1。我们将这种敏感度或伴随表示为 。这是从末端进入图的初始“梯度信号”。
反向步骤(sin 节点): 我们从 反向移动到 。我们需要找到 的变化如何影响 ,这就是局部导数 。然后我们使用链式法则找到到达 的梯度信号:
这个 表示最终输出 对中间变量 变化的敏感度。因为 ,所以局部导数 。因此,。使用从正向传播中存储的值 ,。
反向步骤(平方节点): 我们从 继续反向到 。我们需要局部导数 。我们再次使用链式法则,将传入的梯度信号 ( 对 的敏感度)乘以局部导数 :
因为 ,所以局部导数 。因此,。使用 和输入值 (也可能被存储或重新计算),。这个 就是我们所需的梯度 。
反向传播本质上是计算最终输出对每个中间变量和输入的敏感程度,从输出开始反向工作,并在每一步应用链式法则。
正向传播(计算值)和反向传播(使用链式法则计算梯度)的流程。反向传播需要正向传播期间计算得到的值(x, a),并反向传播敏感度。
为什么反向模式是 jax.grad 的默认模式,并在机器学习中普遍使用?考虑一个典型的神经网络 (neural network)损失函数 (loss function):,其中 是许多权重 (weight)矩阵/向量 (vector)(参数 (parameter)), 是输入数据, 是目标。这个函数可能接收数百万个输入(参数和数据),但只产生一个标量输出(损失 )。
反向模式的计算成本相对而言对输入的数量不敏感。它需要一次通过计算图的正向传播(用于计算中间值)和一次反向传播 (backpropagation)(用于计算梯度)。总成本通常是评估原始函数成本的一个小常数倍(例如,2-4倍)。这使得它在同时计算 相对于所有参数的梯度时非常高效,这正是梯度下降 (gradient descent)等基于梯度的优化算法所需要的。
正向模式则一次计算一个输入值的导数。其成本与你需要计算梯度的输入数量呈线性关系。虽然在某些情况下有用,但对于训练拥有数百万参数的大型模型来说,它的计算成本过高。
当你将 jax.grad 应用于你的 Python 函数时,JAX 会在底层执行以下步骤:
jax.value_and_grad,此传播也会计算最终函数输出)。你作为用户,使用熟悉的 Python 和 NumPy 类似语法定义正向计算。JAX 负责根据该定义,通过反向模式 AD 推导并执行高效的梯度计算。这种关注点分离使你能够专注于模型逻辑,同时依赖 JAX 进行高性能的微分。
这部分内容有帮助吗?
jax.grad 进行自动微分的原理和用法,包括反向模式的具体细节。© 2026 ApX Machine LearningAI伦理与透明度•