虽然 JAX 提供了一个基于 jax.numpy 和 jax.lax 的完备数值运算库,但在某些情况下,你可能需要使用自己的基本运算来扩充 JAX 的功能。这时,自定义原语就变得重要了。原语是 JAX 理解和转换的原子级的、不可再分的运算。定义自定义原语允许你将高度专业化的计算(可能以 C++ 或 CUDA 等其他语言实现)直接集成到 JAX 环境中,使它们像内置运算一样,可以进行 JIT 编译、自动微分和向量化处理。你可以将原语视为 JAX 描述计算所用的词汇。添加自定义原语就像是为这个词汇表添加一个新词。这样做的一些原因包括:性能: 调用预编译的、高度优化的核心程序(例如,来自 CUDA 库或自定义 C++ 代码)来执行特定任务,这可能比等效的 JAX 实现更快。不支持的运算: 实现无法直接使用现有 jax.lax 原语表达的算法或访问硬件功能。与自定义硬件对接: 为专用加速器提供 JAX 绑定。定义自定义原语需要你向 JAX 说明关于你的新运算的几点信息:是什么: 为该运算创建一个唯一标识符。其抽象行为方式: 根据输入形状和数据类型指定输出形状和数据类型,无需实际值。这对 jit、vmap 等使用的 JAX 追踪机制不可或缺。如何执行: 为不同后端(CPU、GPU、TPU)提供实现,通常是将原语转换为相应的低级表示形式(如 XLA HLO)。如何进行微分: 定义计算雅可比-向量积(JVPs)和向量-雅可比积(VJPs)的规则。我们来概述一下这个过程:1. 定义原语对象首先,你使用 jax.core.Primitive 创建一个新的原语实例。此对象作为你的运算的标识符。import jax import jax.numpy as jnp from jax.core import Primitive # 创建一个唯一的原语对象 my_custom_op_p = Primitive("my_custom_op")字符串名称("my_custom_op")用于调试和打印 jaxpr。2. 创建面向用户的函数通常你会将原语的使用封装在一个供用户调用的 Python 函数中。此函数接受标准的 Python/NumPy/JAX 输入,并使用 bind 将原语应用于其参数。def my_custom_op(x, *, some_parameter: int): """将自定义运算应用于 x。 参数: x: 输入 JAX 数组。 some_parameter: 运算的静态参数。 返回: 一个 JAX 数组结果。 """ # 'bind' 将原语应用与其参数 # 以及静态参数一同暂存。 return my_custom_op_p.bind(x, some_parameter=some_parameter)请注意 bind 的使用。这并非立即执行运算,而是在 JAX 的追踪机制中将其暂存。影响计算结构或抽象求值的参数(如 some_parameter)通常需要作为关键字参数传递给 bind。3. 实现抽象求值规则JAX 需要在不运行实际计算的情况下了解输出的形状和数据类型。这通过使用 def_abstract_eval 定义一个抽象求值规则来实现。此规则接受输入的抽象版本(包含形状和数据类型信息的对象,例如 jax.core.ShapedArray),并计算抽象输出。from jax.core import ShapedArray def my_custom_op_abstract_eval(abstract_x, *, some_parameter: int): """my_custom_op 的抽象求值规则。""" # 示例:假设运算是输入值的平方加上参数, # 保留形状和数据类型。 # 实际规则完全取决于原语的功能。 output_shape = abstract_x.shape output_dtype = abstract_x.dtype # 参数检查可以在这里进行 if some_parameter < 0: raise ValueError("some_parameter must be non-negative") return ShapedArray(output_shape, output_dtype) # 向原语注册抽象求值规则 my_custom_op_p.def_abstract_eval(my_custom_op_abstract_eval)这一步对于 JAX 转换非常重要。jit 使用它来确定计算图的结构,而 vmap 使用它来计算批处理维度。4. 实现后端降低规则这通常是最复杂的一步。你需要告诉 JAX 如何将你的原语转换为目标后端(CPU、GPU、TPU)可以执行的低级代码。这通常通过 MLIR(JAX 的中间表示层)生成 XLA HLO(高级运算)来完成。你使用 jax.interpreters.mlir.register_lowering 注册降低规则。# 示例 - 需要对 MLIR 和 XLA HLO 有更深理解 from jax.interpreters import mlir # 假设我们有用于构建特定 HLO 运算的函数 # from jaxlib.hlo_helpers import custom_call # 或类似辅助函数 def my_custom_op_lowering_cpu(ctx, x_operand, *, some_parameter: int): """CPU 的降低规则。""" # 1. 定义 MLIR 的输出类型 # result_type = mlir.ir.RankedTensorType.get(output_shape, mlir_dtype) # 2. 生成 XLA HLO 指令 # 这可能涉及标准 HLO 运算或对外部 C++/CPU 函数的 'custom_call' 调用。 # result_hlo = build_cpu_hlo_for_my_op(x_operand, some_parameter) # Fictional function # return [result_hlo] # 必须返回一个 MLIR 值列表 pass # 占位符 - 实际实现很复杂 def my_custom_op_lowering_gpu(ctx, x_operand, *, some_parameter: int): """GPU 的降低规则。""" # 结构类似,但生成针对 GPU 的 HLO。 # 可能使用 'custom_call' 调用 CUDA 内核。 # result_hlo = build_gpu_hlo_for_my_op(x_operand, some_parameter) # Fictional function # return [result_hlo] pass # 占位符 # 注册降低规则 # mlir.register_lowering(my_custom_op_p, my_custom_op_lowering_cpu, platform='cpu') # mlir.register_lowering(my_custom_op_p, my_custom_op_lowering_gpu, platform='gpu') # ... 也可能为 TPU 注册编写降低规则需要理解目标后端的执行模型和 XLA HLO 指令集,或者如何使用 custom_call 与外部代码连接。这一步弥合了高级 JAX 原语与低级编译代码之间的差异。5. 定义微分规则为了使你的自定义原语能够与 jax.grad 及其他自动微分转换配合使用,你需要定义其 JVP 和 VJP 规则。JVP 规则: 定义如何计算雅可比-向量积 $J v$。它接受原始输入和切线输入(向量),并返回原始输出和切线输出。VJP 规则: 定义如何计算向量-雅可比积 $v^T J$。它接受原始输入,并返回原始输出以及一个在给定输出余切向量时计算 VJP 的函数。这些规则使用 jax.interpreters.ad.primitive_jvps 和 jax.interpreters.ad.primitive_transposes(用于 VJP)进行注册。from jax.interpreters import ad # --- JVP 规则 --- def my_custom_op_jvp(primals, tangents, *, some_parameter: int): """my_custom_op 的 JVP 规则。""" x, = primals x_dot, = tangents # 对应于 x 的切线向量 # 计算原始输出 primal_out = my_custom_op_p.bind(x, some_parameter=some_parameter) # 根据运算的导数计算切线输出 # 示例:如果 my_custom_op(x) = x^2 + some_parameter # 导数是 2x。JVP 是 (2x) * x_dot if type(x_dot) is ad.Zero: tangent_out = ad.Zero.from_value(primal_out) else: # 假设存在用于导数的辅助原语或函数 # tangent_out = my_custom_op_derivative_p.bind(x, x_dot, some_parameter=some_parameter) # 或者如果可能,直接计算: tangent_out = 2 * x * x_dot # 示例导数计算 return primal_out, tangent_out # 注册 JVP 规则 ad.primitive_jvps[my_custom_op_p] = my_custom_op_jvp # --- VJP 规则 --- # VJP 规则通常通过 JVP 规则的转置来定义, # 但如果需要(特别是为了效率),也可以直接定义。 # 实现转置规则需要对 ad.primitive_transposes 有更深的理解。 # 另外,如果运算通过 Python 函数实现, # 使用标准 JAX 运算或其他原语 # (即使这些原语使用 custom_calls), # 你可以在*面向用户*的函数上使用 jax.custom_vjp,而不是直接定义 # 在原语上定义规则。如果你的原语仅调用外部代码,你可能需要在外部代码中提供解析导数或实现反向传播。在包装函数上使用 jax.custom_vjp 或 jax.custom_jvp 有时是比直接在原语本身上定义规则更易于管理的方法,特别是如果运算的实现涉及多个步骤。流程下图说明了当 JAX 在转换后的函数(例如用 @jit 装饰的函数)中遇到你的自定义原语时,这些组件如何配合工作。digraph G { rankdir=TB; node [shape=box, style="filled", fillcolor="#e9ecef", fontname="sans-serif"]; edge [fontname="sans-serif"]; PyFunc [label="Python 调用:\nmy_custom_op(a, param=5)", fillcolor="#a5d8ff"]; Bind [label="my_custom_op_p.bind(a, param=5)", fillcolor="#bac8ff"]; Jaxpr [label="Jaxpr 表示\n(包含 my_custom_op 原语)", fillcolor="#d0bfff"]; AbstractEval [label="抽象求值:\nmy_custom_op_abstract_eval(Abstract(a), param=5)", fillcolor="#eebefa"]; ShapeDtype [label="输出形状和数据类型", fillcolor="#fcc2d7"]; Lowering [label="降低规则:\nmlir.register_lowering(...)", fillcolor="#ffc9c9"]; HLO [label="XLA HLO / MLIR\n(可能包含 custom_call)", fillcolor="#ffd8a8"]; Backend [label="后端执行\n(CPU/GPU/TPU)", fillcolor="#b2f2bb"]; Result [label="结果数组", fillcolor="#96f2d7"]; subgraph cluster_jax_internal { label = "JAX 内部处理"; style = "dashed"; color = "#868e96"; Bind; Jaxpr; AbstractEval; ShapeDtype; Lowering; HLO; } PyFunc -> Bind [label=" 调用"]; Bind -> Jaxpr [label=" 暂存到"]; Jaxpr -> AbstractEval [label=" 触发"]; AbstractEval -> ShapeDtype [label=" 确定"]; Jaxpr -> Lowering [label=" 编译时触发"]; Lowering -> HLO [label=" 生成"]; HLO -> Backend [label=" 发送至"]; Backend -> Result; // 分化路径 Grad [label="@jax.grad", fillcolor="#ffe066"]; VJP [label="VJP / 转置规则", fillcolor="#ffc078"]; Grad -> Jaxpr [style=dotted, label=" 需要微分规则"]; Jaxpr -> VJP [style=dotted, label=" 使用"]; }将自定义原语集成到 JAX 的过程。Python 调用通过 bind 进行暂存,在 jaxpr 中表示,进行形状/数据类型的抽象求值,然后降低为 XLA HLO 进行后端执行。微分需要相应的 VJP/JVP 规则。总结与注意事项定义自定义原语是扩充 JAX 的强大功能,但与使用标准 JAX 运算相比,这在复杂性上显著增加。复杂性: 需要理解 JAX 的内部机制(追踪、抽象求值)、MLIR/XLA HLO,以及可能的低级后端编程(CPU 线程、CUDA)。维护: 自定义原语需要与 JAX 更新同步维护,因为内部 API 可能会发生变化。降低规则可能需要针对不同的 JAX/XLA/驱动版本或硬件进行调整。调试: 调试降低规则或自定义核心程序中的问题可能具有挑战性。在定义自定义原语之前,请考虑你的目标是否可以通过使用现有的 jax.lax 运算、jax.experimental.host_callback、jax.pure_callback,或者通过不同方式表达计算来实现。然而,当性能或必要性要求时,自定义原语提供了将专业代码完整集成到 JAX 环境中的方式。后续章节将介绍实现的具体信息,包括抽象求值、降低和微分规则。