JAX computations often need to interact with other parts of the scientific Python stack or require specialized low-level functionality. This chapter addresses how to bridge JAX with external systems and extend its core operations.
You will learn practical techniques for:
host_callback
and pure_callback
, recognizing their use cases and limitations.5.1 Integrating JAX with NumPy
5.2 Zero-Copy Data Sharing with DLPack
5.3 Calling External CPU/GPU Code with jax.experimental.host_callback
5.4 Using jax.pure_callback for Side-Effect Free Calls
5.5 Introduction to JAX Primitives
5.6 Defining Custom Primitives
5.7 Implementing Abstract Evaluation Rules
5.8 Implementing Lowering Rules for Backends (CPU/GPU/TPU)
5.9 Defining Differentiation Rules for Custom Primitives
5.10 Practice: Integrating a C++ Function
© 2025 ApX Machine Learning