定义一个自定义原语,并为其提供用于形状和数据类型推断的抽象求值规则和实现后端特定的降低规则,即可确立其主要功能。然而,要使该原语在JAX生态系统中真正发挥作用,它需要与JAX的自动微分系统配合。如果没有微分规则,将jax.grad、jax.jvp或jax.vjp应用于包含该原语的函数将导致错误,因为JAX将无法知道如何通过它传播梯度。本节说明了如何规定自定义原语的必需微分规则,使其能够完全参与基于梯度的优化及其他自动微分任务。自定义规则的必要性JAX的自动微分机制依赖于了解它在跟踪过程中遇到的每个原语操作的微分规则。对于jax.lax.add或jax.lax.sin等内置原语,这些规则已在内部规定。但对于您的自定义原语,JAX对其在微分方面的数学行为没有预先知识。因此,您必须明确提供这些规则。JAX使用两种基本模式的自动微分:前向模式 (JVP): 计算雅可比向量积。它在计算图中向前传播导数。您可以使用primitive.def_jvp来规定此项。反向模式 (VJP): 计算向量雅可比积。它在计算图中向后传播导数(余切)。这是jax.grad主要使用的模式。您可以使用primitive.def_vjp来规定此项。要使您的原语完全可微分并可用于JAX所有微分变换,您通常需要规定JVP和VJP规则。规定雅可比向量积 (JVP) 规则JVP规则描述了原语输入中的扰动(切线)如何影响其输出。在数学上,如果您的原语计算$y = f(x)$,则JVP规则计算$J v$,其中$J$是$f$在$x$处的雅可比,而$v$是输入切线向量。您可以使用原语对象的def_jvp方法来规定JVP规则。您提供给def_jvp的函数应具有特定签名:def primitive_jvp(primals, tangents): # primals: 原语的原始输入元组。 # tangents: 与原始值对应的切线值元组。 # 切线具有与相应原始值相同的形状和数据类型。 # 如果对应的原始值在微分方面被认为是常数,切线值可能是jax.ad.Zero。 # 1. 计算原始输出(与抽象求值或实现相同) primal_outputs = ... # 计算 f(原始值) # 2. 根据原始值和输入切线计算切线输出 tangent_outputs = ... # 计算 J @ 切线 return primal_outputs, tangent_outputsJVP规则函数接收原始输入及其对应的切线。它必须返回一对值:原始输出(应与原语的标准求值结果一致)和对应的输出切线。示例:自定义缩放原语的JVP让我们回顾一个简单的自定义原语custom_scale_p,它通过一个factor缩放输入x:$y = x \times factor$。import jax import jax.numpy as jnp from jax.core import Primitive from jax.interpreters import ad from jax.interpreters import mlir # 假设custom_scale_p已通过抽象求值和降低定义 # 例如: custom_scale_p = Primitive('custom_scale') @custom_scale_p.def_impl def _custom_scale_impl(x, factor): # CPU实现示例 return x * factor @custom_scale_p.def_abstract_eval def _custom_scale_abstract_eval(x_abs, factor_abs): # 抽象求值示例 assert jax.core.get_aval(factor_abs) == jax.core.ShapedArray((), factor_abs.dtype) return jax.core.ShapedArray(x_abs.shape, x_abs.dtype) # 假设MLIR降低也已定义... # 现在,规定JVP规则 @custom_scale_p.def_jvp def _custom_scale_jvp(primals, tangents): x, factor = primals x_dot, factor_dot = tangents # 输入切线 # 原始输出计算(也可以调用实现) y = x * factor # 计算输出切线 y_dot # 使用乘积规则:d(x*factor)/dt = (dx/dt)*factor + x*(dfactor/dt) # 正确处理零切线 y_dot = ad.Zero.zero_if_zero(factor_dot) if not isinstance(x_dot, ad.Zero): y_dot += x_dot * factor if not isinstance(factor_dot, ad.Zero): y_dot += x * factor_dot # 确保输出切线具有与输出原始值相同的结构 if isinstance(y_dot, ad.Zero) and y is not None: # 如果y_dot为Zero,则创建一个具有正确形状/数据类型的具体零切线 y_dot = jnp.zeros_like(y) print(f"自定义JVP: x={x}, factor={factor}, x_dot={x_dot}, factor_dot={factor_dot}, y={y}, y_dot={y_dot}") return y, y_dot # jax.jvp的使用示例 x_val = jnp.array([1.0, 2.0, 3.0]) factor_val = 2.0 x_tangent = jnp.array([0.1, 0.2, 0.3]) factor_tangent = 0.5 # 标量factor的切线 # 规定一个使用原语的函数 def apply_scale(x, factor): return custom_scale_p.bind(x, factor=factor) # 使用bind # 计算JVP primal_out, tangent_out = jax.jvp(apply_scale, (x_val, factor_val), (x_tangent, factor_tangent)) print(f"原始输出: {primal_out}") print(f"切线输出: {tangent_out}") # 预期输出切线: # y_dot = x_dot * factor + x * factor_dot # = [0.1, 0.2, 0.3] * 2.0 + [1.0, 2.0, 3.0] * 0.5 # = [0.2, 0.4, 0.6] + [0.5, 1.0, 1.5] # = [0.7, 1.4, 2.1]请注意我们如何处理ad.Zero以避免在输入切线为零时不必要的计算。JVP规则正确应用了微分的乘积规则。规定向量雅可比积 (VJP) 规则VJP规则在反向模式自动微分中起核心作用,这是jax.grad的工作方式。它描述了原语输出处的余切向量如何向后传播以生成输入值的余切。在数学上,如果$y = f(x)$,则VJP规则计算$v^T J$,其中$v^T$是输出余切向量(一个行向量),而$J$是雅可比。使用primitive.def_vjp规定VJP规则稍微复杂一些,因为反向模式需要前向传递的信息来计算反向传递。def_vjp装饰器期望一个执行前向传递并同时返回原始输出和残差值的函数。此残差包含前向传递中梯度计算所需的任何中间值。def_vjp还期望一个执行反向传递的第二个函数(通常在第一个函数内局部定义)。结构如下:def primitive_vjp(primals): # primals: 原始输入元组。 # 1. 计算原始输出 primal_outputs = ... # 计算 f(原始值) # 2. 确定反向传递所需的残差 residuals = ... # 例如,原始输入、中间值 # 3. 规定反向传递函数 def backward_pass(residuals, output_cotangents): # residuals: 从前向传递中保存的数据。 # output_cotangents: 与原始输出对应的余切向量。 # 根据残差和输出余切计算输入余切 input_cotangents = ... # 计算 output_cotangents^T @ J return input_cotangents # 必须是与原始值结构匹配的元组 return primal_outputs, backward_pass外部函数primitive_vjp接收原始输入,计算原始输出,并将必要数据打包到residuals中。它返回原始输出和backward_pass函数。JAX随后使用残差和传入的输出余切调用backward_pass以获得输入余切。示例:自定义缩放原语的VJP继续custom_scale_p的例子($y = x \times factor$):# 规定VJP规则 @custom_scale_p.def_vjp def _custom_scale_vjp(primals): x, factor = primals # 前向传递:计算输出并保存输入以用于反向传递 y = custom_scale_p.bind(x, factor=factor) # 使用bind进行实际计算 residuals = (x, factor) # 反向传递需要x和factor # 规定反向传递函数 def backward_pass(residuals, y_bar): # y_bar 是输出 y 的余切 x_res, factor_res = residuals # 解包残差 # 计算输入余切(梯度) # 关于x的梯度:dy/dx = factor => x_bar = y_bar * factor x_bar = y_bar * factor_res # 关于factor的梯度:dy/dfactor = x => factor_bar = sum(y_bar * x) # 如果x不是标量,则需要求和,因为factor是标量。 factor_bar = jnp.sum(y_bar * x_res) print(f"自定义VJP反向:y_bar={y_bar}, x_res={x_res}, factor_res={factor_res}, x_bar={x_bar}, factor_bar={factor_bar}") # 返回与原始值(x,factor)顺序匹配的余切元组 return (x_bar, factor_bar) return y, backward_pass # jax.grad的使用示例 x_val = jnp.array([1.0, 2.0, 3.0]) factor_val = 2.0 # 规定一个可微分函数 def loss_fn(x, factor): y = apply_scale(x, factor) # 通过apply_scale使用我们的原语 return jnp.sum(y * y) # 损失函数示例:平方和 # 使用jax.grad计算梯度 # grad计算关于指定argnums(0代表x,1代表factor)的导数 grad_x = jax.grad(loss_fn, argnums=0)(x_val, factor_val) grad_factor = jax.grad(loss_fn, argnums=1)(x_val, factor_val) print(f"关于x的梯度: {grad_x}") print(f"关于factor的梯度: {grad_factor}") # 预期梯度: # L = sum( (x*factor)^2 ) # dL/dy = 2*y = 2*x*factor # dL/dx = dL/dy * dy/dx = (2*x*factor) * factor = 2*x*factor^2 # = 2 * [1, 2, 3] * 2^2 = [8, 16, 24] # dL/dfactor = sum( dL/dy * dy/dfactor ) = sum( (2*x*factor) * x ) = sum( 2*x^2*factor ) # = 2 * factor * sum(x^2) = 2 * 2.0 * (1^2 + 2^2 + 3^2) # = 4 * (1 + 4 + 9) = 4 * 14 = 56.0在此VJP规则中,前向部分计算结果y并保存原始输入x和factor作为残差。backward_pass使用这些残差以及传入的输出余切y_bar(代表 $\partial Loss / \partial y$)来使用链式法则计算输入余切x_bar($\partial Loss / \partial x$)和factor_bar($\partial Loss / \partial factor$)。验证微分规则规定自定义微分规则容易出错。强烈建议验证其正确性。常见方法包括:数值检查: 使用有限差分近似导数,并与自定义JVP/VJP规则的结果进行比较。JAX提供了诸如jax.test_util.check_grads的实用工具,它自动完成VJP的此项比较(并通过JVP-VJP一致性检查隐式完成JVP的比较)。JVP-VJP一致性: JVP和VJP规则在数学上应该一致(通过转置相关)。虽然JAX不会自动强制此项,但显著差异通常表明一个或两个规则存在错误。jax.test_util.check_grads也有助于验证这种关系。# 使用jax.test_util进行验证的例子 from jax.test_util import check_grads # 检查使用原语的apply_scale函数的梯度 # check_grads 将分析梯度(来自VJP)与数值估计进行比较 check_grads(apply_scale, (x_val, factor_val), order=2, modes=['fwd', 'rev'], eps=1e-3) # order=2 检查一阶和二阶导数(如适用) # modes=['fwd', 'rev'] 检查JVP和VJP的一致性 print("梯度检查通过!") 运行check_grads提供信心表明您的微分规则已正确实现。与JAX变换的集成一旦您规定了自定义原语的JVP (def_jvp) 和 VJP (def_vjp) 规则,它就与JAX的自动微分系统集成。您现在可以对任何使用您的原语的JAX函数应用jax.grad、jax.jvp、jax.vjp,甚至组合这些变换以计算高阶导数,就像您使用内置操作一样。这完成了使您的自定义操作成为JAX生态系统完全集成部分的过程。