JAX中的自动求导与其结构化控制流原语(如lax.cond、lax.while_loop和lax.scan)紧密结合。然而,了解求导如何与这些操作交互对于编写高效代码和解释结果十分必要,特别是JAX需要在执行前追踪计算图。对lax.cond求导回顾一下,lax.cond(pred, true_fun, false_fun, operand)根据pred的布尔值执行true_fun(operand)或false_fun(operand)。当JAX在被求导的函数中(例如使用jax.grad时)遇到lax.cond时,它必须生成能够计算梯度的代码,无论实际执行时会选择哪个分支。为此,JAX在求导过程中会追踪true_fun和false_fun这两个分支。生成的求导函数将包含两个分支的梯度逻辑。在执行期间,原始正向传播确定使用哪个分支的结果。反向传播随后计算实际执行的分支对应的梯度。考虑这个示例:import jax import jax.numpy as jnp from jax import lax def conditional_computation(x, y): # 条件取决于输入'y' return lax.cond(y > 0, lambda op: jnp.sin(op * 2.0), # true_fun lambda op: jnp.cos(op / 2.0), # false_fun x) # operand grad_x = jax.grad(conditional_computation, argnums=0) grad_y = jax.grad(conditional_computation, argnums=1) # 注意:关于pred的梯度通常为零 # 示例执行 x_val = jnp.pi / 4.0 y_val_pos = 1.0 y_val_neg = -1.0 print(f"f(x, y>0) = {conditional_computation(x_val, y_val_pos)}") print(f"df/dx(x, y>0) = {grad_x(x_val, y_val_pos)}") # sin(2x)的梯度是2*cos(2x)。在x=pi/4时,2*cos(pi/2) = 0.0 print(f"\nf(x, y<0) = {conditional_computation(x_val, y_val_neg)}") print(f"df/dx(x, y<0) = {grad_x(x_val, y_val_neg)}") # cos(x/2)的梯度是-0.5*sin(x/2)。在x=pi/4时,-0.5*sin(pi/8) 大约 -0.191 # 关于'y'(谓词条件)的梯度通常为零 # 因为谓词本身通常被视为离散的。 print(f"\ndf/dy(x, y>0) = {grad_y(x_val, y_val_pos)}") print(f"df/dy(x, y<0) = {grad_y(x_val, y_val_neg)}")重要提示:两个分支的代码都包含在编译后的求导函数中。传递给true_fun和false_fun的操作数必须具有相同的结构(形状和数据类型),因为JAX需要一致的类型签名。通过谓词pred本身的求导通常没有实际意义或不受支持,除非谓词计算涉及可求导操作且其输出以允许梯度的方式使用(这对于典型的布尔条件并不常见)。JAX通常对仅影响谓词的参数给出零梯度。对lax.while_loop求导原语lax.while_loop(cond_fun, body_fun, init_val)只要cond_fun返回True就会重复应用body_fun。对while_loop求导涉及将链式法则反向应用于正向传播期间执行的迭代。这类似于用于训练循环神经网络(RNN)的反向传播(BPTT)。初始状态init_val的梯度取决于状态通过body_fun在所有迭代中如何演变。import jax import jax.numpy as jnp from jax import lax def loop_sum(max_val): # 求和从0到(但不包括)max_val的数字 init_state = (0, 0.0) # (当前i, 当前和) def cond_fun(state): i, _ = state return i < max_val # 当i < max_val时继续 def body_fun(state): i, current_sum = state return (i + 1, current_sum + jnp.sqrt(i.astype(jnp.float32))) # 示例操作 final_state = lax.while_loop(cond_fun, body_fun, init_state) _, final_sum = final_state return final_sum # 对循环结果关于初始'max_val'输入求导 # 注意:max_val影响迭代的*次数*,使其梯度变得复杂。 # 为了清晰起见,我们对循环*内部*的某些东西求导。 def loop_sum_param(scale, n_iters): init_state = (0, 0.0) # (i, 当前和) def cond_fun(state): i, _ = state return i < n_iters def body_fun(state): i, current_sum = state # 在循环中使用'scale' return (i + 1, current_sum + scale * i) _, final_sum = lax.while_loop(cond_fun, body_fun, init_state) return final_sum grad_loop_sum = jax.grad(loop_sum_param, argnums=0) # 关于'scale'的梯度 scale_val = 2.0 iters = 5 print(f"Loop sum(scale={scale_val}, iters={iters}) = {loop_sum_param(scale_val, iters)}") # 正向传播:2*0 + 2*1 + 2*2 + 2*3 + 2*4 = 0 + 2 + 4 + 6 + 8 = 20 # 预期梯度 d(sum)/d(scale) = 0 + 1 + 2 + 3 + 4 = 10 print(f"d(Sum)/d(scale) = {grad_loop_sum(scale_val, iters)}")重要提示:反向传播需要正向传播期间为循环的每次迭代计算的中间值(原值)。对于有许多迭代的循环,存储所有中间值可能会占用大量内存。这是使用梯度检查点(jax.checkpoint)的主要原因,该技术在反向传播期间重新计算中间值而不是存储它们。这是用计算换取内存的做法。类似于深度网络或RNN,通过许多循环迭代传播的梯度可能会消失或爆炸,尽管循环体内的细致初始化或归一化等方法可以缓解这种情况。对lax.scan求导lax.scan(f, init, xs)原语在序列xs上累积应用函数f,并携带状态init。它通常用于实现RNN或其他已知步数的顺序过程。通过lax.scan的求导是定义明确且高效的。JAX自动处理梯度通过携带状态(carry)和每步输出(y)的传播。与lax.while_loop类似,反向传播类似于BPTT。import jax import jax.numpy as jnp from jax import lax def simple_rnn_step(carry, x_t): # 一个非常基本的RNN单元:carry是隐藏状态 h_t-1 # x_t 是时间 t 的输入 # 输出是新的隐藏状态 h_t 和一个输出 y_t prev_h = carry weight_hh = 0.5 # 为简单起见,固定参数 weight_xh = 1.5 # 为简单起见,固定参数 # 简单线性更新 new_h = jnp.tanh(prev_h * weight_hh + x_t * weight_xh) y_t = new_h * 2.0 # 基于隐藏状态的一些输出 return new_h, y_t # 新携带状态, y_t def run_rnn(initial_state, inputs): # 初始状态:h_0 # 输入:x_t 值的序列 final_state, outputs_y = lax.scan(simple_rnn_step, initial_state, inputs) return jnp.sum(outputs_y) # 返回输出总和作为标量损失 grad_rnn_params = jax.grad(run_rnn, argnums=0) # 关于初始状态的梯度 grad_rnn_inputs = jax.grad(run_rnn, argnums=1) # 关于输入的梯度 h0 = jnp.zeros(()) # 初始隐藏状态(标量) xts = jnp.array([0.1, 0.2, -0.1, 0.3]) # 输入序列 total_output = run_rnn(h0, xts) print(f"Total output = {total_output}") # 关于初始状态和输入的梯度 dh0 = grad_rnn_params(h0, xts) dxts = grad_rnn_inputs(h0, xts) print(f"d(Sum)/dh0 = {dh0}") print(f"d(Sum)/dxts = {dxts}")重要提示:对于固定长度的序列,lax.scan通常优于lax.while_loop,因为其结构更明确,通常会带来更简单的分析和XLA可能进行的更多优化。scan的反向传播通过反向遍历正向传播中执行的步骤,使用保存的中间值,从而高效地计算梯度。对于非常长的序列,内存使用仍然是一个问题,jax.checkpoint也可以应用于scan中使用的函数f。digraph ScanDifferentiation { rankdir=LR; node [shape=box, style=rounded, fontname="Arial", fontsize=10]; edge [fontname="Arial", fontsize=9]; subgraph cluster_forward { label = "正向传播 (scan)"; style=filled; color="#e9ecef"; // gray node [fillcolor="#a5d8ff", style=filled]; // blue x0 [label="x₀"]; x1 [label="x₁"]; x2 [label="x₂"]; h_init [label="h_init", shape=ellipse, fillcolor="#ffec99"]; // yellow h0 [label="h₀"]; y0 [label="y₀", shape=cds, fillcolor="#96f2d7"]; // teal f0 [label="f(h_init, x₀)", shape=invhouse]; h1 [label="h₁"]; y1 [label="y₁", shape=cds, fillcolor="#96f2d7"]; // teal f1 [label="f(h₀, x₁)", shape=invhouse]; h2 [label="h₂"]; y2 [label="y₂", shape=cds, fillcolor="#96f2d7"]; // teal f2 [label="f(h₁, x₂)", shape=invhouse]; h_final [label="h_final", shape=ellipse, fillcolor="#ffec99"]; // yellow h_init -> f0 [label="携带状态"]; x0 -> f0; f0 -> h0 [label="携带状态"]; f0 -> y0; h0 -> f1 [label="携带状态"]; x1 -> f1; f1 -> h1 [label="携带状态"]; f1 -> y1; h1 -> f2 [label="携带状态"]; x2 -> f2; f2 -> h2 [label="携带状态"]; f2 -> y2; h2 -> h_final [label="携带状态"]; } subgraph cluster_backward { label = "反向传播 (VJP)"; style=filled; color="#e9ecef"; // gray node [fillcolor="#ffc9c9", style=filled]; // red grad_out [label="dL/dL=1"]; // 最终损失的梯度 grad_y0 [label="dL/dy₀"]; grad_y1 [label="dL/dy₁"]; grad_y2 [label="dL/dy₂"]; grad_h0 [label="dL/dh₀"]; grad_h1 [label="dL/dh₁"]; grad_h2 [label="dL/dh₂"]; grad_h_init [label="dL/dh_init", shape=ellipse, fillcolor="#ffd8a8"]; // orange grad_x0 [label="dL/dx₀", shape=cds, fillcolor="#b2f2bb"]; // green grad_x1 [label="dL/dx₁", shape=cds, fillcolor="#b2f2bb"]; // green grad_x2 [label="dL/dx₂", shape=cds, fillcolor="#b2f2bb"]; // green f0_vjp [label="VJP(f₀)", shape=house]; f1_vjp [label="VJP(f₁)", shape=house]; f2_vjp [label="VJP(f₂)", shape=house]; // 显示梯度流的连接 grad_out -> grad_y0; grad_out -> grad_y1; grad_out -> grad_y2; // 假设损失是sum(y) grad_h_final -> grad_h2 [style=invis]; // 如果需要,使用最终携带状态梯度,此处假设为0 {grad_y2, grad_h2} -> f2_vjp; f2_vjp -> grad_h1 [label="关于携带状态"]; f2_vjp -> grad_x2 [label="关于x"]; {grad_y1, grad_h1} -> f1_vjp; f1_vjp -> grad_h0 [label="关于携带状态"]; f1_vjp -> grad_x1 [label="关于x"]; {grad_y0, grad_h0} -> f0_vjp; f0_vjp -> grad_h_init [label="关于携带状态"]; f0_vjp -> grad_x0 [label="关于x"]; } }lax.scan求导的数据流。正向传播按顺序计算状态($h$)和输出($y$)。反向传播(VJP)使用正向传播中的中间值,反向传播梯度($dL/d\cdot$)通过展开的计算。总之,JAX的自动求导系统旨在与结构化控制流原语正确协同工作。尽管从用户角度看求导“直接可用”,但知道cond追踪两个分支,以及while_loop和scan的求导类似于BPTT,有助于理解内存使用、潜在的数值问题以及梯度检查点等技术的适用性。