Hi there, can you please help me reparametrise the GM model in 2D for efficient MCMC with NUTS.
I am fitting for both the covariance matrices and the means of mixture components. The MCMC mixing is quite poor, because covariance matrix is a random variable. Alas, I’ve hit a wall trying to reparametrise it.
Specifically, in the model below I need to reparametrise the site ‘obs’. The most natural way IMHO is via
dist.transforms.LowerCholeskyAffine, but it does not work inside my plate. I tried a way around it – but to no avail. My 3 attempts at reparametrising are in the model below. They are not satisfactory.
Thank you for any tips/help.
def GM_mix_backward(data, K = 3): # y has dimension N x d
"""
Mixture of K Gaussians: fitting cov matrices and means
2D Gaussian distribution with random cov matrix
"""
if data is not None:
N = data.shape[0]
d = data.shape[1]
# CATEGORICAL PRIOR
phi = sample('phi', dist.Dirichlet(np.ones(K)))
# PRIOR ON COMPONENT MEANS AND COVS
with plate('K', K, dim = -1):
# COVARIANCE MATRIX PRIOR
# Vector of variances for each of the d variables
theta = numpyro.sample("theta", dist.HalfCauchy(1).expand([d]).to_event(1))
# Lower cholesky factor of a correlation matrix
concentration = np.ones(1)*.1 # Implies correlated variables
L_omega = numpyro.sample("L_omega", dist.LKJCholesky(d, concentration))
sigma = np.sqrt(theta)
L_Omega = sigma[..., None] * L_omega
# MEANS PRIOR
mu = sample('mu', dist.Normal(0, 1).expand([d]).to_event(1))
with numpyro.plate("observations", N, dim=-1):
z = sample('z', dist.Categorical(phi))
numpyro.sample('obs', dist.MultivariateNormal(mu[z], scale_tril=L_Omega[z]), obs=data)
# REPARAMETRISING SITE 'obs':
# TRY 1: (throws not implemented error)
# numpyro.sample('obs1',
# dist.TransformedDistribution(dist.MultivariateNormal(np.zeros(d), np.eye(d)),
# dist.transforms.LowerCholeskyAffine(mu[z],
# L_Omega[z])))
# TRY 2: (MCMC fails to find initial guess)
# a = numpyro.sample('a', dist.MultivariateNormal(np.zeros(d), np.eye(d)))
# b = mu[z]+np.einsum('...ij,...j', L_Omega[z], a)
# numpyro.sample("obs2", dist.Delta(b, event_dim = 1), obs=data)
# TRY 3: (VERY WASTEFUL)
# numpyro.sample("obs3", dist.MultivariateNormal(b, np.eye(d)*.001), obs=data)