Hi.
I was building on this previous post to implement a 2D gaussian mixture. Instead of having a shared covariance with constant diagonal for all components, I wanted different, arbitrary diagonal covariances. After a lot of tweaking, I finally got some idea of how batch and event shapes take part in all this and got this model to work:
@config_enumerate
def gaussian_enum(K,data=None):
dim, num = data.shape
weights = numpyro.sample('weights', dist.Dirichlet(jnp.ones(K)))
with numpyro.plate("variables", dim):
with numpyro.plate("components", K):
sigmas = numpyro.sample("sigmas", dist.Exponential(1))
with numpyro.plate('components', K):
locs = numpyro.sample('locs',dist.MultivariateNormal(jnp.zeros(dim),5*jnp.eye(dim)))
with numpyro.plate('data', num):
assignment = numpyro.sample('assignment', dist.Categorical(weights))
cv = jnp.apply_along_axis(jnp.diag, -1, sigmas[assignment,:])
numpyro.sample(
'obs',
dist.MultivariateNormal(locs[assignment, :], covariance_matrix=cv),
obs=data.T
)
Even so, the use of jnp.apply_along_axis
looks like a hack and I am left wondering whether there is a more elegant way of achieving the same effect. I’ve seen examples where a LKJCholesky distribution is used, but I think that may not be necessary when one is only interested in diagonal covariances.
Thank you.