JAX原语是JAX系统所知的基础构造块。在定义您自己的自定义操作时,主要的一步是让JAX知晓其形状和类型特征。这通过实现抽象求值规则来完成。抽象求值是JAX追踪机制的一个重要组成部分。在JAX能够对包含您自定义原语的函数进行JIT编译、自动微分或向量化之前,它需要确定该原语产生的输出的属性(如形状和数据类型,即dtype),而无需实际在真实数据上运行计算。它通过操作抽象值来执行此分析,主要是jax.core.ShapedArray实例,这些实例封装了形状和dtype信息,但不持有实际的数值。抽象求值的目的可以将抽象求值理解为在类型和形状层面定义您的原语的函数签名。当JAX在追踪期间遇到您的原语时(例如,当jax.jit应用于使用它的函数时),它会调用该原语的抽象求值规则。此规则接收与原语输入对应的抽象值以及任何必要的元数据(原语行为的特定参数),并计算原语输出的抽象值(形状和dtype)。此信息之所以重要,有以下几个原因:静态分析: 这使得JAX能够在追踪阶段,在尝试编译或执行之前,及早检查形状不匹配或类型错误。变换兼容性: vmap等变换需要知晓形状如何变化才能正确实现批处理规则。grad需要知晓形状和类型才能正确设置反向传播。编译: XLA编译器需要计算图中所有操作的精确形状和类型信息,以便为目标硬件(CPU、GPU、TPU)生成高效、专门的代码。实现abstract_eval方法为了定义自定义原语的抽象求值规则,您通常需要继承jax.core.Primitive并实现其abstract_eval方法。import jax import jax.numpy as jnp from jax.core import Primitive, ShapedArray from jax.interpreters import xla # 示例:一个将输入加1.0的原语 custom_add_one_p = Primitive('custom_add_one') @custom_add_one_p.def_abstract_eval def custom_add_one_abstract_eval(*avals, **params): # avals 包含输入的抽象值(例如 ShapedArray) # params 包含在原语绑定期间传递的任何参数 # 基本验证(可选但推荐) if len(avals) != 1: raise ValueError("custom_add_one 需要恰好一个输入操作数。") input_aval = avals[0] # 检查输入是否为 ShapedArray(最常见的情况) if not isinstance(input_aval, ShapedArray): raise TypeError(f"输入必须是 ShapedArray,得到 {type(input_aval)}") # 确保此特定操作的 dtype 是浮点类型 if not jnp.issubdtype(input_aval.dtype, jnp.floating): # 或者您可能定义特定的类型提升规则 raise TypeError(f"输入 dtype 必须是浮点类型,得到 {input_aval.dtype}") # 核心逻辑:确定输出形状和 dtype # 对于 'custom_add_one',形状和 dtype 与输入相同。 output_shape = input_aval.shape output_dtype = input_aval.dtype # 返回输出的抽象值 return ShapedArray(output_shape, output_dtype) # 示例用法(说明性,如果没有实现规则则无法运行) # def my_func(x): # # 将原语与输入 x 绑定 # return custom_add_one_p.bind(x) 在此示例中:我们定义了一个原语custom_add_one_p。我们使用@custom_add_one_p.def_abstract_eval装饰器修饰函数custom_add_one_abstract_eval。这将该函数注册为我们原语的抽象求值规则。该函数接收传递给原语bind方法的每个位置参数的抽象值(avals),以及传递给bind的任何关键字参数(params)。我们对输入数量和类型执行基本验证。对于数组操作,检查isinstance(aval, ShapedArray)很常见。我们还检查了dtype。主要任务是计算output_shape和output_dtype。对于我们简单的custom_add_one,它们与输入的形状和dtype保持一致。更复杂的原语在此处会有逻辑,以体现它们如何变换形状(例如矩阵乘法、卷积)。最后,它返回一个ShapedArray,表示原语输出的抽象属性。如果一个原语返回多个输出,该函数将返回一个ShapedArray实例的元组。处理参数(params)有时,原语的行为以及其输出形状或类型,取决于并非输入数组本身的参数。例如,卷积原语需要步长和填充信息,或者归约原语需要归约的轴。这些参数通常作为关键字参数传递给原语的bind方法,并在abstract_eval函数内部的params字典中接收。# 归约原语 # custom_reduce_sum_p = Primitive('custom_reduce_sum') # @custom_reduce_sum_p.def_abstract_eval # def custom_reduce_sum_abstract_eval(aval, *, axis): # axis 作为关键字参数传递给 bind # if not isinstance(aval, ShapedArray): # raise TypeError("输入必须是数组") # # # 根据输入形状和轴参数计算输出形状 # output_shape = tuple(d for i, d in enumerate(aval.shape) if i not in axis) # output_dtype = aval.dtype # 假设 dtype 不变 # # return ShapedArray(output_shape, output_dtype) # 用法: # result = custom_reduce_sum_p.bind(my_array, axis=(0,))在这里,axis参数影响output_shape的计算。abstract_eval规则直接从params字典(或者如上所示的关键字参数,JAX对此处理得很好)中使用此参数来确定正确的输出形状。在原语定义工作流中的位置抽象求值是您需要为新原语定义的第一个规则。它为JAX理解原语的签名奠定了基础。只有在定义了abstract_eval之后,您才能继续实现实际的计算逻辑(针对CPU/GPU/TPU等特定后端的“降级”规则)及其微分规则(JVP/VJP),这些都依赖于抽象求值提供的形状和类型信息。digraph G { rankdir=TB; node [shape=box, style="filled,rounded", fontname="Arial", fillcolor="#e9ecef"]; edge [fontname="Arial", fontsize=10]; JaxFunction [label="使用原语的函数", fillcolor="#a5d8ff"]; Tracing [label="JAX 追踪(例如 jit, vmap)", fillcolor="#bac8ff"]; AbstractEval [label="原语.abstract_eval\n(输入: 抽象值, 参数\n输出: 抽象输出值)", fillcolor="#96f2d7"]; Jaxpr [label="jaxpr\n(已知形状/dtype的图)", fillcolor="#ffec99"]; Compiler [label="XLA 编译", fillcolor="#ffd8a8"]; BackendExecution [label="后端执行 (CPU/GPU/TPU)\n(需要降级规则)", fillcolor="#ffc9c9"]; Differentiation [label="自动微分 (grad)\n(需要微分规则)", fillcolor="#fcc2d7"]; JaxFunction -> Tracing; Tracing -> AbstractEval [label=" 遇到\n 自定义原语 "]; AbstractEval -> Jaxpr [label=" 提供输出\n 形状/dtype "]; Jaxpr -> Compiler; Compiler -> BackendExecution; Jaxpr -> Differentiation [style=dashed, label=" 告知结构 "]; subgraph cluster_primitive { label = "自定义原语定义"; style="dashed"; color="#adb5bd"; AbstractEval; BackendExecution; Differentiation; } } 抽象求值规则(abstract_eval)在JAX追踪期间被调用,以确定原语输出的形状和类型,从而能够创建jaxpr(JAX的中间表示),jaxpr随后用于编译、通过降级规则执行以及定义微分行为。通过仔细实现abstract_eval方法,您可以确保您的自定义原语与JAX的追踪和变换机制结合,在JIT编译、向量化和自动微分等不同环境中,在形状和数据类型方面表现出可预测性。这一步对于在JAX中创建自定义操作来说是基础。