Now, let's put the theory into practice. We'll take a common numerical task, calculating pairwise distances between two sets of vectors, and apply the optimization principles discussed in this chapter. Our goal is not just to make it faster, but to understand why certain approaches perform better within the JAX/XLA ecosystem.Suppose we have two sets of points, $X$ with $N$ points and $Y$ with $M$ points, both in a $D$-dimensional space. We want to compute a distance matrix $D_{ij}$ representing the Euclidean distance between the $i$-th point in $X$ and the $j$-th point in $Y$.$X \in \mathbb{R}^{N \times D}$ $Y \in \mathbb{R}^{M \times D}$ $D \in \mathbb{R}^{N \times M}$ $$ D_{ij} = \sqrt{\sum_{k=1}^{D} (X_{ik} - Y_{jk})^2} $$Baseline ImplementationA straightforward way to implement this in JAX, mimicking NumPy, might involve explicit broadcasting.import jax import jax.numpy as jnp import timeit # Generate some sample data key = jax.random.PRNGKey(0) N, M, D = 1000, 1500, 64 X = jax.random.normal(key, (N, D)) Y = jax.random.normal(key, (M, D)) # Ensure data is on the device before timing X = jax.device_put(X) Y = jax.device_put(Y) def pairwise_distance_v1(X, Y): """Computes pairwise distances using explicit broadcasting.""" # Expand dimensions for broadcasting: X becomes (N, 1, D), Y becomes (1, M, D) diff = X[:, None, :] - Y[None, :, :] # Calculate squared Euclidean distance squared_dist = jnp.sum(diff**2, axis=-1) # Return the square root return jnp.sqrt(squared_dist) # Compile the function pairwise_distance_v1_jit = jax.jit(pairwise_distance_v1) # First run to compile _ = pairwise_distance_v1_jit(X, Y).block_until_ready() # Benchmark runs = 10 start_time = timeit.default_timer() for _ in range(runs): result_v1 = pairwise_distance_v1_jit(X, Y).block_until_ready() elapsed_v1 = (timeit.default_timer() - start_time) / runs print(f"Baseline JIT version (v1) average time: {elapsed_v1:.6f} seconds") # Example Output (will vary based on hardware): # Baseline JIT version (v1) average time: 0.002512 secondsThis implementation is already quite efficient due to JAX's NumPy-like API and JIT compilation. JAX traces the function, converts it to jaxpr, and XLA compiles it into optimized kernels. The broadcasting operations ([:, None, :] and [None, :, :]) create intermediate arrays, and jnp.sum performs the reduction. XLA's fusion capabilities will likely combine some of these operations.Alternative Implementation: Leveraging Matrix AlgebraWe can express the squared Euclidean distance using matrix algebra: $||x_i - y_j||^2 = ||x_i||^2 - 2 x_i^T y_j + ||y_j||^2$This suggests an alternative computation involving dot products and sums of squares.def pairwise_distance_v2(X, Y): """Computes pairwise distances using matrix algebra identity.""" # Calculate squared norms for each vector in X and Y x_sq_norms = jnp.sum(X**2, axis=1) # Shape (N,) y_sq_norms = jnp.sum(Y**2, axis=1) # Shape (M,) # Calculate the dot products between all pairs of vectors # X @ Y.T results in a (N, M) matrix where element (i, j) is dot(X[i], Y[j]) dot_products = jnp.dot(X, Y.T) # Compute squared distances using the identity: ||x-y||^2 = ||x||^2 - 2*x.y + ||y||^2 # We need to reshape norms for broadcasting: # x_sq_norms[:, None] -> (N, 1) # y_sq_norms[None, :] -> (1, M) squared_dist = x_sq_norms[:, None] - 2 * dot_products + y_sq_norms[None, :] # Handle potential small negative values due to floating point inaccuracies squared_dist = jnp.maximum(0.0, squared_dist) return jnp.sqrt(squared_dist) # Compile the function pairwise_distance_v2_jit = jax.jit(pairwise_distance_v2) # First run to compile _ = pairwise_distance_v2_jit(X, Y).block_until_ready() # Benchmark start_time = timeit.default_timer() for _ in range(runs): result_v2 = pairwise_distance_v2_jit(X, Y).block_until_ready() elapsed_v2 = (timeit.default_timer() - start_time) / runs print(f"Algebraic JIT version (v2) average time: {elapsed_v2:.6f} seconds") # Example Output (will vary based on hardware): # Algebraic JIT version (v2) average time: 0.001855 secondsAnalysis and ComparisonWhy might v2 be faster, especially on accelerators?Intermediate Allocation (v1): The first version pairwise_distance_v1 explicitly creates a large intermediate array diff of shape (N, M, D). For our example sizes (1000, 1500, 64), this is 1000 * 1500 * 64 = 96,000,000 elements. This can consume significant memory bandwidth, which is often a bottleneck on GPUs/TPUs.Optimized Kernels (v2): The second version pairwise_distance_v2 relies heavily on jnp.dot(X, Y.T). Matrix multiplication is a fundamental operation for which highly optimized kernels (like cuBLAS on NVIDIA GPUs or specific TPU kernels) exist. XLA can effectively target these kernels. The other operations (summing squares, broadcasting additions/subtractions) are typically elementwise and can often be fused effectively by XLA with the matrix multiplication or with each other.XLA Fusion: While XLA can fuse operations in v1, the large intermediate tensor might limit the extent or efficiency of fusion compared to the structure of v2, which breaks the problem down into operations (matrix multiply, elementwise operations) that map well to accelerator hardware capabilities.Let's visualize the potential performance difference (using representative timings):{"data": [{"x": ["Baseline (v1)", "Algebraic (v2)"], "y": [0.002512, 0.001855], "type": "bar", "marker": {"color": ["#fa5252", "#40c057"]}}], "layout": {"title": "Pairwise Distance Calculation Time (JIT Compiled)", "yaxis": {"title": "Average Execution Time (seconds)"}, "xaxis": {"title": "Implementation Version"}, "template": "plotly_white", "width": 600, "height": 400}}Comparison of average execution times for the two JIT-compiled pairwise distance implementations. Lower is better.The Optimization Process"How would we arrive at this optimization in a scenario?"Baseline and JIT: Start with a clear, readable implementation (v1) and apply jax.jit.Benchmarking: Use block_until_ready() and a timer (timeit or %timeit in notebooks) to get reliable performance numbers. Run multiple times to average out noise.Profiling (Identify Bottlenecks): If performance isn't satisfactory, use JAX's profiling tools:jax.profiler.start_trace() / stop_trace(): To capture execution traces viewable in TensorBoard. This helps visualize operation durations and identify which parts of the computation take the most time. You might see significant time spent in broadcasting or reduction operations in v1.Device-specific profilers (like NVIDIA Nsight Systems): For deeper hardware-level analysis.Inspect Jaxpr (Optional): Use jax.make_jaxpr to examine the intermediate representation. While often verbose, it can reveal how JAX sees the computation before XLA optimization, potentially highlighting large intermediate structures.Hypothesize and Refactor: Based on profiling or understanding of XLA/hardware, form hypotheses. "Is the large intermediate diff array costly?" or "Can I use a more direct matrix multiplication?". Refactor the code (v2) based on these ideas.Re-Benchmark: Compare the performance of the new version against the baseline.This iterative process of implementing, benchmarking, profiling, and refactoring is essential for optimizing JAX code effectively. Understanding how operations map to hardware capabilities and how XLA performs fusion allows you to write code that JAX can compile into highly efficient routines. Remember that the "best" implementation can sometimes depend on the specific dimensions (N, M, D) and the target hardware (CPU, GPU, TPU).