趋近智
虽然 JAX 提供了一个基于 jax.numpy 和 jax.lax 的完备数值运算库,但在某些情况下,你可能需要使用自己的基本运算来扩充 JAX 的功能。这时,自定义原语就变得重要了。原语是 JAX 理解和转换的原子级的、不可再分的运算。定义自定义原语允许你将高度专业化的计算(可能以 C++ 或 CUDA 等其他语言实现)直接集成到 JAX 环境中,使它们像内置运算一样,可以进行 JIT 编译、自动微分和向量化处理。
你可以将原语视为 JAX 描述计算所用的词汇。添加自定义原语就像是为这个词汇表添加一个新词。这样做的一些原因包括:
jax.lax 原语表达的算法或访问硬件功能。定义自定义原语需要你向 JAX 说明关于你的新运算的几点信息:
jit、vmap 等使用的 JAX 追踪机制不可或缺。我们来概述一下这个过程:
首先,你使用 jax.core.Primitive 创建一个新的原语实例。此对象作为你的运算的标识符。
import jax
import jax.numpy as jnp
from jax.core import Primitive
# 创建一个唯一的原语对象
my_custom_op_p = Primitive("my_custom_op")
字符串名称("my_custom_op")用于调试和打印 jaxpr。
通常你会将原语的使用封装在一个供用户调用的 Python 函数中。此函数接受标准的 Python/NumPy/JAX 输入,并使用 bind 将原语应用于其参数。
def my_custom_op(x, *, some_parameter: int):
"""将自定义运算应用于 x。
参数:
x: 输入 JAX 数组。
some_parameter: 运算的静态参数。
返回:
一个 JAX 数组结果。
"""
# 'bind' 将原语应用与其参数
# 以及静态参数一同暂存。
return my_custom_op_p.bind(x, some_parameter=some_parameter)
请注意 bind 的使用。这并非立即执行运算,而是在 JAX 的追踪机制中将其暂存。影响计算结构或抽象求值的参数(如 some_parameter)通常需要作为关键字参数传递给 bind。
JAX 需要在不运行实际计算的情况下了解输出的形状和数据类型。这通过使用 def_abstract_eval 定义一个抽象求值规则来实现。此规则接受输入的抽象版本(包含形状和数据类型信息的对象,例如 jax.core.ShapedArray),并计算抽象输出。
from jax.core import ShapedArray
def my_custom_op_abstract_eval(abstract_x, *, some_parameter: int):
"""my_custom_op 的抽象求值规则。"""
# 示例:假设运算是输入值的平方加上参数,
# 保留形状和数据类型。
# 实际规则完全取决于原语的功能。
output_shape = abstract_x.shape
output_dtype = abstract_x.dtype
# 参数检查可以在这里进行
if some_parameter < 0:
raise ValueError("some_parameter must be non-negative")
return ShapedArray(output_shape, output_dtype)
# 向原语注册抽象求值规则
my_custom_op_p.def_abstract_eval(my_custom_op_abstract_eval)
这一步对于 JAX 转换非常重要。jit 使用它来确定计算图的结构,而 vmap 使用它来计算批处理维度。
这通常是最复杂的一步。你需要告诉 JAX 如何将你的原语转换为目标后端(CPU、GPU、TPU)可以执行的低级代码。这通常通过 MLIR(JAX 的中间表示层)生成 XLA HLO(高级运算)来完成。
你使用 jax.interpreters.mlir.register_lowering 注册降低规则。
# 示例 - 需要对 MLIR 和 XLA HLO 有更深理解
from jax.interpreters import mlir
# 假设我们有用于构建特定 HLO 运算的函数
# from jaxlib.hlo_helpers import custom_call # 或类似辅助函数
def my_custom_op_lowering_cpu(ctx, x_operand, *, some_parameter: int):
"""CPU 的降低规则。"""
# 1. 定义 MLIR 的输出类型
# result_type = mlir.ir.RankedTensorType.get(output_shape, mlir_dtype)
# 2. 生成 XLA HLO 指令
# 这可能涉及标准 HLO 运算或对外部 C++/CPU 函数的 'custom_call' 调用。
# result_hlo = build_cpu_hlo_for_my_op(x_operand, some_parameter) # Fictional function
# return [result_hlo] # 必须返回一个 MLIR 值列表
pass # 占位符 - 实际实现很复杂
def my_custom_op_lowering_gpu(ctx, x_operand, *, some_parameter: int):
"""GPU 的降低规则。"""
# 结构类似,但生成针对 GPU 的 HLO。
# 可能使用 'custom_call' 调用 CUDA 内核。
# result_hlo = build_gpu_hlo_for_my_op(x_operand, some_parameter) # Fictional function
# return [result_hlo]
pass # 占位符
# 注册降低规则
# mlir.register_lowering(my_custom_op_p, my_custom_op_lowering_cpu, platform='cpu')
# mlir.register_lowering(my_custom_op_p, my_custom_op_lowering_gpu, platform='gpu')
# ... 也可能为 TPU 注册
编写降低规则需要理解目标后端的执行模型和 XLA HLO 指令集,或者如何使用 custom_call 与外部代码连接。这一步弥合了高级 JAX 原语与低级编译代码之间的差异。
为了使你的自定义原语能够与 jax.grad 及其他自动微分转换配合使用,你需要定义其 JVP 和 VJP 规则。
这些规则使用 jax.interpreters.ad.primitive_jvps 和 jax.interpreters.ad.primitive_transposes(用于 VJP)进行注册。
from jax.interpreters import ad
# --- JVP 规则 ---
def my_custom_op_jvp(primals, tangents, *, some_parameter: int):
"""my_custom_op 的 JVP 规则。"""
x, = primals
x_dot, = tangents # 对应于 x 的切线向量
# 计算原始输出
primal_out = my_custom_op_p.bind(x, some_parameter=some_parameter)
# 根据运算的导数计算切线输出
# 示例:如果 my_custom_op(x) = x^2 + some_parameter
# 导数是 2x。JVP 是 (2x) * x_dot
if type(x_dot) is ad.Zero:
tangent_out = ad.Zero.from_value(primal_out)
else:
# 假设存在用于导数的辅助原语或函数
# tangent_out = my_custom_op_derivative_p.bind(x, x_dot, some_parameter=some_parameter)
# 或者如果可能,直接计算:
tangent_out = 2 * x * x_dot # 示例导数计算
return primal_out, tangent_out
# 注册 JVP 规则
ad.primitive_jvps[my_custom_op_p] = my_custom_op_jvp
# --- VJP 规则 ---
# VJP 规则通常通过 JVP 规则的转置来定义,
# 但如果需要(特别是为了效率),也可以直接定义。
# 实现转置规则需要对 ad.primitive_transposes 有更深的理解。
# 另外,如果运算通过 Python 函数实现,
# 使用标准 JAX 运算或其他原语
# (即使这些原语使用 custom_calls),
# 你可以在*面向用户*的函数上使用 jax.custom_vjp,而不是直接定义
# 在原语上定义规则。
如果你的原语仅调用外部代码,你可能需要在外部代码中提供解析导数或实现反向传播。在包装函数上使用 jax.custom_vjp 或 jax.custom_jvp 有时是比直接在原语本身上定义规则更易于管理的方法,特别是如果运算的实现涉及多个步骤。
下图说明了当 JAX 在转换后的函数(例如用 @jit 装饰的函数)中遇到你的自定义原语时,这些组件如何配合工作。
将自定义原语集成到 JAX 的过程。Python 调用通过
bind进行暂存,在 jaxpr 中表示,进行形状/数据类型的抽象求值,然后降低为 XLA HLO 进行后端执行。微分需要相应的 VJP/JVP 规则。
定义自定义原语是扩充 JAX 的强大功能,但与使用标准 JAX 运算相比,这在复杂性上显著增加。
在定义自定义原语之前,请考虑你的目标是否可以通过使用现有的 jax.lax 运算、jax.experimental.host_callback、jax.pure_callback,或者通过不同方式表达计算来实现。然而,当性能或必要性要求时,自定义原语提供了将专业代码完整集成到 JAX 环境中的方式。后续章节将介绍实现的具体信息,包括抽象求值、降低和微分规则。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造