TraceEnum_ELBO reference

Is there any reference that explains the algorithm behind TraceEnum_ELBO and how it is different from Trace_ELBO? To be specific, I don’t understand why in the GMM example (Gaussian Mixture Model — Pyro Tutorials 1.8.4 documentation) if I want to infer discrete local variables (z) I cannot use Trace_ELBO? I tried removing @config_enumerate decorators and using Trace_ELBO which runs but parameters do not converge to correct values. Thanks

K = 2  # Fixed number of components.

#@config_enumerate
def model(data):
    # Global variables.
    weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
    scale = pyro.sample('scale', dist.LogNormal(0., 2.))
    with pyro.plate('components', K):
        locs = pyro.sample('locs', dist.Normal(0., 10.))

    with pyro.plate('data', len(data)):
        # Local variables.
        assignment = pyro.sample('assignment', dist.Categorical(weights))
        pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)

global_guide = AutoDelta(poutine.block(model, expose=['weights', 'locs', 'scale']))

#@config_enumerate
def full_guide(data):
    # Global variables.
    #with poutine.block(hide_types=["param"]):  # Keep our learned values of global parameters.
    global_guide(data)

    # Local variables.
    with pyro.plate('data', len(data)):
        assignment_probs = pyro.param('assignment_probs', torch.ones(len(data), K) / K,
                                      constraint=constraints.unit_interval)
        pyro.sample('assignment', dist.Categorical(assignment_probs))
        
optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
#elbo = TraceEnum_ELBO(max_plate_nesting=1)
elbo = Trace_ELBO()
svi = SVI(model, full_guide, optim, loss=elbo)

TraceEnum_ELBO marginalizes out discrete latent variables, and therefore produces much lower-variance estimates of gradients during training. That allows it to use a higher learning rate and train more quickly and reliably than Trace_ELBO (in relevant models).

The variable elimination algorithm used by TraceEnum_ELBO is described in (Obermeyer et al. 2019). Additionally TraceEnum_ELBO constructs a differentiable elbo loss using DiCE factors following (Foerster et al. 2018).

Thank you!