I am trying to implement a univariate Gaussian mixture model that I previously had written in Stan. As described in Stan, it is possible to sum out the discrete parameters. Is it possible to do this in Pyro? The documentation for MixtureOfDiagNormals
looks similar to what I want to do, but it does not support D=1.
Relevant stan code below:
data {
// data
int N; // number observations
int K; // number clusters
vector[N] y;
// priors
vector<lower=0>[K] mu0;
real<lower=0> sigma0;
vector<lower=0>[K] alpha1;
vector<lower=0>[K] beta1;
vector<lower=0>[K] alpha;
}
parameters {
vector[K] mu;
vector<lower=0>[K] sigma;
simplex[K] theta;
}
model {
real contributions[K];
// prior
mu ~ normal(mu0, sigma0);
sigma ~ gamma(alpha1, beta1);
theta ~ dirichlet(alpha);
for (n in 1:N) {
for (k in 1:K) {
contributions[k] = log(theta[k]) + normal_lpdf(y[n] | mu[k], sigma[k]);
}
target += log_sum_exp(contributions);
}
}