Hi there, can you please help me reparametrise the GM model in 2D for efficient MCMC with NUTS.
I am fitting for both the covariance matrices and the means of mixture components. The MCMC mixing is quite poor, because covariance matrix is a random variable. Alas, I’ve hit a wall trying to reparametrise it.
Specifically, in the model below I need to reparametrise the site ‘obs’. The most natural way IMHO is via
dist.transforms.LowerCholeskyAffine, but it does not work inside my plate. I tried a way around it – but to no avail. My 3 attempts at reparametrising are in the model below. They are not satisfactory.
Thank you for any tips/help.
def GM_mix_backward(data, K = 3): # y has dimension N x d """ Mixture of K Gaussians: fitting cov matrices and means 2D Gaussian distribution with random cov matrix """ if data is not None: N = data.shape d = data.shape # CATEGORICAL PRIOR phi = sample('phi', dist.Dirichlet(np.ones(K))) # PRIOR ON COMPONENT MEANS AND COVS with plate('K', K, dim = -1): # COVARIANCE MATRIX PRIOR # Vector of variances for each of the d variables theta = numpyro.sample("theta", dist.HalfCauchy(1).expand([d]).to_event(1)) # Lower cholesky factor of a correlation matrix concentration = np.ones(1)*.1 # Implies correlated variables L_omega = numpyro.sample("L_omega", dist.LKJCholesky(d, concentration)) sigma = np.sqrt(theta) L_Omega = sigma[..., None] * L_omega # MEANS PRIOR mu = sample('mu', dist.Normal(0, 1).expand([d]).to_event(1)) with numpyro.plate("observations", N, dim=-1): z = sample('z', dist.Categorical(phi)) numpyro.sample('obs', dist.MultivariateNormal(mu[z], scale_tril=L_Omega[z]), obs=data) # REPARAMETRISING SITE 'obs': # TRY 1: (throws not implemented error) # numpyro.sample('obs1', # dist.TransformedDistribution(dist.MultivariateNormal(np.zeros(d), np.eye(d)), # dist.transforms.LowerCholeskyAffine(mu[z], # L_Omega[z]))) # TRY 2: (MCMC fails to find initial guess) # a = numpyro.sample('a', dist.MultivariateNormal(np.zeros(d), np.eye(d))) # b = mu[z]+np.einsum('...ij,...j', L_Omega[z], a) # numpyro.sample("obs2", dist.Delta(b, event_dim = 1), obs=data) # TRY 3: (VERY WASTEFUL) # numpyro.sample("obs3", dist.MultivariateNormal(b, np.eye(d)*.001), obs=data)