Summing out latent discrete parameters in mixture model

I am trying to implement a univariate Gaussian mixture model that I previously had written in Stan. As described in Stan, it is possible to sum out the discrete parameters. Is it possible to do this in Pyro? The documentation for MixtureOfDiagNormals looks similar to what I want to do, but it does not support D=1.

Relevant stan code below:

data {
    // data
    int N; // number observations
    int K; // number clusters
    vector[N] y;
    
    // priors
    vector<lower=0>[K] mu0;
    real<lower=0> sigma0;

    vector<lower=0>[K] alpha1;
    vector<lower=0>[K] beta1;

    vector<lower=0>[K] alpha;
}


parameters {
    vector[K] mu;
    vector<lower=0>[K] sigma;
    simplex[K] theta;
}

model {
    real contributions[K];

    // prior
    mu ~ normal(mu0, sigma0);
    sigma ~ gamma(alpha1, beta1);
    theta ~ dirichlet(alpha);

    for (n in 1:N) {
        for (k in 1:K) {
            contributions[k] = log(theta[k]) + normal_lpdf(y[n] | mu[k], sigma[k]);
        }

        target += log_sum_exp(contributions);
    }
}

Hi, Pyro can automatically enumerate discrete variables, so you don’t have write out the marginal likelihood by hand as in Stan. See the Gaussian Mixture Model tutorial for a complete example.

1 Like

Also here’s a notebook for a simpler toy model

1 Like

hi @gbernstein. would you be interested in contributing an example/tutorial based on your example notebook? contributions are always welcome!

@martinjankowiak Sure, I’ll DM you.