JAX has been quietly picking up momentum in the research community, particularly at places like DeepMind and Google Brain. If you've heard about it but haven't quite understood what makes it different from NumPy or PyTorch, this post walks through everything from the basics to the internals — including why it's fast, what "functional" really means in practice, and when you'd actually want to use it.

1. What JAX Is

JAX is a Python numerical computing library developed by Google. The elevator pitch is that it looks like NumPy but can do things NumPy can't: automatic differentiation, JIT compilation to GPU/TPU, and a set of function transformations that compose cleanly with each other.

But that description undersells it. The more interesting thing about JAX is the design philosophy behind it. NumPy is a library. PyTorch is a framework. JAX is closer to a compiler infrastructure that happens to expose a Python interface. Understanding that distinction is what makes the rest of it click.

import jax.numpy as jnp

x = jnp.array([1.0, 2.0, 3.0])
print(jnp.sin(x))
# [0.841 0.909 0.141]

That code looks unremarkable, but what actually happened is different from NumPy. The array lives on an accelerator device, the operation was dispatched through XLA, and the whole thing is set up to compose with the transformations we'll cover next.

2. The Four Core Transformations

JAX's power comes from four composable transformations. Each one takes a function and returns a new function. This is the key idea: you write a simple, clean function, then wrap it to add capabilities.

grad — automatic differentiation

from jax import grad

def f(x):
    return x**2 + 3*x

df = grad(f)
print(df(2.0))  # 7.0  (derivative of x^2 + 3x is 2x + 3, evaluated at x=2)

grad takes any scalar-valued function and returns its gradient function. No manual chain rule, no symbolic math — it works by tracing the computation and applying reverse-mode automatic differentiation. The result is numerically exact, not an approximation.

You can differentiate through arbitrary control flow, and you can compose: grad(grad(f)) gives the second derivative, grad(jit(f)) works fine, and so on.

jit — compilation

from jax import jit

@jit
def f(x):
    return x**2 + 3*x

f(1000.0)   # first call: compiles, then runs
f(999.0)    # subsequent calls: runs the compiled version directly

The first call to a @jit-decorated function is slow — JAX traces the computation and compiles it via XLA. Every call after that hits the compiled version, which can be orders of magnitude faster. The compilation is cached by function signature and input shapes.

vmap — vectorization

from jax import vmap

def f(x):
    return x**2

xs = jnp.array([1.0, 2.0, 3.0, 4.0])

vf = vmap(f)
print(vf(xs))  # [1. 4. 9. 16.]

vmap takes a function written for a single input and automatically vectorizes it to run over a batch. The difference from a for loop is that vmap produces a single fused kernel — the batching happens inside the compiled code, not in Python. Combined with jit, this is how you get efficient batch processing without writing batch-aware code.

pmap — multi-device parallelism

from jax import pmap

def f(x):
    return x**2

xs = jnp.arange(8.0)  # 8 elements, one per device
result = pmap(f)(xs)

pmap is like vmap but distributes computation across multiple GPUs or TPU cores. Each device gets a slice of the input and runs the function independently. For large-scale training across many accelerators, this is the primitive you build on.

Composing transformations

The transformations are designed to compose. This is where JAX becomes genuinely expressive:

from jax import grad, jit, vmap

# gradient of a loss, compiled, applied to a whole batch
batched_grad = jit(vmap(grad(loss_fn)))

# second derivative
d2f = jit(grad(grad(f)))

In PyTorch, combining gradients, batching, and compilation requires careful setup. In JAX, it's just function composition.

3. What "Functional" Actually Means Here

JAX is described as a functional library, and most explanations of this are either too abstract or too hand-wavy. Here's the concrete version.

Pure functions

A pure function has two properties: the output depends only on the input, and it doesn't modify anything outside itself. These are both pure:

def add(a, b):
    return a + b

def normalize(x):
    return (x - x.mean()) / x.std()

These are not:

# Not pure: output depends on external state
counter = 0
def f(x):
    global counter
    counter += 1
    return x + counter

