I would like to write a custom ELBO function and set num_particles
> 1. num_particles
is easy to set with the existing ELBO classes (e.g., Trace_ELBO), but how can one implement this in a custom ELBO, like this simple_elbo?
Hi @dilara, you can just wrap your model and guide in an additional pyro.plate
context inside the ELBO implementation:
def simple_elbo_vectorized(model, guide, num_particles, max_plate_nesting, *args, **kwargs):
# max_plate_nesting ensures particle plate does not collide with model plates
with pyro.plate("elbo_particles", size=num_particles, dim=-max_plate_nesting-2):
guide_trace = trace(guide).get_trace(*args, **kwargs)
model_trace = trace(replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
# divide by num_particles
return -1. / num_particles * (model_trace.log_prob_sum() - guide_trace.log_prob_sum())
1 Like