Is there any reference that explains the algorithm behind TraceEnum_ELBO and how it is different from Trace_ELBO? To be specific, I don’t understand why in the GMM example (Gaussian Mixture Model — Pyro Tutorials 1.8.4 documentation) if I want to infer discrete local variables (z) I cannot use Trace_ELBO? I tried removing @config_enumerate decorators and using Trace_ELBO which runs but parameters do not converge to correct values. Thanks
K = 2 # Fixed number of components.
#@config_enumerate
def model(data):
# Global variables.
weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
scale = pyro.sample('scale', dist.LogNormal(0., 2.))
with pyro.plate('components', K):
locs = pyro.sample('locs', dist.Normal(0., 10.))
with pyro.plate('data', len(data)):
# Local variables.
assignment = pyro.sample('assignment', dist.Categorical(weights))
pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)
global_guide = AutoDelta(poutine.block(model, expose=['weights', 'locs', 'scale']))
#@config_enumerate
def full_guide(data):
# Global variables.
#with poutine.block(hide_types=["param"]): # Keep our learned values of global parameters.
global_guide(data)
# Local variables.
with pyro.plate('data', len(data)):
assignment_probs = pyro.param('assignment_probs', torch.ones(len(data), K) / K,
constraint=constraints.unit_interval)
pyro.sample('assignment', dist.Categorical(assignment_probs))
optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
#elbo = TraceEnum_ELBO(max_plate_nesting=1)
elbo = Trace_ELBO()
svi = SVI(model, full_guide, optim, loss=elbo)