While jax.experimental.host_callback
provides a way to execute arbitrary Python code on the host during a JAX computation, its nature breaks the flow for JAX transformations and is primarily intended for side effects like debugging or logging. For integrating external Python functions that are computationally pure into JAX graphs that need to be transformed (e.g., with jit
, vmap
, or grad
), JAX offers jax.pure_callback
.
jax.pure_callback
allows you to call a Python function from within transformed JAX code, but it comes with a significant contract: the function you call must be functionally pure.
Functional purity is essential for JAX's tracing and transformation mechanisms to work correctly. A pure function has two main properties:
JAX relies on this purity. During tracing (e.g., when jit
compiles a function), JAX analyzes the operations based on abstract shapes and types. It needs to trust that the callback function's behavior only depends on its inputs and that its abstract representation (output shape/dtype) accurately reflects its runtime behavior for any valid input. If a function wrapped with pure_callback
violates this contract (e.g., returns different values for the same input, or modifies a global variable), the results of JAX transformations can become incorrect or unpredictable, often without raising explicit errors.
jax.pure_callback
WorksWhen you use jax.pure_callback
, you provide three main things:
During JAX's tracing phase, it doesn't actually execute the Python callback. Instead, it uses the provided output shape and dtype information (result_shape_dtypes
) to create a placeholder in the computation graph (the jaxpr). This allows tracing and transformations like jit
, vmap
, or grad
to proceed, treating the callback as a black box with known input/output specifications.
At runtime (when the compiled JAX function executes), the actual Python callback function is invoked with the concrete input values. JAX trusts that the output produced will match the shape and dtype specified earlier.
Let's illustrate with an example. Suppose we have a pure Python function that performs a specific calculation, perhaps using a library not directly available in JAX, but known to be deterministic and side-effect free.
import jax
import jax.numpy as jnp
import numpy as np # Using NumPy for the 'external' function
# Assume this function represents some complex, pure computation
# perhaps from an external library or custom Python code.
def external_pure_python_computation(x: np.ndarray, y: float) -> np.ndarray:
"""A placeholder for a pure Python function."""
# Ensure inputs are NumPy arrays for this function
if isinstance(x, jax.Array):
x = np.array(x)
# Example pure computation:
return np.sin(x) * y + 1.0
def jax_function_using_callback(a, b):
"""A JAX function that incorporates the pure callback."""
# Define the expected output structure.
# Here, we expect a single array with the same shape and dtype as 'a'.
output_shape_dtype = jax.ShapeDtypeStruct(a.shape, a.dtype)
# Create the callback
result = jax.pure_callback(
external_pure_python_computation, # The pure Python function
output_shape_dtype, # The expected output shape/dtype
a, # Example/actual argument 'a'
b # Example/actual argument 'b'
# Pass keyword arguments if the callback expects them
)
return result + a # Continue with standard JAX operations
# Example usage within JIT
key = jax.random.PRNGKey(0)
input_array = jax.random.normal(key, (3, 3))
input_scalar = 2.0
jitted_function = jax.jit(jax_function_using_callback)
# Execute the JIT-compiled function
output = jitted_function(input_array, input_scalar)
print("Input array:\n", input_array)
print("\nOutput array:\n", output)
# Verify the computation manually (approximates the callback)
expected_output = np.sin(np.array(input_array)) * input_scalar + 1.0 + np.array(input_array)
print("\nExpected output (approx):\n", expected_output)
# Check if results are close
assert np.allclose(output, expected_output)
# You can also vmap or grad through pure_callback
vmapped_function = jax.vmap(jax_function_using_callback, in_axes=(0, None))
batched_input_array = jax.random.normal(key, (10, 3, 3))
batched_output = vmapped_function(batched_input_array, input_scalar)
print("\nShape of batched output:", batched_output.shape)
grad_function = jax.grad(lambda x: jnp.sum(jax_function_using_callback(x, input_scalar)))
# Note: Grad requires the callback to be differentiable w.r.t its inputs,
# which pure_callback itself doesn't guarantee. You'd typically use
# custom VJPs for differentiation through external code.
# This example works because sin(x)*y is differentiable.
gradients = grad_function(input_array)
print("\nGradients w.r.t input_array:\n", gradients)
In this example:
external_pure_python_computation
is our stand-in for an external, pure Python function. It takes a NumPy array and a float.jax_function_using_callback
, we define output_shape_dtype
using jax.ShapeDtypeStruct
to tell JAX what kind of output to expect from the callback (an array with the same shape and dtype as input a
).jax.pure_callback
is called, passing the Python function, the expected output structure, and the actual inputs (a
, b
).result + a
).jit
-compiled, vmap
-ped, and even differentiated using grad
(though differentiation relies on the underlying mathematical operation being differentiable and potentially requires custom rules for complex cases).result_shape_dtypes
Providing the correct result_shape_dtypes
is fundamental. This argument tells JAX's tracer the abstract value (shape and dtype) of the callback's output without running the Python code.
jax.Array
, NumPy array, or scalar), provide a jax.ShapeDtypeStruct(shape, dtype)
.jax.ShapeDtypeStruct
instances corresponding to each output element.result_shape_dtypes
must then mirror this structure, containing jax.ShapeDtypeStruct
objects at the leaves.# Example for multiple outputs
def multi_output_pure_function(x):
return np.sum(x), np.mean(x)
def jax_multi_output_callback(arr):
# Specify output structure: a scalar float32 and another scalar float32
output_structure = (
jax.ShapeDtypeStruct((), arr.dtype), # Shape () for scalar
jax.ShapeDtypeStruct((), arr.dtype)
)
sum_val, mean_val = jax.pure_callback(
multi_output_pure_function,
output_structure,
arr
)
return sum_val * 2, mean_val * 3
input_arr = jnp.arange(5.0)
res1, res2 = jax.jit(jax_multi_output_callback)(input_arr)
print(f"\nMultiple outputs: {res1=}, {res2=}") # Outputs: res1=20.0, res2=6.0
jax.pure_callback
jit
or pmap
.pure_callback
thoughtfully, especially within performance-critical loops. Native JAX operations compiled by XLA will almost always be significantly faster.pure_callback
works with jit
, vmap
, and pmap
, automatic differentiation (grad
, vjp
, jvp
) requires that the wrapped Python function itself is differentiable in a way JAX can understand, or you need to define custom differentiation rules (using jax.custom_vjp
or jax.custom_jvp
) for the callback operation. pure_callback
itself doesn't make a non-differentiable function differentiable.host_callback
Feature | jax.pure_callback |
jax.experimental.host_callback |
---|---|---|
Purity | Required (User guaranteed) | Not required (Allows side effects) |
Use Case | Pure computations, library integration | Debugging, logging, I/O, side effects |
jit |
Compatible | Compatible (but executes on host) |
vmap / pmap |
Compatible | Limited (executes callback sequentially) |
grad |
Compatible (if function differentiable or custom rules provided) | Not directly differentiable |
Execution | Executes Python function at runtime | Executes Python function on host |
Return Value | Returns computational result back to JAX | Typically returns nothing (None ) |
Choose jax.pure_callback
when you need to integrate external, purely computational Python code into JAX functions that will undergo transformations like jit
, vmap
, or grad
. Always ensure the function strictly adheres to the purity contract. If you need to perform side effects like printing or logging, host_callback
is the appropriate, albeit more limited, tool.
© 2025 ApX Machine Learning