Numpy + Autograd. Use XLA to compile and run NumPy code on accelerators.

Asynchronous dispatch, for sync use block_until_ready()

import jax.numpy as jnp
from jax import random
key = random.PRNGKey(0)
x = random.normal(key, (10,))
jnp.dot(x, x.T).block_until_ready()
  • notable function:
    • jit() for compilation of multiple computations
    • grad() for performing transformation (autodiff, Jacobian-vector product)
    • vmap() for auto-vectorisation

Arrays are immutable in Jax

  • Treat functions as pure as to compiled with XLA
from functools import partial
from typing import NamedTuple, Tuple
import jax
import jax.numpy as jnp
import jax.scipy as jsp
def kl_divergence(logp: jnp.ndarray, logq: jnp.ndarray) -> jnp.ndarray:
  """Compute KL divergence between two log probability distributions."""
  p = jnp.exp(logp)
  return jnp.sum(jnp.where(p > 0, p * (logp - logq), 0.0), axis=-1)
def ent_varent(logp: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
  """Compute entropy and varentropy from log probabilities."""
  p = jnp.exp(logp)
  ent = -jnp.sum(p * logp, axis=-1)
  diff = logp + ent[..., None]
  varent = jnp.sum(p * diff**2, axis=-1)
  return ent, varent
def normalize_logits(logits: jnp.ndarray, noise_floor: float) -> jnp.ndarray:
  """Normalize logits to log probabilities with noise floor truncation."""
  shifted = logits - jnp.max(logits, axis=-1, keepdims=True)
  normalized = shifted - jax.nn.logsumexp(shifted + EPS, axis=-1, keepdims=True)
  # noise floor calculated for bfloat16
  return jnp.where(normalized < noise_floor, jnp.log(EPS), normalized)

control flow

The following works:

def f(x):
  for i in range(3): x = 2 * x
  return x
def g(x):
  y = 0.
  for i in range(x.shape[0]): y = y + x[i]
  return y
print(g(jnp.array([1., 2., 3.])))

Reasoning: jit traces code on ShapedArray abstraction, where each abstract value represents the set of all array values with a fixed shape and dtype

type coercion tradeoff

If we trace a Python function on a ShapedArray((), jnp.float32) that isn’t committed to a specific concrete value, when we hit a line like if x < 3, the expression x < 3 evaluates to an abstract ShapedArray((), jnp.bool_) that represents the set {True, False}.

Fix: you can use static_argnums to specify which argument should be treated as static

def f(x):
  if x < 3:
    return 3. * x ** 2
    return -4 * x


How does JAX handle memory buffers?

