趋近智
JAX原语是JAX系统所知的基础构造块。在定义您自己的自定义操作时,主要的一步是让JAX知晓其形状和类型特征。这通过实现抽象求值规则来完成。
抽象求值是JAX追踪机制的一个重要组成部分。在JAX能够对包含您自定义原语的函数进行JIT编译、自动微分或向量化之前,它需要确定该原语产生的输出的属性(如形状和数据类型,即dtype),而无需实际在真实数据上运行计算。它通过操作抽象值来执行此分析,主要是jax.core.ShapedArray实例,这些实例封装了形状和dtype信息,但不持有实际的数值。
可以将抽象求值理解为在类型和形状层面定义您的原语的函数签名。当JAX在追踪期间遇到您的原语时(例如,当jax.jit应用于使用它的函数时),它会调用该原语的抽象求值规则。此规则接收与原语输入对应的抽象值以及任何必要的元数据(原语行为的特定参数),并计算原语输出的抽象值(形状和dtype)。
此信息之所以重要,有以下几个原因:
vmap等变换需要知晓形状如何变化才能正确实现批处理规则。grad需要知晓形状和类型才能正确设置反向传播。abstract_eval方法为了定义自定义原语的抽象求值规则,您通常需要继承jax.core.Primitive并实现其abstract_eval方法。
import jax
import jax.numpy as jnp
from jax.core import Primitive, ShapedArray
from jax.interpreters import xla
# 示例:一个将输入加1.0的原语
custom_add_one_p = Primitive('custom_add_one')
@custom_add_one_p.def_abstract_eval
def custom_add_one_abstract_eval(*avals, **params):
# avals 包含输入的抽象值(例如 ShapedArray)
# params 包含在原语绑定期间传递的任何参数
# 基本验证(可选但推荐)
if len(avals) != 1:
raise ValueError("custom_add_one 需要恰好一个输入操作数。")
input_aval = avals[0]
# 检查输入是否为 ShapedArray(最常见的情况)
if not isinstance(input_aval, ShapedArray):
raise TypeError(f"输入必须是 ShapedArray,得到 {type(input_aval)}")
# 确保此特定操作的 dtype 是浮点类型
if not jnp.issubdtype(input_aval.dtype, jnp.floating):
# 或者您可能定义特定的类型提升规则
raise TypeError(f"输入 dtype 必须是浮点类型,得到 {input_aval.dtype}")
# 核心逻辑:确定输出形状和 dtype
# 对于 'custom_add_one',形状和 dtype 与输入相同。
output_shape = input_aval.shape
output_dtype = input_aval.dtype
# 返回输出的抽象值
return ShapedArray(output_shape, output_dtype)
# 示例用法(说明性,如果没有实现规则则无法运行)
# def my_func(x):
# # 将原语与输入 x 绑定
# return custom_add_one_p.bind(x)
在此示例中:
custom_add_one_p。@custom_add_one_p.def_abstract_eval装饰器修饰函数custom_add_one_abstract_eval。这将该函数注册为我们原语的抽象求值规则。bind方法的每个位置参数的抽象值(avals),以及传递给bind的任何关键字参数(params)。isinstance(aval, ShapedArray)很常见。我们还检查了dtype。output_shape和output_dtype。对于我们简单的custom_add_one,它们与输入的形状和dtype保持一致。更复杂的原语在此处会有逻辑,以体现它们如何变换形状(例如矩阵乘法、卷积)。ShapedArray,表示原语输出的抽象属性。如果一个原语返回多个输出,该函数将返回一个ShapedArray实例的元组。params)有时,原语的行为以及其输出形状或类型,取决于并非输入数组本身的参数。例如,卷积原语需要步长和填充信息,或者归约原语需要归约的轴。这些参数通常作为关键字参数传递给原语的bind方法,并在abstract_eval函数内部的params字典中接收。
# 归约原语
# custom_reduce_sum_p = Primitive('custom_reduce_sum')
# @custom_reduce_sum_p.def_abstract_eval
# def custom_reduce_sum_abstract_eval(aval, *, axis): # axis 作为关键字参数传递给 bind
# if not isinstance(aval, ShapedArray):
# raise TypeError("输入必须是数组")
#
# # 根据输入形状和轴参数计算输出形状
# output_shape = tuple(d for i, d in enumerate(aval.shape) if i not in axis)
# output_dtype = aval.dtype # 假设 dtype 不变
#
# return ShapedArray(output_shape, output_dtype)
# 用法:
# result = custom_reduce_sum_p.bind(my_array, axis=(0,))
在这里,axis参数影响output_shape的计算。abstract_eval规则直接从params字典(或者如上所示的关键字参数,JAX对此处理得很好)中使用此参数来确定正确的输出形状。
抽象求值是您需要为新原语定义的第一个规则。它为JAX理解原语的签名奠定了基础。只有在定义了abstract_eval之后,您才能继续实现实际的计算逻辑(针对CPU/GPU/TPU等特定后端的“降级”规则)及其微分规则(JVP/VJP),这些都依赖于抽象求值提供的形状和类型信息。
抽象求值规则(
abstract_eval)在JAX追踪期间被调用,以确定原语输出的形状和类型,从而能够创建jaxpr(JAX的中间表示),jaxpr随后用于编译、通过降级规则执行以及定义微分行为。
通过仔细实现abstract_eval方法,您可以确保您的自定义原语与JAX的追踪和变换机制结合,在JIT编译、向量化和自动微分等不同环境中,在形状和数据类型方面表现出可预测性。这一步对于在JAX中创建自定义操作来说是基础。
这部分内容有帮助吗?
© 2026 ApX Machine Learning用心打造