# Not pure: modifies its input
def f(lst):
    lst.append(10)
    return lst

# Not pure: output is non-deterministic
import random
def f(x):
    return x + random.random()

The reason purity matters for JAX is practical, not philosophical. The compiler needs to reason about what your code does in order to optimize it. If a function reads from external state, modifies inputs, or has other side effects, the compiler can't safely reorder operations, fuse kernels, or parallelize anything — it would risk changing the program's behavior. Pure functions give the compiler full visibility into what's happening, which is what enables aggressive optimization.

How JAX enforces this

JAX doesn't just ask you to write pure functions — it makes impure code either fail or behave unexpectedly. The main mechanisms:

Immutable arrays. You can't do in-place mutation:

import jax.numpy as jnp

x = jnp.array([1, 2, 3])
x[0] = 10   # TypeError: JAX arrays are immutable

# Correct approach — creates a new array
x = x.at[0].set(10)

JIT captures the function at trace time. If your function reads an external variable, JIT bakes in the value at the time of first compilation:

from jax import jit

a = 10

@jit
def f(x):
    return x + a

print(f(1.0))   # 11.0

a = 100
print(f(1.0))   # still 11.0 — the compiled version doesn't see the change

This will catch you off guard the first time. The fix is to pass a as an explicit argument.

Explicit random keys. Instead of a hidden global RNG state, JAX requires you to pass keys explicitly:

import jax.random as random

key = random.PRNGKey(42)
x = random.normal(key, shape=(3,))

# To get different values, split the key
key, subkey = random.split(key)
y = random.normal(subkey, shape=(3,))

This design means random functions are pure — same key always produces the same output — which lets them work correctly with JIT and vmap.

Functional control flow. Regular Python if/else works inside JIT for conditions on static values, but dynamic conditions (based on array values) need JAX's functional equivalents:

import jax.lax as lax

# dynamic condition: use lax.cond instead of if/else
def f(x):
    return lax.cond(
        x > 0,
        lambda _: x,
        lambda _: -x,
        operand=None
    )

# dynamic loop: use lax.fori_loop instead of for
def body(i, val):
    return val + 1

result = lax.fori_loop(0, 1000, body, init_val=0)

The reason for this is that JIT needs to trace a single static computation graph. A Python for loop unrolls (which is fine for small counts), but a loop that depends on a runtime array value needs to be expressed as a primitive that XLA can compile into a proper loop kernel.

4. Why JAX Is Fast

The speed story has several layers, and it's worth understanding all of them.

XLA compilation

XLA (Accelerated Linear Algebra) is the compiler that sits underneath JAX. When you call jit(f)(x), JAX traces f to build a computation graph in XLA's intermediate representation (called HLO — High Level Operations). XLA then optimizes this graph and compiles it to machine code for your target hardware.

The key difference from ordinary Python execution is that XLA sees the entire function at once, not one operation at a time. This global view enables optimizations that are impossible when you execute operations eagerly.

Operator fusion

Consider this:

y = jnp.sin(x) + jnp.cos(x)

Without fusion, this does three passes over the data: compute sin(x) and write to memory, compute cos(x) and write to memory, add the two arrays and write to memory. Three reads, three writes.

With XLA's kernel fusion, the compiler merges these into a single kernel that reads x once, computes sin and cos simultaneously, adds them, and writes the result once. This matters a lot in practice because on modern hardware, memory bandwidth is often the bottleneck, not raw compute.

# Give XLA the opportunity to fuse by putting everything in one jit'd function
@jit
def f(x):
    return jnp.sin(x) + jnp.cos(x)

# Don't do this if you care about performance:
sin_x = jnp.sin(x)   # forces immediate execution
cos_x = jnp.cos(x)   # another pass
result = sin_x + cos_x

The functional design enables the compiler

This is the connection most explanations miss. JIT by itself isn't the reason JAX is fast — it's that the functional design gives the compiler the freedom to optimize aggressively.

