JAX dispatcher
Static arguments
Where they are needed
shape
parameters- Scan's
length
parameter
Hashable static arguments to JIT compiled function
i.e. no list, numpy array.
TypeError: Shape
Related issues:
There are two underlying issues:
- JAX needs shapes to be determined at tracing time.
Random variables need size specified as tuples (see docstring of the dirichlet distribution)
Let's reproduce the example exactly:
import jax shape = jax.numpy.array([1000]) def jax_funcified(prng_key): return jax.random.normal(prng_key, shape) key = jax.random.PRNGKey(0) try: jax.jit(jax_funcified)(key) except Exception as e: print(e)
import jax import numpy as np shape = np.array([10]) def jax_funcified(prng_key): return jax.random.normal(prng_key, shape) key = jax.random.PRNGKey(0) print(jax.jit(jax_funcified)(key))
import jax import numpy as np rng_key = jax.random.PRNGKey(0) try: print(jax.random.normal(rng_key, shape=10)) except Exception as e: print(e) print(jax.random.normal(rng_key, shape=[3])) print(jax.random.normal(rng_key, shape=(3,))) print(jax.random.normal(rng_key, shape=np.array([3]))) print(jax.random.normal(rng_key, shape=jax.numpy.array([3])))
import jax def fun(x): rng_key = jax.random.PRNGKey(0) return jax.random.normal(rng_key, shape=x) try: jax.jit(fun)(1) except Exception as e: print(f"shape as int: {e}") try: jax.jit(fun)([1, 2]) except Exception as e: print(f"shape as list: {e}") try: jax.jit(fun)((1, 2)) except Exception as e: print(f"shape as tuple: {e}") # using static_argnums res = jax.jit(fun, static_argnums=(0,))((1, 2)) print(f"shape as tuple (static argnum): {res}") try: res = jax.jit(fun, static_argnums=(0,))([1, 2]) except Exception as e: print(f"shape as list (static_argnums): {e}")