趋近智
jax.jitjitjit 的常见问题jitgrad 进行自动微分jax.gradgrad的grad)jax.value_and_grad)vmap 实现自动向量化jax.vmapin_axes,out_axes)vmapvmap 与 jit 和 gradvmap的性能考量pmap 在多设备上并行计算jax.pmapin_axes, out_axes)lax.psum、lax.pmean等)pmap 与其他变换结合使用pmap 化的函数正如章节引言中提到的,梯度在许多计算分支中都很常用,尤其是在机器学习里。训练模型通常需要通过迭代调整参数来优化一个函数(比如损失函数)。梯度指明了如何高效地进行这些调整。但梯度究竟是什么呢?
通俗地说,对于一个接受多个数值输入并产生单个数值输出(即标量函数)的函数,梯度是一个向量,它指向函数在特定输入点上增长最快的方向。这个梯度向量的大小(长度)表明了增长的陡峭程度。
思考一个只有一个输入变量的简单函数 f(x)。你可能在微积分中学过,它的导数,常写为 f′(x) 或 dxdf,表示函数图上点 x 处切线的斜率。它告诉你函数输出相对于输入的瞬时变化率。
函数 f(x)=x2 和在 x=2 处的切线。切线的斜率(4)就是该点处的导数(梯度)。
现在,我们将其延伸到一个拥有多个输入变量的函数,比如 f(x1,x2,...,xn)。梯度推广了导数的思想。我们得到的不是一个单一的斜率,而是一个偏导数向量。偏导数 ∂xi∂f 衡量了当仅改变输入 xi 而保持所有其他输入(xj 且 j=i)不变时,函数 f 如何变化。
函数 f 在点 (x1,...,xn) 处的梯度是包含其所有偏导数的向量:
∇f(x1,...,xn)=[∂x1∂f,∂x2∂f,...,∂xn∂f]符号 ∇(nabla)常用于表示梯度算子。
例如,思考函数 f(x,y)=x2+sin(y)。 对 x 的偏导数是 ∂x∂f=2x。 对 y 的偏导数是 ∂y∂f=cos(y)。 梯度向量是 ∇f(x,y)=[2x,cos(y)]。
在点 (x=1,y=0) 处,梯度为 ∇f(1,0)=[2(1),cos(0)]=[2,1]。这个向量 [2,1] 表明了从 (1,0) 开始函数 f(x,y) 增长最快的方向。
梯度的一个主要用途是函数最小化。设想你有一个“成本”或“损失”函数,它衡量了你的机器学习模型表现有多差。你希望找到使这个成本尽可能小的模型参数。
梯度 ∇f 指向 最陡峭的上升 方向。因此,负梯度 −∇f 指向 最陡峭的下降 方向。
这是梯度下降算法的核心原理:
每一步都使函数更接近一个局部最小值。
新参数=旧参数−学习率×∇f(旧参数)这里,learning_rate 是一个小的正标量,它控制着步长。
手动使用微积分规则计算梯度,就像我们为 f(x,y)=x2+sin(y) 所做的那样,对于简单函数是可行的。然而,机器学习中遇到的函数(如深度神经网络)可能涉及数百万个参数和复杂的操作组合。手动推导梯度变得不切实际且极易出错。
我们可以尝试数值微分,它通过在轻微扰动的点处评估函数来近似梯度(例如,使用 hf(x+h)−f(x) 的定义,其中 h 很小)。然而,这种方法存在近似误差(由于 h 的选择),并且计算成本可能很高,因为它对梯度的每个维度都需要多次函数评估。
符号微分,由计算机代数系统执行,它操作数学表达式以找到精确的导数表达式。虽然精确,但它可能导致非常复杂且可能低效的表达式(“表达式膨胀”),特别是对于大型计算。
这就是**自动微分(AD)**发挥作用的地方。AD 是一套技术,它通过在函数计算中对基本运算(加法、乘法、sin、cos 等)层面系统地应用微积分的链式法则,来高效地计算函数梯度的精确数值。它避免了数值微分的近似误差和符号微分可能产生的表达式膨胀问题。
JAX 的 grad 转换建立在高度优化的 AD 实现之上,特别是反向模式 AD(在神经网络中也称为反向传播)。这种模式对于机器学习中常见的输入众多但映射到单个标量输出(如损失函数)的函数尤其高效。
在接下来的章节中,我们将了解如何使用 jax.grad,以便将自动微分的优势运用到你自己的 Python 函数中。
这部分内容有帮助吗?
grad函数的官方文档,包含实际使用示例和其应用的详细信息。© 2026 ApX Machine Learning用心打造