This chapter introduces JAX, a Python library designed for high-performance numerical computation, particularly relevant for machine learning research. We'll start by defining what JAX is and how it fits into the Python scientific computing ecosystem.
You will learn about the core ideas behind JAX, emphasizing its relationship with NumPy and its reliance on function transformations. We'll compare jax.numpy
with the standard NumPy library, highlighting key similarities and differences you need to be aware of, such as immutability. We will also cover the practical steps for installing JAX and configuring it for different hardware like CPUs, GPUs, and TPUs.
Finally, you'll get hands-on experience creating and manipulating JAX arrays, the fundamental data structure, and learn how JAX manages computations across available hardware devices. By the end of this chapter, you'll have a foundational understanding of JAX's purpose, its basic API, and how to set up your environment to start using it.
1.1 What is JAX?
1.2 JAX vs NumPy
1.3 Core Design Philosophy: Function Transformations
1.4 Installation and Setup
1.5 Working with JAX Arrays
1.6 Device Management: CPU, GPU, TPU
1.7 Practice: Basic Array Operations
© 2025 ApX Machine Learning