I’m trying to use a HalfMultivariateNormal
prior on a “non-root” parameter.
The most straightforward attempt can be seen in model1
below, but this leads to some divergences (especially on certain sets of data) because of Neal’s funnel issues.
When I manually decenter the distribution, as in model2
below, I have to switch back to a regular MultivariateNormal
and then place a factor
on the centered values. However, this approach leads to 100% divergence on negative data. I think because the factor
constraint doesn’t get “propagated” in the same way as the proper constraint in HalfMultivariateNormal
, we just end up applying -np.inf
log-prob to everything.
Is the best approach here to write a proper Transform
for TransformReparam
on HalfMultivariateNormal
that can handle both the decentering during sampling and a constraint on the centered version?
Thanks. Let me know if anything is unclear.
import numpy as np
import numpyro as ny
from numpyro.infer import MCMC, NUTS
import numpyro.distributions as dist
from numpyro import optim
import jax.numpy as jnp
import jax.nn as jnn
from jax import random, ops
import jax
import matplotlib.pyplot as plt
from functools import partial
class HalfMultivariateNormal(dist.Distribution):
support = dist.constraints.independent(dist.constraints.positive, 1)
arg_constraints = {
"covariance_matrix": dist.constraints.positive_definite,
}
# Mostly copied from MultivariateNormal
def __init__(self, covariance_matrix, validate_args=None):
self._mv_normal = dist.MultivariateNormal(np.zeros(covariance_matrix.shape[0]), covariance_matrix)
self._dims = covariance_matrix.shape[0]
loc = np.zeros(covariance_matrix.shape[0])
# temporary append a new axis to loc
loc = loc[..., jnp.newaxis]
loc, self.covariance_matrix = dist.util.promote_shapes(loc, covariance_matrix)
scale_tril = jnp.linalg.cholesky(self.covariance_matrix)
batch_shape = jax.lax.broadcast_shapes(
jnp.shape(loc)[:-2], jnp.shape(scale_tril)[:-2]
)
event_shape = jnp.shape(scale_tril)[-1:]
super(HalfMultivariateNormal, self).__init__(
batch_shape=batch_shape,
event_shape=event_shape,
validate_args=validate_args
)
def sample(self, key, sample_shape=()):
return jnp.abs(self._mv_normal.sample(key, sample_shape))
@ny.distributions.util.validate_sample
def log_prob(self, value):
return self._mv_normal.log_prob(value) + jnp.log(2) * self._dims
def model1(shape, data=None):
length = shape[0]
width = shape[1]
tau = ny.sample("tau", dist.HalfNormal(scale=1))
corr = ny.sample("corr", dist.LKJ(width, concentration=1))
std_devs = ny.sample("std_devs", dist.HalfNormal(jnp.ones(width)))
cov = jnp.matmul(jnp.matmul(jnp.diag(std_devs), corr), jnp.diag(std_devs))
weights = ny.sample("weights", HalfMultivariateNormal(covariance_matrix=jnp.linalg.inv(cov) * tau))
with ny.plate("data", length):
ny.sample("final", dist.MultivariateNormal(loc=jnp.matmul(cov, weights), covariance_matrix=cov), obs=data)
def model2(shape, data=None):
length = shape[0]
width = shape[1]
tau = ny.sample("tau", dist.HalfNormal(scale=1))
corr = ny.sample("corr", dist.LKJ(width, concentration=1))
std_devs = ny.sample("std_devs", dist.HalfNormal(jnp.ones(width)))
cov = jnp.matmul(jnp.matmul(jnp.diag(std_devs), corr), jnp.diag(std_devs))
weights_reparam = ny.sample("weights_reparam", dist.MultivariateNormal(loc=np.zeros(width), covariance_matrix=jnp.identity(width)))
weights = jnp.matmul(jnp.linalg.cholesky(tau * jnp.linalg.inv(cov)), weights_reparam)
ny.factor("positive_only", jax.lax.cond(jnp.any(weights < 0), lambda _: -np.inf, lambda _: 0.0, operand=None))
with ny.plate("data", length):
ny.sample("final", dist.MultivariateNormal(loc=jnp.matmul(cov, weights), covariance_matrix=cov), obs=data)