Blackjax On Aesara
import aesara.tensor as at srng = at.random.RandomStream(0) sigma_rv = srng.normal(1.) mu_rv = srng.normal(0, 1) Y_rv = srng.normal(mu_rv, sigma_rv)
Sample prior
Sampling from the prior predictive distribution is a useful tool for debugging. One thus needs this function to be:
- fast
- Easy to use
- Easy to customize; is it easy to change the value of a parameter.
import aesara from aesara.graph.basic import io_toposort from aesara.tensor.random.op import RandomVariable import jax rng_key = jax.random.PRNGKey(3) def count_model_rvs(rv_out): """Count the number of `RandomVariable` in a model""" return len([node for node in io_toposort([], [rv_out]) if isinstance(node.op, RandomVariable)]) def split_to_tuple(rng_key, num): keys = jax.numpy.split(jax.random.split(rng_key, num), num) return tuple([key.squeeze() for key in keys]) def prior_sample(rng_key, num_samples, rv_out): """Return prior predictive samples""" prior_fn = aesara.function([], Y_rv, mode="JAX").vm.jit_fn num_rvs = count_model_rvs(Y_rv) def take_one_sample(rng_key): keys = [{"jax_state": key} for key in split_to_tuple(rng_key, num_rvs)] return prior_fn(*keys)[0] return jax.vmap(take_one_sample)(jax.random.split(rng_key, num_samples))
samples = prior_sample(rng_key, 10, Y_rv)
samples
prior_sampler(Y_rv, mu_rv).run(rng_key, 1000, {a_tt: 1.})