GMM With Non-Linear Observations

I’m trying to implement a GMM with nonlinear observations as in:


taken from:
Composing graphical models with neural networks for structured representations and fast inference

I think the model should look like this:

    def model(self, y):
        pi = pyro.sample("theta_pi", dist.Dirichlet(0.5 * self.K))
        with pyro.plate("components", len(self.K)):
            theta_scale = pyro.sample("theta_scale", dist.HalfCauchy(torch.ones(d)))
            eta = (torch.ones(1) * 1e-4).to(y.device)
            theta_cor = pyro.sample("theta_cov", dist.LKJCorrCholesky(self.d, eta))
            cov = torch.bmm(theta_scale.sqrt().diag_embed(), theta_cor)
            mu = pyro.sample("mu", dist.normal(torch.zeros(self.d), torch.ones(self.d)))
        
        with pyro.plate("observations", y.shape[0]):
            assignment = pyro.sample('assignment', dist.Categorical(pi))
            x = pyro.sample("latent", dist.MultivariateNormal(mu[assignment], scale_tril=cov[assignment]))
            y_loc, y_scale = self.decode(x) # MLP
            pyro.sample("obs", dist.Normal(y_loc, y_scale), obs=y).to_event(1)

However, I am a bit unsure about how to write the guide. What should the guide look like to infer cluster probabilities, covariance matrices, and means?

I would start with an AutoDelta guide and use TraceEnum_ELBO and @config_enumerate to integrate out the assignment variables. IIUC this should be nearly identical to the GMM tutorial.