ApX logo
Getting Started with JAX
Chapter 1: Introduction to JAX
What is JAX?
JAX vs NumPy
Core Design Philosophy: Function Transformations
Installation and Setup
Working with JAX Arrays
Device Management: CPU, GPU, TPU
Practice: Basic Array Operations
Quiz for Chapter 1
Chapter 2: Accelerating Functions with JIT Compilation
The Need for Speed: Why Compile?
Introducing jax.jit
How JIT Works: Tracing and Compilation
Python Control Flow and jit
Static vs Traced Values
Common Challenges with jit
Hands-on Practical: Applying jit
Quiz for Chapter 2
Chapter 3: Automatic Differentiation with grad
Understanding Gradients
Introducing jax.grad
How Autodiff Works: Reverse Mode
Differentiating with Respect to Arguments
Higher-Order Derivatives (grad of grad)
Value and Gradient (jax.value_and_grad)
Differentiation and Control Flow
Limitations and Considerations
Hands-on Practical: Computing Gradients
Quiz for Chapter 3
Chapter 4: Automatic Vectorization with vmap
The Concept of Vectorization
Introducing jax.vmap
Mapping over Specific Arguments (in_axes, out_axes)
Handling Multiple Batched Arguments
Nesting vmap
Combining vmap with jit and grad
Performance Considerations for vmap
Hands-on Practical: Vectorizing Functions
Quiz for Chapter 4
Chapter 5: Parallelization Across Devices with pmap
Introduction to Data Parallelism (SPMD)
Introducing jax.pmap
Mapping Data to Devices (in_axes, out_axes)
Device Meshes and Axis Names
Collective Operations (lax.psum, lax.pmean, etc.)
Combining pmap with other Transformations
Debugging pmapped Functions
Hands-on Practical: Parallel Computation
Quiz for Chapter 5
Chapter 6: Managing State in JAX
Functional Purity and Side Effects
The Challenge of State in Functional Code
Pattern: Explicit State Passing
Using PyTrees for Structured State
Example: Stateful Counter
Example: Simple Optimizer State
Combining State Management with Transformations
Practice: Implementing Stateful Functions
Quiz for Chapter 6