Numpy + Autograd. Use XLA to compile and run NumPy code on accelerators. Asynchronous dispatch, for sync use block_until_ready() import jax.numpy as jnp from jax import random key = random.PRNGKey(0) x = random.normal(key, (10,)) jnp.dot(x, x.T).block_until_ready() notable function: jit() for compilation of multiple computations grad() for performing transformation (autodiff, Jacobian-vector product) vmap() for auto-vectorisation Arrays are immutable in Jax Treat functions as pure as to compiled with XLA entropix/dslider.pyfrom functools import partial from typing import NamedTuple, Tuple import jax import jax.numpy as jnp import jax.scipy as jsp @jax.jit def kl_divergence(logp: jnp.ndarray, logq: jnp.ndarray) -> jnp.ndarray: """Compute KL divergence between two log probability distributions.""" p = jnp.exp(logp) return jnp.sum(jnp.where(p > 0, p * (logp - logq), 0.0), axis=-1) @jax.jit def ent_varent(logp: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: """Compute entropy and varentropy from log probabilities.""" p = jnp.exp(logp) ent = -jnp.sum(p * logp, axis=-1) diff = logp + ent[..., None] varent = jnp.sum(p * diff**2, axis=-1) return ent, varent @jax.jit def normalize_logits(logits: jnp.ndarray, noise_floor: float) -> jnp.ndarray: """Normalize logits to log probabilities with noise floor truncation.""" shifted = logits - jnp.max(logits, axis=-1, keepdims=True) normalized = shifted - jax.nn.logsumexp(shifted + EPS, axis=-1, keepdims=True) # noise floor calculated for bfloat16 return jnp.where(normalized < noise_floor, jnp.log(EPS), normalized) references: github control flow see also link The following works: @jax.jit def f(x): for i in range(3): x = 2 * x return x print(f(3)) @jax.jit def g(x): y = 0.
@jax.jitdef f(x): for i in range(3): x = 2 * x return xprint(f(3))@jax.jitdef g(x): y = 0. for i in range(x.shape[0]): y = y + x[i] return yprint(g(jnp.array([1., 2., 3.])))
doesn't work
@jax.jitdef fail(x): if x < 3: return 3. * x ** 2 else : return -4 * xfail(2)
Reasoning: jit traces code on ShapedArray abstraction, where each abstract value represents the set of all array values with a fixed shape and dtype
type coercion tradeoff
If we trace a Python function on a ShapedArray((), jnp.float32) that isn’t committed to a specific concrete value,
when we hit a line like if x < 3, the expression x < 3 evaluates to an abstract ShapedArray((), jnp.bool_) that represents the set {True, False}.
Fix: you can use static_argnums to specify which argument should be treated as static
@jit(static_argnums=(0,))def f(x): if x < 3: return 3. * x ** 2 else: return -4 * x