I’m trying to implement a GMM with nonlinear observations as in:
taken from:
Composing graphical models with neural networks for structured representations and fast inference
I think the model should look like this:
def model(self, y):
pi = pyro.sample("theta_pi", dist.Dirichlet(0.5 * self.K))
with pyro.plate("components", len(self.K)):
theta_scale = pyro.sample("theta_scale", dist.HalfCauchy(torch.ones(d)))
eta = (torch.ones(1) * 1e-4).to(y.device)
theta_cor = pyro.sample("theta_cov", dist.LKJCorrCholesky(self.d, eta))
cov = torch.bmm(theta_scale.sqrt().diag_embed(), theta_cor)
mu = pyro.sample("mu", dist.normal(torch.zeros(self.d), torch.ones(self.d)))
with pyro.plate("observations", y.shape[0]):
assignment = pyro.sample('assignment', dist.Categorical(pi))
x = pyro.sample("latent", dist.MultivariateNormal(mu[assignment], scale_tril=cov[assignment]))
y_loc, y_scale = self.decode(x) # MLP
pyro.sample("obs", dist.Normal(y_loc, y_scale), obs=y).to_event(1)
However, I am a bit unsure about how to write the guide. What should the guide look like to infer cluster probabilities, covariance matrices, and means?