@jax.jitdef f(x): for i in range(3): x = 2 * x return xprint(f(3))@jax.jitdef g(x): y = 0. for i in range(x.shape[0]): y = y + x[i] return yprint(g(jnp.array([1., 2., 3.])))
doesn't work
@jax.jitdef fail(x): if x < 3: return 3. * x ** 2 else : return -4 * xfail(2)
Reasoning: jit traces code on ShapedArray abstraction, where each abstract value represents the set of all array values with a fixed shape and dtype
type coercion tradeoff
If we trace a Python function on a ShapedArray((), jnp.float32) that isn’t committed to a specific concrete value,
when we hit a line like if x < 3, the expression x < 3 evaluates to an abstract ShapedArray((), jnp.bool_) that represents the set {True, False}.
Fix: you can use static_argnums to specify which argument should be treated as static
@jit(static_argnums=(0,))def f(x): if x < 3: return 3. * x ** 2 else: return -4 * x