Diagonal Covariances for 2D Gaussian Mixture

Hi.

I was building on this previous post to implement a 2D gaussian mixture. Instead of having a shared covariance with constant diagonal for all components, I wanted different, arbitrary diagonal covariances. After a lot of tweaking, I finally got some idea of how batch and event shapes take part in all this and got this model to work:

@config_enumerate
def gaussian_enum(K,data=None):
    dim, num = data.shape
    
    weights = numpyro.sample('weights', dist.Dirichlet(jnp.ones(K)))

    with numpyro.plate("variables", dim):
        with numpyro.plate("components", K):
            sigmas = numpyro.sample("sigmas", dist.Exponential(1))
    
    with numpyro.plate('components', K):
        locs = numpyro.sample('locs',dist.MultivariateNormal(jnp.zeros(dim),5*jnp.eye(dim))) 

    with numpyro.plate('data', num):
        assignment = numpyro.sample('assignment', dist.Categorical(weights)) 

        cv = jnp.apply_along_axis(jnp.diag, -1, sigmas[assignment,:])
        numpyro.sample(
            'obs', 
            dist.MultivariateNormal(locs[assignment, :], covariance_matrix=cv), 
            obs=data.T
        )

Even so, the use of jnp.apply_along_axis looks like a hack and I am left wondering whether there is a more elegant way of achieving the same effect. I’ve seen examples where a LKJCholesky distribution is used, but I think that may not be necessary when one is only interested in diagonal covariances.

Thank you.

Well, after all the time I spent on making the above code work, it has just struck me that the solution is simpler: simply build the covariance matrices for the components after sampling sigmas. Then I can rely on NumPyro to access the right dimensions in the last plate. For completeness, in case this is useful for somebody else, this is the revised code:

@config_enumerate
def gaussian_enum(K,data=None):
    dim, num = data.shape
    
    weights = numpyro.sample('weights', dist.Dirichlet(jnp.ones(K)))

    with numpyro.plate("variables", dim):
        with numpyro.plate("components", K):
            sigmas = numpyro.sample("sigmas", dist.Exponential(1))
    cv = jnp.array([jnp.diag(s) for s in sigmas])
    
    with numpyro.plate('components', K):
        locs = numpyro.sample('locs',dist.MultivariateNormal(jnp.zeros(dim),5*jnp.eye(dim))) 

    with numpyro.plate('data', num):
        assignment = numpyro.sample('assignment', dist.Categorical(weights))#, infer={'enumerate': 'parallel'}) 
        
        numpyro.sample(
            'obs', 
            dist.MultivariateNormal(locs[assignment, :], covariance_matrix=cv[assignment]), 
            obs=data.T
        )
1 Like