Gmm in 4-D

Hello, I read the tutorial about pyro - Gaussian Mixture Model 1-D :

My problem is on Iris data set :
I have three different flowers with four attributes,
I’m trying to construct 4-D gmm with 3 clusters:

I followed the tutorial above, and I manage to do 12 different mu’s as I needed, but I can’t do so for sigmas(multivariate don’t work there)

this is my model till now :

def gmm(N,data=None):
sigma = numpyro.sample(‘sigma’,dist.LogNormal(1.))
with numpyro.plate(‘components’,K):
mu = numpyro.sample(‘mu’,dist.MultivariateNormal(jnp.zeros(4),10.*jnp.eye(4)))
with numpyro.plate(‘N’,N):
category = numpyro.sample(‘category’,dist.Categorical(jnp.ones(K)/K))
numpyro.sample(‘obs’,dist.MultivariateNormal(mu[category],sigma * jnp.eye(4)),obs=data)

i want my code to look something like this :

def gmm(N,data=None):
sigma = numpyro.sample(‘sigma’,dist.MultivariateLogNormal(jnp.zeros(4),10.*jnp.eye(4)))
with numpyro.plate(‘components’,K):
mu = numpyro.sample(‘mu’,dist.MultivariateNormal(jnp.zeros(4),10.*jnp.eye(4)))
with numpyro.plate(‘N’,N):
category = numpyro.sample(‘category’,dist.Categorical(jnp.ones(K)/K))
numpyro.sample(‘obs’,dist.MultivariateNormal(mu[category] ,sigma[category] ),obs=data)

Could you help me construct it?

Thanks

Hi @sandalik, if you want multivariate here, I think you can do

sigma = numpyro.sample("sigma",
    dist.TransformedDistribution(
        dist.MultivariateNormal(...), ExpTransform()))

Btw, could you format the code, using the pattern

```python
your code here
```  #