# Gaussian mixture model in latent space of VAE

I was curious if `pyro` would easily enable putting a Gaussian mixture model (GMM) as the prior on the latent space of a VAE. I took the VAE tutorial code and changed the `model` to the following using `MixtureSameFamily`. It runs (on MNIST) but after 30 or so epochs I get NaNs. I wanted to confirm this implementation is correct in principle before trying to debug the numerical issues. For the latter I wonder if the HalfCauchy on the scales isn’t ideal in this situation and I’d be better served with something more friendly like Gamma(2,2). Thanks!

``````    # define the model p(x|z)p(z)
def model(self, x):
# register PyTorch module `decoder` with Pyro
pyro.module("decoder", self.decoder)

# mixture proportions
pi = pyro.sample("pi", dist.Dirichlet(torch.ones(self.K, device=x.device)))
mix = dist.Categorical(pi)
one = torch.tensor(1., device=x.device)

# component centers and scales
mix_locs = pyro.sample("mix_locs", dist.Normal(0*one,one).expand([self.K,self.z_dim]).to_event(2))
mix_scales = pyro.sample("mix_scales", dist.HalfCauchy(one).expand([self.K,self.z_dim]).to_event(2))

# construct gaussian mixture model
comp = dist.Normal(mix_locs,mix_scales).to_event(1)
gmm = dist.MixtureSameFamily(mix, comp)

with pyro.plate("data", x.shape[0]):
# sample from prior (value will be sampled by guide when computing the ELBO)
z = pyro.sample("latent", gmm)
# decode the latent code z
loc_img = self.decoder(z)
# score against actual images
pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))
return loc_img
``````

Edit: I missed that doing this naively ‘mix_locs’, ‘pi’, ‘mix_scales’ aren’t learnt because they aren’t part of `vae.guide`. I thought this would work:

``````guide = AutoGuideList(vae.model)
guide.append(vae.guide)
guide.append(AutoDelta(poutine.block(vae.model, expose = ["pi", "mix_locs", "mix_scales"])))
svi = SVI(vae.model, guide, optimizer, loss=Trace_ELBO())
``````

but is giving me `RuntimeError: Multiple sample sites named 'data'` which I don’t understand.

1 Like

looks fine to me. it might be hard to get to work. the easier version would do MAP wrt to `pi`, `mix_locs`, and `mix_scales`

1 Like

Thanks Martin, just edited after realizing I wasn’t actually learning the GMM (and agree MAP for that feels like it should be enough). Any idea where that error is coming from?

you probably need to use `create_plates` see here for usage examples

1 Like

Hmm that was giving me the same error.

What seems to be working is creating the guide for the GMM parameters in VAE.init, and then calling that in `vae.guide`, full notebook here.

I suspect the weighting of the likelihood vs prior is incorrect though (for the original VAE actually as well), since the inference is never told how big the full data is. Presumably I should do something like passing `subsample_size` to the `data` plate?

since you have global variables you need something like

`pyro.plate("data", size=total_number_of_datapoints, subsample_size=len(x))`

for the original VAE it doesn’t matter since there are no global latent variables. or rather it matters for whether the ELBO is scaled to the whole data set size or not but those details don’t matter for optimization

1 Like

Thanks. You’re right of course it’s fine for the original VAE - it would only matter if you were interested in being (variational) Bayesian about the decoder weights (which no one really seems to do).