Gaussian mixture model in latent space of VAE

I was curious if pyro would easily enable putting a Gaussian mixture model (GMM) as the prior on the latent space of a VAE. I took the VAE tutorial code and changed the model to the following using MixtureSameFamily. It runs (on MNIST) but after 30 or so epochs I get NaNs. I wanted to confirm this implementation is correct in principle before trying to debug the numerical issues. For the latter I wonder if the HalfCauchy on the scales isn’t ideal in this situation and I’d be better served with something more friendly like Gamma(2,2). Thanks!

    # define the model p(x|z)p(z)
    def model(self, x):
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)
        
        # mixture proportions
        pi = pyro.sample("pi", dist.Dirichlet(torch.ones(self.K, device=x.device)))
        mix = dist.Categorical(pi)
        one = torch.tensor(1., device=x.device)

        # component centers and scales
        mix_locs = pyro.sample("mix_locs", dist.Normal(0*one,one).expand([self.K,self.z_dim]).to_event(2))
        mix_scales = pyro.sample("mix_scales", dist.HalfCauchy(one).expand([self.K,self.z_dim]).to_event(2))

        # construct gaussian mixture model
        comp = dist.Normal(mix_locs,mix_scales).to_event(1)
        gmm = dist.MixtureSameFamily(mix, comp)

        with pyro.plate("data", x.shape[0]):
            # sample from prior (value will be sampled by guide when computing the ELBO)
            z = pyro.sample("latent", gmm)
            # decode the latent code z
            loc_img = self.decoder(z)
            # score against actual images
            pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))
        return loc_img

Edit: I missed that doing this naively ‘mix_locs’, ‘pi’, ‘mix_scales’ aren’t learnt because they aren’t part of vae.guide. I thought this would work:

guide = AutoGuideList(vae.model)
guide.append(vae.guide)
guide.append(AutoDelta(poutine.block(vae.model, expose = ["pi", "mix_locs", "mix_scales"])))
svi = SVI(vae.model, guide, optimizer, loss=Trace_ELBO())

but is giving me RuntimeError: Multiple sample sites named 'data' which I don’t understand.

1 Like

looks fine to me. it might be hard to get to work. the easier version would do MAP wrt to pi, mix_locs, and mix_scales

1 Like

Thanks Martin, just edited after realizing I wasn’t actually learning the GMM (and agree MAP for that feels like it should be enough). Any idea where that error is coming from?

you probably need to use create_plates see here for usage examples

1 Like

Hmm that was giving me the same error.

What seems to be working is creating the guide for the GMM parameters in VAE.init, and then calling that in vae.guide, full notebook here.

I suspect the weighting of the likelihood vs prior is incorrect though (for the original VAE actually as well), since the inference is never told how big the full data is. Presumably I should do something like passing subsample_size to the data plate?

since you have global variables you need something like

pyro.plate("data", size=total_number_of_datapoints, subsample_size=len(x))

for the original VAE it doesn’t matter since there are no global latent variables. or rather it matters for whether the ELBO is scaled to the whole data set size or not but those details don’t matter for optimization

1 Like

Thanks. You’re right of course it’s fine for the original VAE - it would only matter if you were interested in being (variational) Bayesian about the decoder weights (which no one really seems to do).