When a function has no side effects and all state is explicit, the compiler can:

  • Safely reorder operations (if they don't share state, order doesn't matter)
  • Fuse operations (no risk of altering program semantics)
  • Parallelize computations across cores (each is independent)
  • Eliminate redundant computations (pure functions with the same input always produce the same output)

PyTorch in eager mode executes operations one at a time in order. Even torch.compile has to work harder to prove it's safe to optimize, because PyTorch code can have side effects. JAX's restriction to pure functions means the compiler can assume safety by construction.

lax.scan for loops

Python loops inside JIT get unrolled — the compiler sees 1000 copies of the loop body. For short loops this is fine, but for long loops it creates enormous HLO programs and slow compilation. lax.scan is the alternative:

from jax import lax

# Compute a running sum
def step(carry, x):
    new_carry = carry + x
    output = new_carry
    return new_carry, output

xs = jnp.arange(1000.0)
final_carry, outputs = lax.scan(step, init=0.0, xs=xs)

lax.scan compiles to a single loop kernel in XLA — the loop body is compiled once and the hardware executes it repeatedly. This is much faster than unrolling for long sequences, and it's how you'd implement things like RNNs or sequential algorithms in JAX.

5. The Main Modules

JAX is organized into several modules, each with a different level of abstraction and purpose:

Module Purpose When to use it
jax.numpy NumPy-compatible array operations Most of the time — your default for array math
jax.lax Low-level functional primitives Dynamic control flow (cond, scan, while_loop), or when you need precise control over what XLA sees
jax.grad / jax.value_and_grad Automatic differentiation Any time you need gradients — loss functions, optimization
jax.jit JIT compilation Functions you call repeatedly with the same input shapes
jax.vmap Automatic vectorization Applying a per-sample function to a batch
jax.pmap Multi-device parallelism Distributing training across multiple GPUs/TPUs
jax.random Functional random number generation Any randomness — always use explicit keys
jax.tree_util Operations on nested Python containers (pytrees) Manipulating model parameters stored as nested dicts or lists

jax.numpy vs jax.lax

Most people use jax.numpy and occasionally drop into jax.lax when they need something that jnp doesn't expose. The key practical difference:

import jax.numpy as jnp
import jax.lax as lax

# jnp: familiar, high-level
y = jnp.where(x > 0, x, -x)   # works, but static condition only inside jit

# lax: explicit, composable with dynamic values
y = lax.cond(x > 0, lambda _: x, lambda _: -x, operand=None)

# lax.scan: the only way to do long loops efficiently
final, history = lax.scan(step_fn, init_state, inputs)

Pytrees

JAX's transformations work not just on arrays but on arbitrary nested Python containers — dicts, lists, tuples, and named tuples — as long as the leaves are arrays. These are called pytrees, and they're how neural network parameters are typically represented:

import jax.numpy as jnp
from jax import grad

# Parameters as a nested dict (a pytree)
params = {
    'w': jnp.array([1.0, 2.0, 3.0]),
    'b': jnp.array(0.5)
}

def loss(params, x, y):
    pred = jnp.dot(params['w'], x) + params['b']
    return (pred - y)**2

# grad works on pytrees — returns a dict with the same structure
grads = grad(loss)(params, x=jnp.array([1., 0., 0.]), y=2.0)
# grads = {'w': array([...]), 'b': array(...)}

6. Hardware: CPU vs GPU vs TPU

This is one of the areas where JAX differs most clearly from NumPy and PyTorch.

Library CPU GPU TPU Notes
NumPy CPU only. GPU requires CuPy or Numba as a separate replacement.
PyTorch ⚠️ Excellent GPU support via CUDA. TPU requires torch_xla, which is a separate install and not always seamless.
JAX All three are first-class targets. The same code runs on any backend — XLA handles code generation for each.

The reason JAX handles all three is architectural. PyTorch's GPU support is tensor-level: individual operations are dispatched to CUDA kernels. JAX's approach is to compile the entire function to the target hardware through XLA, which has backends for CPU (via LLVM), NVIDIA GPUs (via CUDA/PTX), and Google TPUs (natively). The same Python code produces different machine code depending on where you're running, without any changes from you.

