Reparametrise 2D Gaussian mixture

Hi there, can you please help me reparametrise the GM model in 2D for efficient MCMC with NUTS.

I am fitting for both the covariance matrices and the means of mixture components. The MCMC mixing is quite poor, because covariance matrix is a random variable. Alas, I’ve hit a wall trying to reparametrise it.

Specifically, in the model below I need to reparametrise the site ‘obs’. The most natural way IMHO is via
dist.transforms.LowerCholeskyAffine, but it does not work inside my plate. I tried a way around it – but to no avail. My 3 attempts at reparametrising are in the model below. They are not satisfactory.

Thank you for any tips/help.

def GM_mix_backward(data, K = 3):  # y has dimension N x d
    Mixture of K Gaussians: fitting cov matrices and means
    2D Gaussian distribution with random cov matrix
    if data is not None:
        N = data.shape[0]
        d = data.shape[1]

    phi = sample('phi', dist.Dirichlet(np.ones(K)))

    with plate('K', K, dim = -1):

        # Vector of variances for each of the d variables
        theta = numpyro.sample("theta", dist.HalfCauchy(1).expand([d]).to_event(1))
        # Lower cholesky factor of a correlation matrix
        concentration = np.ones(1)*.1  # Implies correlated variables
        L_omega = numpyro.sample("L_omega", dist.LKJCholesky(d, concentration))
        sigma = np.sqrt(theta)
        L_Omega = sigma[..., None] * L_omega

        # MEANS PRIOR
        mu = sample('mu', dist.Normal(0, 1).expand([d]).to_event(1))

    with numpyro.plate("observations", N, dim=-1):
        z = sample('z', dist.Categorical(phi))
        numpyro.sample('obs', dist.MultivariateNormal(mu[z], scale_tril=L_Omega[z]), obs=data)


        # TRY 1: (throws not implemented error)
        #  numpyro.sample('obs1', 
        #                 dist.TransformedDistribution(dist.MultivariateNormal(np.zeros(d), np.eye(d)),
        #                                                 dist.transforms.LowerCholeskyAffine(mu[z], 
        #                                                                                     L_Omega[z])))
        # TRY 2: (MCMC fails to find initial guess)
        # a = numpyro.sample('a', dist.MultivariateNormal(np.zeros(d), np.eye(d)))
        # b = mu[z]+np.einsum('...ij,...j', L_Omega[z], a)
        # numpyro.sample("obs2", dist.Delta(b, event_dim = 1), obs=data)

        # TRY 3: (VERY WASTEFUL)
        # numpyro.sample("obs3", dist.MultivariateNormal(b, np.eye(d)*.001), obs=data)


Why you need to reparam the observed site? Reparam is only useful for latent sites - where we want to make geometry of posterior better for inference.

your issues are probably more readily explained by “mode switching” and multi-modality