Sequential Enumeration - Basics

  • What tutorial are you running?
  • What version of Pyro are you using?
  • Please link or paste relevant code, and steps to reproduce.

Here is the code I’m running:

def model():
    z = pyro.sample("z", dist.Categorical(torch.ones(5)),infer={"enumerate": "sequential"})
    print(f"model z = {z}")

def guide():

elbo = TraceEnum_ELBO(max_plate_nesting=0)
#elbo.loss(model, config_enumerate(guide, "sequential"));
elbo.loss(model, guide);

I’m trying to understand the “sequential” enumeration, since I’ll likely be using this. When I run this code, I get this error:

At site 'z', model-side sequential enumeration is not implemented. Try parallel enumeration or guide-side enumeration.

My question is: like parallel enumeration, shouldn’t this variable get marginalized out, so we don’t need to construct the discrete variable in the guide? I’m thinking this is the cause of this error coming up. Can someone clarify this?

as per the doc string you either need to use parallel enumeration (which requires using correct parallelizable indexing logic with discrete latent variables in the model) or guide-side enumeration which means that the guide needs to define approximate posterior distributions for each discrete latent variable (in your example you’d need a sample statement for z in the guide).

model-side sequential enumeration is not an option because it is not implemented.

Thanks a ton!