In practice, you can check and control the device placement like this:

import jax

# See available devices
print(jax.devices())
# [CpuDevice(id=0)]  or  [GpuDevice(id=0, process_index=0)]

# Move data to a specific device
x = jax.device_put(jnp.array([1.0, 2.0]), jax.devices()[0])

# Check where an array lives
print(x.device())  # CpuDevice(id=0)

7. Writing XLA-Friendly Code

Knowing that XLA is doing the heavy lifting, there are concrete patterns that let it optimize well and patterns that prevent it from doing so.

Write large functions, not small fragments

# Bad: three separate operations, no fusion opportunity
y = jnp.sin(x)
z = jnp.cos(y)
w = z + 1.0

# Good: one compiled function, XLA sees everything and can fuse
@jit
def f(x):
    y = jnp.sin(x)
    z = jnp.cos(y)
    return z + 1.0

Replace Python loops with lax primitives

# Bad: Python loop unrolls inside jit — slow compilation, large HLO
@jit
def f(x):
    for i in range(1000):
        x = x + 1.0
    return x

# Good: lax.fori_loop compiles to a single loop kernel
from jax import lax

@jit
def f(x):
    def body(i, val):
        return val + 1.0
    return lax.fori_loop(0, 1000, body, x)

Use vmap instead of for loops over batches

# Bad: Python loop, no parallelism
results = [f(x) for x in xs]

# Good: single vectorized kernel
from jax import vmap, jit

batched_f = jit(vmap(f))
results = batched_f(xs)

Keep shapes static

JIT caches compiled kernels by input shape. If the shape changes between calls, JAX recompiles. For shapes that genuinely vary, you can use padding to a fixed maximum size, or use jax.jit with static_argnums to tell it which arguments are shape-determining:

from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, size):
    return x[:size]

f(x, 10)   # compiles for size=10
f(x, 20)   # compiles again for size=20 (different static arg)

Minimize host-device transfers

# Bad: pulling data to CPU every iteration
for i in range(1000):
    loss_val = jit(loss_fn)(params, batch)
    print(loss_val)   # forces synchronization and device-to-host transfer

# Better: only transfer when you actually need the value
loss_val = jit(loss_fn)(params, batch)
# ... do more computation ...
print(float(loss_val))   # transfer once

8. A Working Optimization Example

Putting it together: here's gradient descent implemented in JAX, demonstrating how all the pieces interact.

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import jax.random as random

# --- Data ---
key = random.PRNGKey(0)
key, subkey = random.split(key)

# True parameters we're trying to recover
w_true = jnp.array([2.0, -1.0, 0.5])
X = random.normal(subkey, (100, 3))
y = X @ w_true + 0.1 * random.normal(key, (100,))

# --- Model and loss ---
def predict(w, x):
    return jnp.dot(x, w)

def loss_fn(w, X, y):
    preds = vmap(predict, in_axes=(None, 0))(w, X)
    return jnp.mean((preds - y)**2)

# --- Training step, compiled ---
@jit
def step(w, X, y, lr=0.01):
    loss, grads = jax.value_and_grad(loss_fn)(w, X, y)
    w = w - lr * grads
    return w, loss

# --- Run optimization ---
key, subkey = random.split(key)
w = random.normal(subkey, (3,))

for i in range(200):
    w, loss = step(w, X, y)
    if i % 50 == 0:
        print(f"Step {i:3d} | loss: {loss:.4f} | w: {w}")

print(f"\nTrue w:      {w_true}")
print(f"Recovered w: {w}")

A few things to notice. jax.value_and_grad returns both the function value and its gradient in one pass — more efficient than computing them separately. vmap(predict, in_axes=(None, 0)) says "batch over the second argument (x) but not the first (w)". And the entire training step, including the gradient computation, is JIT-compiled.

9. Neural Networks: Flax and Haiku

JAX itself doesn't have a built-in neural network layer API — that's intentional. The core library stays minimal and composable; higher-level APIs are built on top. The two dominant ones are Flax (from Google) and Haiku (from DeepMind).

