I was curious if pyro
would easily enable putting a Gaussian mixture model (GMM) as the prior on the latent space of a VAE. I took the VAE tutorial code and changed the model
to the following using MixtureSameFamily
. It runs (on MNIST) but after 30 or so epochs I get NaNs. I wanted to confirm this implementation is correct in principle before trying to debug the numerical issues. For the latter I wonder if the HalfCauchy on the scales isn’t ideal in this situation and I’d be better served with something more friendly like Gamma(2,2). Thanks!
# define the model p(x|z)p(z)
def model(self, x):
# register PyTorch module `decoder` with Pyro
pyro.module("decoder", self.decoder)
# mixture proportions
pi = pyro.sample("pi", dist.Dirichlet(torch.ones(self.K, device=x.device)))
mix = dist.Categorical(pi)
one = torch.tensor(1., device=x.device)
# component centers and scales
mix_locs = pyro.sample("mix_locs", dist.Normal(0*one,one).expand([self.K,self.z_dim]).to_event(2))
mix_scales = pyro.sample("mix_scales", dist.HalfCauchy(one).expand([self.K,self.z_dim]).to_event(2))
# construct gaussian mixture model
comp = dist.Normal(mix_locs,mix_scales).to_event(1)
gmm = dist.MixtureSameFamily(mix, comp)
with pyro.plate("data", x.shape[0]):
# sample from prior (value will be sampled by guide when computing the ELBO)
z = pyro.sample("latent", gmm)
# decode the latent code z
loc_img = self.decoder(z)
# score against actual images
pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))
return loc_img
Edit: I missed that doing this naively ‘mix_locs’, ‘pi’, ‘mix_scales’ aren’t learnt because they aren’t part of vae.guide
. I thought this would work:
guide = AutoGuideList(vae.model)
guide.append(vae.guide)
guide.append(AutoDelta(poutine.block(vae.model, expose = ["pi", "mix_locs", "mix_scales"])))
svi = SVI(vae.model, guide, optimizer, loss=Trace_ELBO())
but is giving me RuntimeError: Multiple sample sites named 'data'
which I don’t understand.