To effectively extend JAX or understand its inner workings, especially when interacting with external code or optimizing performance, we need to look beneath the surface of the familiar NumPy-like API. At the core of JAX's functionality lies the concept of primitives.
Think of primitives as the fundamental, atomic operations that JAX truly understands. While you write code using functions like jax.numpy.add
or jax.numpy.dot
, JAX ultimately breaks these down (or traces them) into a sequence of these elementary primitives. Examples include operations like add
, mul
, dot_general
, reduce_sum
, transpose
, conv_general_dilated
, and control flow operators like scan_p
or cond_p
.
Primitives are significant for several reasons:
Transformation Target: JAX transformations like jit
, grad
, and vmap
don't operate directly on Python code. Instead, they work on the level of primitives.
jit
: Traces the Python function into a sequence of primitives represented in an intermediate format called jaxpr
. This jaxpr
is then compiled by XLA.grad
: Relies on having specific differentiation rules (Vector-Jacobian Product or VJP rules, and Jacobian-Vector Product or JVP rules) defined for each primitive encountered during the trace.vmap
: Uses batching rules defined for each primitive to determine how to map the operation over an additional batch dimension.Compilation Interface: Primitives form the bridge between the high-level JAX API and the backend compilers, primarily XLA (Accelerated Linear Algebra). JAX translates the jaxpr
representation (a graph of primitives) into XLA's High Level Operations (HLO) instructions. XLA then performs sophisticated optimizations and compiles these HLO instructions into highly efficient machine code specific to the target hardware (GPU, TPU, or CPU).
Extensibility: While JAX provides a comprehensive set of built-in primitives, the system is designed to be extensible. Understanding primitives is the foundation for adding new, custom operations to JAX. If you need to integrate a specialized algorithm written in C++ or CUDA, or define an operation with a unique differentiation rule, you often achieve this by defining a new primitive and providing the necessary rules for it.
The process from Python function to execution generally follows these steps, with primitives playing a central role:
A conceptual view of how a Python function using JAX operations is transformed into executable code. Primitives form the core of the
jaxpr
representation, which is the input to the XLA compiler.
It's helpful to distinguish between a JAX function like jax.numpy.add
and a primitive like add_p
. The function jax.numpy.add
is a Python wrapper that provides a familiar NumPy-like interface. When JAX traces code using this function, it typically registers the corresponding primitive add_p
into the jaxpr
. The primitive is the internal representation that carries the necessary information for transformations and compilation.
In essence, primitives are the irreducible building blocks upon which JAX's powerful system of transformations and compilation is built. While you may not interact with them directly in everyday use, understanding their role is essential for advanced optimization, debugging, and extending JAX's capabilities, particularly when defining custom operations as we will explore in the subsequent sections.
© 2025 ApX Machine Learning