Both follow JAX's functional philosophy: model parameters are stored in explicit pytree structures and passed as arguments, rather than being hidden inside objects with mutable state. Here's a minimal Flax example:

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax

# Define model
class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(64)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x

# Initialize — note: params are returned explicitly, not stored inside the model
model = MLP()
key = jax.random.PRNGKey(0)
params = model.init(key, jnp.ones((1, 10)))

# Forward pass
out = model.apply(params, jnp.ones((32, 10)))

# Loss and training step
def loss_fn(params, x, y):
    pred = model.apply(params, x)
    return jnp.mean((pred - y)**2)

@jax.jit
def train_step(params, x, y, lr=0.001):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    params = jax.tree_util.tree_map(lambda p, g: p - lr * g, params, grads)
    return params, loss

The pattern of model.init returning explicit parameters and model.apply taking them as an argument is distinctively JAX. There's no hidden state — everything is in the pytree you pass around.

10. JAX vs PyTorch: An Honest Comparison

Aspect JAX PyTorch
Execution model Functional + compiled (via XLA) Eager by default; torch.compile optional
Automatic differentiation Strong — composes with vmap, jit, etc. Strong — autograd is mature and well-documented
Hardware support CPU, GPU, TPU — all first-class CPU, GPU (excellent); TPU via torch_xla
Debugging Harder — tracing makes stack traces confusing Easier — eager execution, intuitive errors
Ecosystem Smaller but growing (Flax, Haiku, Optax, Equinox) Mature and large (Hugging Face, Lightning, etc.)
Code style Functional — explicit state, no mutation Object-oriented — modules with internal state
Learning curve Steeper — requires understanding functional patterns Gentler — more familiar Python/OOP conventions
Research adoption Growing fast in academic ML (DeepMind, etc.) Dominant in industry and academia overall

The honest summary: if you're building products or need a large ecosystem, PyTorch is the pragmatic choice. If you're doing research that involves novel training algorithms, custom differentiable computations, or anything that runs at scale on TPUs, JAX's design pays off. The functional model is genuinely more expressive for composing transformations, and the XLA backend gives you hardware flexibility that PyTorch requires more work to achieve.

11. Common Pitfalls

Forgetting that jit captures external variables. Covered above, but it will catch you. Anything that shouldn't be baked in at compile time needs to be passed as an explicit argument.

Using Python loops for long sequences inside jit. They unroll. Use lax.scan, lax.fori_loop, or lax.while_loop instead.

Expecting in-place updates. JAX arrays are immutable. x[i] = v raises an error. Use x.at[i].set(v), which returns a new array.

Not splitting keys properly. Reusing the same random key gives the same result every time. Always split before using:

key = random.PRNGKey(0)

# Wrong: same key used twice → same output
x = random.normal(key, (3,))
y = random.normal(key, (3,))   # identical to x

# Correct: split the key before each use
key, subkey1 = random.split(key)
key, subkey2 = random.split(key)
x = random.normal(subkey1, (3,))
y = random.normal(subkey2, (3,))

Dynamic shapes inside jit. Shapes that depend on array values cause recompilation. If you find jit is recompiling on every call, check whether your function's output shape is fixed given the input shape.

Small computations with jit. Compilation has overhead. For tiny calculations, jit can be slower than just running eagerly. Profile before assuming jit always helps.

12. Where to Go From Here

If this post made JAX click conceptually, the practical next step is to install it and try the examples above. The official JAX documentation is well-written and has good tutorials on the core transforms. For neural networks, Flax's documentation includes full training loop examples.

The things worth understanding next, roughly in order: how pytrees work in more depth (they're everywhere in JAX code), how lax.scan handles state for sequential models, and how pmap works for multi-GPU training. Once those are clear, you're in good shape to read most JAX research code.

JAX has a steeper ramp than PyTorch, but the concepts it forces you to understand — pure functions, explicit state, composable transformations, compilation — are genuinely useful beyond JAX itself. Most of the ideas transfer to understanding other systems, and the style of thinking tends to make for cleaner code regardless of what library you're ultimately using.