趋近智
虽然 jax.experimental.host_callback 提供了一种在 JAX 计算期间在主机上执行任意 Python 代码的方式,但其性质会打断 JAX 转换的流程,主要用于调试或日志记录等副作用。要将 计算纯粹 的外部 Python 函数整合到需要转换(例如,使用 jit、vmap 或 grad)的 JAX 图中,JAX 提供了 jax.pure_callback。
jax.pure_callback 允许你在转换后的 JAX 代码中调用 Python 函数,但它附带一项重要的约定:你调用的函数必须是函数式纯粹的。
函数式纯粹性对于 JAX 的追踪和转换机制的正常运行非常重要。纯函数具有两个主要特性:
JAX 依赖于这种纯粹性。在追踪阶段(例如,当 jit 编译一个函数时),JAX 根据抽象的形状和类型分析操作。它需要相信回调函数的行为只取决于其输入,并且其抽象表示(输出形状/数据类型)准确反映了其在任何有效输入下的运行时行为。如果用 pure_callback 封装的函数违反这项约定(例如,对相同输入返回不同的值,或修改全局变量),JAX 转换的结果可能会变得不正确或不可预测,通常不会引发明确的错误。
jax.pure_callback 的工作原理当你使用 jax.pure_callback 时,你需要提供三项主要内容:
在 JAX 的追踪阶段,它并不会实际执行 Python 回调函数。相反,它使用提供的输出形状和数据类型信息(result_shape_dtypes)在计算图(jaxpr)中创建一个占位符。这使得追踪和 jit、vmap 或 grad 等转换能够进行,将回调视为具有已知输入/输出规范的黑盒 (black box)。
在运行时(当编译后的 JAX 函数执行时),实际的 Python 回调函数会使用具体的输入值被调用。JAX 相信产生的输出会与之前指定的形状和数据类型匹配。
让我们通过一个例子来说明。假设我们有一个纯 Python 函数,它执行一个特定的计算,可能使用 JAX 中不直接可用的库,但已知其具有确定性且无副作用。
import jax
import jax.numpy as jnp
import numpy as np # 使用 NumPy 作为“外部”函数
# 假设此函数代表一些复杂的纯计算
# 可能来自外部库或自定义 Python 代码。
def external_pure_python_computation(x: np.ndarray, y: float) -> np.ndarray:
"""一个纯 Python 函数的占位符。"""
# 确保此函数的输入是 NumPy 数组
if isinstance(x, jax.Array):
x = np.array(x)
# 纯计算示例:
return np.sin(x) * y + 1.0
def jax_function_using_callback(a, b):
"""一个包含纯回调的 JAX 函数。"""
# 定义预期的输出结构。
# 这里,我们期望一个与 'a' 具有相同形状和数据类型的数组。
output_shape_dtype = jax.ShapeDtypeStruct(a.shape, a.dtype)
# 创建回调
result = jax.pure_callback(
external_pure_python_computation, # 纯 Python 函数
output_shape_dtype, # 预期的输出形状/数据类型
a, # 示例/实际参数 'a'
b # 示例/实际参数 'b'
# 如果回调需要关键字参数,则传递它们
)
return result + a # 继续进行标准 JAX 操作
# JIT 内的用法示例
key = jax.random.PRNGKey(0)
input_array = jax.random.normal(key, (3, 3))
input_scalar = 2.0
jitted_function = jax.jit(jax_function_using_callback)
# 执行 JIT 编译的函数
output = jitted_function(input_array, input_scalar)
print("输入数组:\n", input_array)
print("\n输出数组:\n", output)
# 手动验证计算(近似回调)
expected_output = np.sin(np.array(input_array)) * input_scalar + 1.0 + np.array(input_array)
print("\n预期输出(近似):\n", expected_output)
# 检查结果是否接近
assert np.allclose(output, expected_output)
# 也可以通过 pure_callback 进行 vmap 或 grad
vmapped_function = jax.vmap(jax_function_using_callback, in_axes=(0, None))
batched_input_array = jax.random.normal(key, (10, 3, 3))
batched_output = vmapped_function(batched_input_array, input_scalar)
print("\n批量输出的形状:", batched_output.shape)
grad_function = jax.grad(lambda x: jnp.sum(jax_function_using_callback(x, input_scalar)))
# 注意:Grad 要求回调函数对其输入是可微的,
# 而 pure_callback 本身不保证这一点。你通常会使用
# 自定义 VJP 来通过外部代码进行微分。
# 这个例子之所以有效,是因为 sin(x)*y 是可微的。
gradients = grad_function(input_array)
print("\n关于 input_array 的梯度:\n", gradients)
在此示例中:
external_pure_python_computation 是外部纯 Python 函数的替代表达。它接受一个 NumPy 数组和一个浮点数。jax_function_using_callback 中,我们使用 jax.ShapeDtypeStruct 定义 output_shape_dtype,以告知 JAX 预期从回调中获得的输出类型(一个与输入 a 具有相同形状和数据类型的数组)。jax.pure_callback,传递 Python 函数、预期的输出结构和实际输入(a、b)。result + a)。jit 编译、vmap 映射,甚至使用 grad 进行微分(尽管微分依赖于底层数学运算的可微性,并且对于复杂情况可能需要自定义规则)。result_shape_dtypes提供正确的 result_shape_dtypes 非常重要。此参数 (parameter)告知 JAX 的追踪器回调输出的抽象值(形状和数据类型),而无需运行 Python 代码。
jax.Array、NumPy 数组或标量),请提供 jax.ShapeDtypeStruct(shape, dtype)。jax.ShapeDtypeStruct 实例的元组,每个实例对应一个输出元素。result_shape_dtypes 必须镜像此结构,在叶子节点处包含 jax.ShapeDtypeStruct 对象。# 多个输出示例
def multi_output_pure_function(x):
return np.sum(x), np.mean(x)
def jax_multi_output_callback(arr):
# 指定输出结构:一个 float32 标量和另一个 float32 标量
output_structure = (
jax.ShapeDtypeStruct((), arr.dtype), # 标量的形状为 ()
jax.ShapeDtypeStruct((), arr.dtype)
)
sum_val, mean_val = jax.pure_callback(
multi_output_pure_function,
output_structure,
arr
)
return sum_val * 2, mean_val * 3
input_arr = jnp.arange(5.0)
res1, res2 = jax.jit(jax_multi_output_callback)(input_arr)
print(f"\n多个输出:{res1=}, {res2=}") # 输出:res1=20.0, res2=6.0
jax.pure_callback 的使用场景jit 或 pmap 等转换下。pure_callback,尤其是在性能关键的循环中。XLA 编译的原生 JAX 操作几乎总是快得多。pure_callback 与 jit、vmap 和 pmap 兼容,但自动微分(grad、vjp、jvp)要求被封装的 Python 函数本身以 JAX 可理解的方式是可微的,或者你需要为回调操作定义自定义微分规则(使用 jax.custom_vjp 或 jax.custom_jvp)。pure_callback 本身并不能使不可微函数变得可微。host_callback 的比较| 特性 | jax.pure_callback |
jax.experimental.host_callback |
|---|---|---|
| 纯粹性 | 必需(用户保证) | 非必需(允许副作用) |
| 用途 | 纯计算,库集成 | 调试,日志记录,I/O,副作用 |
jit |
兼容 | 兼容(但在主机上执行) |
vmap / pmap |
兼容 | 有限(顺序执行回调) |
grad |
兼容(如果函数可微或提供了自定义规则) | 不直接可微 |
| 执行 | 在运行时执行 Python 函数 | 在主机上执行 Python 函数 |
| 返回值 | 将计算结果返回给 JAX | 通常不返回任何内容(None) |
当你需要将外部的、纯计算型 Python 代码整合到将进行 jit、vmap 或 grad 等转换的 JAX 函数中时,请选择 jax.pure_callback。务必确保该函数严格遵守纯粹性约定。如果你需要执行打印或日志记录等副作用,host_callback 是合适的工具,尽管功能更受限。
这部分内容有帮助吗?
jax.pure_callback的官方API规范、用法示例和契约细节。host_callback API,它提供了一种在JAX计算期间执行带副作用的Python代码的对比方法。pure_callback设计的基础。jax.pure_callback的契约至关重要。© 2026 ApX Machine LearningAI伦理与透明度•