I would like to infer covariance matrices using variational inference while maintaining control over the variational family (i.e. not using auto guides). As (inverse) Wishart distributions aren’t yet available, I’ve been using a scale/correlation decomposition in the model like so (the gamma prior of course needs replacing with something more flexible).
# p is the number of dimensions of the covariance matrix.
def model():
# Sample the scales of the covariance matrix.
d = dist.Gamma(1, 1).expand([p]).independent(1)
scale = pyro.sample('scale', d)
# Sample the correlation matrix.
d = dist.LKJCholesky(p, 1)
cholesky_corr = pyro.sample('cholesky_corr', d)
# Evaluate the Cholesky decomposition of the covariance matrix.
cholesky_cov = cholesky_corr * torch.sqrt(scale[:, None])
# Rest of model goes here...
In the guide, I’ve used a transformed distribution for the correlation matrix like so.
def guide():
# Parametrise the variational distribution for the scales.
concentration = pyro.param('scale_concentration', 10 * torch.ones(p),
constraint=dist.constraints.positive)
rate = pyro.param('scale_rate', 10 * torch.ones(p),
constraint=dist.constraints.positive)
d = dist.Gamma(concentration, rate).independent(1)
scale = pyro.sample('scale', d)
# Parametrise the variational distribution for the covariance.
q = p * (p - 1) // 2 # Number of parameters needed to parametrise a correlation matrix.
loc = pyro.param('cholesky_corr_loc', torch.zeros(q))
cov = pyro.param('cholesky_corr_cov', 1e-2 * torch.eye(q),
constraint=dist.constraints.positive_definite)
d = dist.MultivariateNormal(loc, cov)
d = dist.TransformedDistribution(d, dist.transforms.CorrCholeskyTransform())
cholesky_corr = pyro.sample('cholesky_corr', d)
# Rest of guide goes here...
This seems to work fine as I can infer the covariance of a multivariate Gaussian (see below), but are there better/easier approaches than the above?