趋近智
为了在 CPU、GPU 或 TPU 等特定硬件上实际执行自定义 JAX 原语,需要一个重要的翻译过程。这个过程被称为“转换”,其中高级原语被转换为硬件后端能够理解的低级指令。尽管自定义原语在形状和数据类型方面已抽象定义,但使它们能够在不同硬件后端上实际执行是一个必要的步骤,通常通过 XLA (加速线性代数) 编译器完成。
JAX 通过使用 XLA 将 Python 函数即时编译为优化的可执行代码来提高性能。XLA 在其自身的中间表示形式 HLO (高级优化器 IR) 上运行。当 JAX 在编译过程中遇到一个原语(例如 jax.lax.add 或您的自定义原语)时,它需要一种方法将该原语转换为相应的 HLO 指令序列。如果没有转换规则,XLA 将不知道为您的自定义操作生成什么机器代码。
每个后端(CPU、GPU、TPU)可能有不同的最佳方式来实现相同的操作。因此,您需要为打算支持的每个后端提供专门的转换规则。
JAX 中的转换过程几乎总是以 XLA HLO 为目标。您的转换规则的任务是接收原语的输入(表示为 XLA 值),并使用 XLA 构建器对象构建执行预期操作的 HLO 计算图。XLA 随后会将此 HLO 图进一步编译为目标加速器的高度优化机器代码。
您可以使用 JAX 的内部 API 为特定原语和后端注册转换规则。一种常见的方法涉及 xla_client.register_translation 等函数。您将您的原语对象与执行转换的 Python 函数进行关联。
# 示例:为 'my_custom_op_p' 原语注册转换规则
from jax.lib import xla_client
from jax.interpreters import xla
# 假设 'my_custom_op_p' 是您的 Primitive 对象
def my_custom_op_lowering_cpu(builder, *args, **params):
"""my_custom_op 在 CPU 上的转换规则。"""
# 'builder' 是一个 XlaBuilder 实例
# 'args' 是输入的 XLA 计算值
# 'params' 是原语的静态参数
# 使用构建器来构建 HLO 操作...
# 示例:如果 my_custom_op 添加 1.0
# operand = args[0]
# one = builder.ConstantF32Scalar(1.0)
# result = builder.Add(operand, one)
# return [result] # 返回输出 XLA 值的列表
# 替换为您的原语的实际 HLO 构建
pass
# 为 CPU 后端注册
xla_client.register_translation(my_custom_op_p,
my_custom_op_lowering_cpu,
platform='cpu')
# 类似地,为 'gpu' 和 'tpu' 平台定义和注册
# def my_custom_op_lowering_gpu(builder, *args, **params): ...
# xla_client.register_translation(my_custom_op_p, my_custom_op_lowering_gpu, platform='gpu')
# def my_custom_op_lowering_tpu(builder, *args, **params): ...
# xla_client.register_translation(my_custom_op_p, my_custom_op_lowering_tpu, platform='tpu')
(注意:具体的 API 细节可能会有所变化;请查阅当前的 JAX 文档以获取最新的注册方法。)
传递给转换函数的 builder 对象是构建 HLO 的主要工具。它提供了以下方法:
builder.ConstantF32Scalar(1.0))。builder.Add、builder.Mul、builder.Log1p)。builder.DotGeneral)。builder.Reshape、builder.Transpose、builder.Slice)。您使用这些方法,将它们链接在一起,接收输入 args(它们是 XLA 值的句柄)并生成表示您的原语结果的输出 XLA 值。您之前定义的抽象评估规则提供了构建器通常需要正确构建这些操作的形状和数据类型信息。
尽管某些原语在不同后端上可能具有相同的转换规则,但性能要求高的操作通常受益于后端专门化:
gpu_kernel 调用,嵌入 (embedding) CUDA 代码(这会大大增加复杂性)。relu 的转换假设我们正在定义一个自定义 ReLU 原语 my_relu_p。ReLU 操作是 。
# my_relu_p 的转换
# 假设 my_relu_p 是 Primitive 对象
# 假设抽象评估规则已定义
def my_relu_lowering(builder, x_operand, **params):
"""将 my_relu(x) = max(0, x) 转换为 HLO。"""
# 从操作数获取形状和数据类型
# 抽象评估确保 x_operand 具有正确的抽象值
shape = builder.GetShape(x_operand)
dtype = shape.element_type() # 例如,F32
# 创建一个与操作数相同类型和形状的常数零
zero = builder.Broadcast(builder.ConstantElement(0, dtype), shape.dimensions())
# 计算逐元素最大值
result = builder.Max(zero, x_operand)
return [result] # 返回包含单个输出的列表
# 为相关平台注册此规则
# xla_client.register_translation(my_relu_p, my_relu_lowering, platform='cpu')
# xla_client.register_translation(my_relu_p, my_relu_lowering, platform='gpu')
# xla_client.register_translation(my_relu_p, my_relu_lowering, platform='tpu')
在此示例中,转换很简单:创建一个与输入形状和类型相同的零张量,并使用 Max HLO 操作。这条规则可能对所有后端都足够,因为 XLA 知道如何为每个后端高效地编译 Max。
编写转换规则可以从简单(如 ReLU 示例)到高度复杂。如果您的操作不能很好地映射到现有的 HLO 指令,您可能需要:
CustomCall 调用它。这需要 HLO 和目标硬件编程模型(例如 CUDA)方面的专业知识。调试转换规则也可能具有挑战性,通常需要使用 JAX 或 XLA 提供的工具检查生成的 HLO 本身。
通过实现转换规则,您提供了 JAX 在所需硬件上高效编译和运行您的自定义原语所需的最后一部分,使其成为 JAX 生态系统中的一等公民,可以用于 jit、vmap、pmap,甚至进行微分(一旦您定义了其微分规则)。
这部分内容有帮助吗?
xla_client.register_translation 等 API 的使用。© 2026 ApX Machine LearningAI伦理与透明度•