Waste-free Sequential Monte Carlo
Algorithm from this article.
- How many MCMC steps needed for optimal performace? (always set arbitratily)
- Number of steps should be set adaptively, especially for situation like tmpering.
Intermediate steps are wasted.
Waste-free SMC uses intermediate steps as new particles.
Principle
If there are \(N\) particles, we only resample \(M = N/P\) particles. Each resampled particle is moved \(P-1\) times and each iterate is taken to form a new sample of size \(N\).
How can we integrate both processes in the same algorithm?
The implement of sequential MC first consists in a resampling step:
resampling_idx = resampling_fn(weights, rng_key, num_resampled) particles = jax.tree_map(lambda x: x[resampling_index], particles)
We can only resample between \(1\) and \(N\) particles. But since we dont'
Waste-free SMC may be recast as a standard SMC sampler that propopagates and reweighs particles that are markov chains of length P.
Indeed we have \(M\) chains of length \(P\).
\(z \in \mathcal{Z} = \Xi^P\). The potential functions are as:
\begin{equation} G^{wf}(z) = \frac{1}{P} \sum_{p=1}^{P} G_{t}(z[p]) \end{equation}and the initial distribution
\begin{equation} \nu^{wf}(\mathrm{d}z) = \prod_{p=1}^{P} \nu(\mathrm{d}z[p]) \end{equation}And the transition kernel:
\begin{equation} M_{t}^{wf}(z_{t-1}, \mathrm{d}z_{t}) = \left\{\sum_{p=1}^{P}\right\} \end{equation}
- Choose one chain of length \(P\) with probability \(\propto \sum_p G_{t-1}(z_{t-1}[p])\);
- Choose one component \(q\) randomly with probability \(\propto G_{t-1}(z_{t-1}[q])\) to be the starting point of the next chain;
- Repeat
Which leads to the algorithm 2 in the paper:
- sampledidx <- resample(M, weights)
- sampledparticles = particles[sampledidx]
- traces = jax.scan(jax.vmap(mcmc)(keys, newparticles), P)
- particles = traces.flatten()
- weights = G(particles)
- return particles, weights
The equivalent algorithm 1 would be:
- sampledparticles = particles
- traces = jax.scan(jax.vmap(mcmc)(keys, newparticles), P)
- particles = traces[-1, :]
- weights = G(particles)
- return particles, weights
Code
Let us now refactor the original smc code:
def smc( mcmc_kernel_factory: Callable, mcmc_state_generator: Callable, resampling_fn: Callable, num_mcmc_iterations: int, is_waste_free: bool = False, ): if is_waste_free: num_mcmc_iterations = num_mcmc_iterations - 1 def kernel( rng_key: PRNGKey, state: SMCState, logprob_fn: Callable, log_weight_fn: Callable, ) -> Tuple[SMCState, SMCInfo]: weights, particles = state scan_key, resampling_key = jax.random.split(rng_key, 2) num_particles = weights.shape[0] if is_waste_free: sub_num_particles, remainder = divmod( num_particles, num_mcmc_iterations + 1 ) if remainder > 0: raise ValueError( "`num_mcmc_iterations` must be a divider " f"of `num_particles`, {num_mcmc_iterations} and " f"{num_particles} were given" ) else: sub_num_particles = num_particles resampling_index = resampling_fn(weights, resampling_key, sub_num_particles) particles = jax.tree_map(lambda x: x[resampling_index], particles) # First advance the particles using the MCMC kernel mcmc_kernel = mcmc_kernel_factory(logprob_fn) def mcmc_body_fn(carry, curr_key): curr_particles, n_accepted = carry keys = jax.random.split(curr_key, sub_num_particles) new_particles, mcmc_info = jax.vmap(mcmc_kernel, in_axes=(0, 0))( keys, curr_particles ) n_accepted = n_accepted + mcmc_info.is_accepted return (new_particles, n_accepted), new_particles mcmc_state = jax.vmap(mcmc_state_generator, in_axes=(0, None))( particles, logprob_fn ) keys = jax.random.split(scan_key, num_mcmc_iterations) (proposed_states, total_accepted), proposed_states_history = jax.lax.scan( mcmc_body_fn, (mcmc_state, jnp.zeros((sub_num_particles,))), keys ) acceptance_rate = jnp.mean(total_accepted / num_mcmc_iterations) if is_waste_free: initial_position, tree_def = jax.tree_flatten(mcmc_state.position) chains_history, _ = jax.tree_flatten(proposed_states_history.position) position_history = [ jnp.concatenate([jnp.expand_dims(elem1, 0), elem2]) for elem1, elem2 in zip(initial_position, chains_history) ] position_history = jax.tree_unflatten(tree_def, position_history) proposed_particles = jax.tree_map( lambda z: jnp.reshape(z, (num_particles,) + z.shape[2:]), position_history, ) else: proposed_particles = proposed_states.position # Resample the particles depending on their respective weights log_weights = jax.vmap(log_weight_fn, in_axes=(0,))(proposed_particles) weights, log_likelihood_increment = _normalize(log_weights) state = SMCState(weights, proposed_particles) info = SMCInfo(resampling_index, log_likelihood_increment, acceptance_rate) return state, info return kernel
First separate vanilly and waste-free SMC: