Setting num_particles in a custom ELBO

I would like to write a custom ELBO function and set num_particles > 1. num_particles is easy to set with the existing ELBO classes (e.g., Trace_ELBO), but how can one implement this in a custom ELBO, like this simple_elbo?

Hi @dilara, you can just wrap your model and guide in an additional pyro.plate context inside the ELBO implementation:

def simple_elbo_vectorized(model, guide, num_particles, max_plate_nesting, *args, **kwargs):
    # max_plate_nesting ensures particle plate does not collide with model plates
    with pyro.plate("elbo_particles", size=num_particles, dim=-max_plate_nesting-2):
        guide_trace = trace(guide).get_trace(*args, **kwargs)
        model_trace = trace(replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
    # divide by num_particles
    return -1. / num_particles * (model_trace.log_prob_sum() - guide_trace.log_prob_sum())
1 Like