While JAX provides a comprehensive library of numerical operations built on jax.numpy
and jax.lax
, there are situations where you might need to extend JAX's capabilities with your own fundamental operations. This is where custom primitives come into play. Primitives are the atomic, irreducible operations that JAX understands and transforms. Defining a custom primitive allows you to integrate highly specialized computations, potentially implemented in other languages like C++ or CUDA, directly into the JAX ecosystem, making them subject to JIT compilation, automatic differentiation, and vectorization just like built-in operations.
Think of primitives as the vocabulary JAX uses to describe computations. Adding a custom primitive is like adding a new word to this vocabulary. Reasons for doing this include:
jax.lax
primitives.Defining a custom primitive involves telling JAX several things about your new operation:
jit
, vmap
, etc.Let's outline the process:
First, you create a new primitive instance using jax.core.Primitive
. This object serves as the identifier for your operation.
import jax
import jax.numpy as jnp
from jax.core import Primitive
# Create a unique Primitive object
my_custom_op_p = Primitive("my_custom_op")
The string name ("my_custom_op") is used for debugging and printing jaxprs.
Typically, you wrap the primitive usage in a Python function that users will call. This function takes standard Python/NumPy/JAX inputs and uses bind
to apply the primitive to its arguments.
def my_custom_op(x, *, some_parameter: int):
"""Applies the custom operation to x.
Args:
x: Input JAX array.
some_parameter: A static parameter for the operation.
Returns:
A JAX array result.
"""
# 'bind' stages the primitive application with its arguments
# and parameters.
return my_custom_op_p.bind(x, some_parameter=some_parameter)
Note the use of bind
. This doesn't execute the operation immediately but rather stages it within JAX's tracing machinery. Parameters like some_parameter
that affect the computation's structure or abstract evaluation often need to be passed as keyword arguments to bind
.
JAX needs to know the shape and dtype of the output without running the actual computation. This is achieved by defining an abstract evaluation rule using def_abstract_eval
. This rule takes abstract versions of the inputs (objects containing shape and dtype information, like jax.core.ShapedArray
) and computes the abstract output.
from jax.core import ShapedArray
def my_custom_op_abstract_eval(abstract_x, *, some_parameter: int):
"""Abstract evaluation rule for my_custom_op."""
# Example: Assume the op squares the input and adds the parameter,
# preserving shape and dtype.
# Real rules depend entirely on what the primitive does.
output_shape = abstract_x.shape
output_dtype = abstract_x.dtype
# Parameter checks might occur here
if some_parameter < 0:
raise ValueError("some_parameter must be non-negative")
return ShapedArray(output_shape, output_dtype)
# Register the abstract evaluation rule with the primitive
my_custom_op_p.def_abstract_eval(my_custom_op_abstract_eval)
This step is fundamental for JAX transformations. jit
uses it to determine the structure of the computation graph, and vmap
uses it to calculate batch dimensions.
This is often the most involved step. You need to tell JAX how to translate your primitive into low-level code that the target backend (CPU, GPU, TPU) can execute. This is usually done by generating XLA HLO (High-Level Operations) via MLIR, JAX's intermediate representation layer.
You register lowering rules using jax.interpreters.mlir.register_lowering
.
# Conceptual example - requires deeper knowledge of MLIR and XLA HLO
from jax.interpreters import mlir
# Assume we have functions to build the specific HLO operations
# from jaxlib.hlo_helpers import custom_call # Or similar helpers
def my_custom_op_lowering_cpu(ctx, x_operand, *, some_parameter: int):
"""Lowering rule for CPU."""
# 1. Define the output type for MLIR
# result_type = mlir.ir.RankedTensorType.get(output_shape, mlir_dtype)
# 2. Generate XLA HLO instructions
# This might involve standard HLO ops or a 'custom_call'
# to an external C++/CPU function.
# result_hlo = build_cpu_hlo_for_my_op(x_operand, some_parameter) # Fictional function
# return [result_hlo] # Must return a list of MLIR values
pass # Placeholder - Actual implementation is complex
def my_custom_op_lowering_gpu(ctx, x_operand, *, some_parameter: int):
"""Lowering rule for GPU."""
# Similar structure, but generates HLO targeting GPU.
# Might use 'custom_call' to invoke a CUDA kernel.
# result_hlo = build_gpu_hlo_for_my_op(x_operand, some_parameter) # Fictional function
# return [result_hlo]
pass # Placeholder
# Register the lowering rules
# mlir.register_lowering(my_custom_op_p, my_custom_op_lowering_cpu, platform='cpu')
# mlir.register_lowering(my_custom_op_p, my_custom_op_lowering_gpu, platform='gpu')
# ... potentially register for TPU as well
Writing lowering rules requires understanding the target backend's execution model and the XLA HLO instruction set or how to interface external code using custom_call
. This step bridges the gap between the high-level JAX primitive and the low-level compiled code.
To make your custom primitive work with jax.grad
and other autodiff transformations, you need to define its JVP and VJP rules.
These rules are registered using jax.interpreters.ad.primitive_jvps
and jax.interpreters.ad.primitive_transposes
(for VJPs).
from jax.interpreters import ad
# --- JVP Rule ---
def my_custom_op_jvp(primals, tangents, *, some_parameter: int):
"""JVP rule for my_custom_op."""
x, = primals
x_dot, = tangents # Tangent vector corresponding to x
# Compute primal output(s)
primal_out = my_custom_op_p.bind(x, some_parameter=some_parameter)
# Compute tangent output(s) based on the operation's derivative
# Example: If my_custom_op(x) = x^2 + some_parameter
# Derivative is 2x. JVP is (2x) * x_dot
if type(x_dot) is ad.Zero:
tangent_out = ad.Zero.from_value(primal_out)
else:
# Assume a helper primitive or function for the derivative exists
# tangent_out = my_custom_op_derivative_p.bind(x, x_dot, some_parameter=some_parameter)
# Or compute directly if possible:
tangent_out = 2 * x * x_dot # Example derivative computation
return primal_out, tangent_out
# Register JVP rule
ad.primitive_jvps[my_custom_op_p] = my_custom_op_jvp
# --- VJP Rule ---
# VJP rules are often defined via transposition of JVP rules,
# but can be defined directly if needed (especially for efficiency).
# Implementing a transpose rule requires deeper understanding of ad.primitive_transposes.
# Alternatively, and often simpler, if the operation is implemented
# via a Python function using standard JAX ops or other primitives
# (even if those primitives use custom_calls), you can use
# jax.custom_vjp on the *user-facing* function instead of defining
# rules directly on the primitive.
If your primitive simply calls external code, you might need to provide analytical derivatives or implement the backward pass in the external code as well. Using jax.custom_vjp
or jax.custom_jvp
on the wrapper function can sometimes be a more manageable approach than defining the rules directly on the primitive itself, especially if the operation's implementation involves multiple steps.
The diagram below illustrates how these components fit together when JAX encounters your custom primitive within a transformed function (like one decorated with @jit
).
The process of integrating a custom primitive into JAX. A Python call gets staged using
bind
, represented in a jaxpr, abstractly evaluated for shape/dtype, and then lowered to XLA HLO for backend execution. Differentiation requires corresponding VJP/JVP rules.
Defining custom primitives is a powerful feature for extending JAX, but it represents a significant step up in complexity compared to using standard JAX operations.
Before defining a custom primitive, consider if your goal can be achieved using existing jax.lax
operations, jax.experimental.host_callback
, jax.pure_callback
, or by expressing the computation differently. However, when performance or necessity dictates, custom primitives offer a way to fully integrate specialized code into the JAX ecosystem. The subsequent sections will delve into the implementation details for abstract evaluation, lowering, and differentiation rules.
© 2025 ApX Machine Learning