Hello, I read the tutorial about pyro - Gaussian Mixture Model 1-D :
My problem is on Iris data set :
I have three different flowers with four attributes,
I’m trying to construct 4-D gmm with 3 clusters:
I followed the tutorial above, and I manage to do 12 different mu’s as I needed, but I can’t do so for sigmas(multivariate don’t work there)
this is my model till now :
def gmm(N,data=None):
sigma = numpyro.sample(‘sigma’,dist.LogNormal(1.))
with numpyro.plate(‘components’,K):
mu = numpyro.sample(‘mu’,dist.MultivariateNormal(jnp.zeros(4),10.*jnp.eye(4)))
with numpyro.plate(‘N’,N):
category = numpyro.sample(‘category’,dist.Categorical(jnp.ones(K)/K))
numpyro.sample(‘obs’,dist.MultivariateNormal(mu[category],sigma * jnp.eye(4)),obs=data)
i want my code to look something like this :
def gmm(N,data=None):
sigma = numpyro.sample(‘sigma’,dist.MultivariateLogNormal(jnp.zeros(4),10.*jnp.eye(4)))
with numpyro.plate(‘components’,K):
mu = numpyro.sample(‘mu’,dist.MultivariateNormal(jnp.zeros(4),10.*jnp.eye(4)))
with numpyro.plate(‘N’,N):
category = numpyro.sample(‘category’,dist.Categorical(jnp.ones(K)/K))
numpyro.sample(‘obs’,dist.MultivariateNormal(mu[category] ,sigma[category] ),obs=data)
Could you help me construct it?
Thanks