# Factor vs constraint on centered and decentered parameterizations

I’m trying to use a `HalfMultivariateNormal` prior on a “non-root” parameter.

The most straightforward attempt can be seen in `model1` below, but this leads to some divergences (especially on certain sets of data) because of Neal’s funnel issues.

When I manually decenter the distribution, as in `model2` below, I have to switch back to a regular `MultivariateNormal` and then place a `factor` on the centered values. However, this approach leads to 100% divergence on negative data. I think because the `factor` constraint doesn’t get “propagated” in the same way as the proper constraint in `HalfMultivariateNormal`, we just end up applying `-np.inf` log-prob to everything.

Is the best approach here to write a proper `Transform` for `TransformReparam` on `HalfMultivariateNormal` that can handle both the decentering during sampling and a constraint on the centered version?

Thanks. Let me know if anything is unclear.

``````import numpy as np

import numpyro as ny
from numpyro.infer import MCMC, NUTS
import numpyro.distributions as dist
from numpyro import optim

import jax.numpy as jnp
import jax.nn as jnn
from jax import random, ops
import jax

import matplotlib.pyplot as plt

from functools import partial

class HalfMultivariateNormal(dist.Distribution):
support = dist.constraints.independent(dist.constraints.positive, 1)
arg_constraints = {
"covariance_matrix": dist.constraints.positive_definite,
}

# Mostly copied from MultivariateNormal
def __init__(self, covariance_matrix, validate_args=None):
self._mv_normal = dist.MultivariateNormal(np.zeros(covariance_matrix.shape[0]), covariance_matrix)
self._dims = covariance_matrix.shape[0]

loc = np.zeros(covariance_matrix.shape[0])
# temporary append a new axis to loc
loc = loc[..., jnp.newaxis]

loc, self.covariance_matrix = dist.util.promote_shapes(loc, covariance_matrix)
scale_tril = jnp.linalg.cholesky(self.covariance_matrix)

jnp.shape(loc)[:-2], jnp.shape(scale_tril)[:-2]
)
event_shape = jnp.shape(scale_tril)[-1:]

super(HalfMultivariateNormal, self).__init__(
batch_shape=batch_shape,
event_shape=event_shape,
validate_args=validate_args
)

def sample(self, key, sample_shape=()):
return jnp.abs(self._mv_normal.sample(key, sample_shape))

@ny.distributions.util.validate_sample
def log_prob(self, value):
return self._mv_normal.log_prob(value) + jnp.log(2) * self._dims

def model1(shape, data=None):
length = shape[0]
width = shape[1]

tau = ny.sample("tau", dist.HalfNormal(scale=1))
corr = ny.sample("corr", dist.LKJ(width, concentration=1))
std_devs = ny.sample("std_devs", dist.HalfNormal(jnp.ones(width)))
cov = jnp.matmul(jnp.matmul(jnp.diag(std_devs), corr), jnp.diag(std_devs))

weights = ny.sample("weights", HalfMultivariateNormal(covariance_matrix=jnp.linalg.inv(cov) * tau))

with ny.plate("data", length):
ny.sample("final", dist.MultivariateNormal(loc=jnp.matmul(cov, weights), covariance_matrix=cov), obs=data)

def model2(shape, data=None):
length = shape[0]
width = shape[1]

tau = ny.sample("tau", dist.HalfNormal(scale=1))
corr = ny.sample("corr", dist.LKJ(width, concentration=1))
std_devs = ny.sample("std_devs", dist.HalfNormal(jnp.ones(width)))
cov = jnp.matmul(jnp.matmul(jnp.diag(std_devs), corr), jnp.diag(std_devs))

weights_reparam = ny.sample("weights_reparam", dist.MultivariateNormal(loc=np.zeros(width), covariance_matrix=jnp.identity(width)))
weights = jnp.matmul(jnp.linalg.cholesky(tau * jnp.linalg.inv(cov)), weights_reparam)
ny.factor("positive_only", jax.lax.cond(jnp.any(weights < 0), lambda _: -np.inf, lambda _: 0.0, operand=None))

with ny.plate("data", length):
ny.sample("final", dist.MultivariateNormal(loc=jnp.matmul(cov, weights), covariance_matrix=cov), obs=data)
``````

I was eventually able to get the desired behavior via:

``````class CustomConstraint(dist.constraints.Constraint):
def __init__(self, cov, tau):
self.cov = cov
self.tau = tau
@property
def event_dim(self):
return dist.constraints.independent(dist.constraints.positive, 1).event_dim
def __call__(self, x):
return dist.constraints.independent(dist.constraints.positive, 1)(jnp.matmul(jnp.linalg.cholesky(self.tau * jnp.linalg.inv(self.cov)), x))

def constraintToTransform(constraint):
return trans.ComposeTransform([
trans.IndependentTransform(trans.ExpTransform(), 1),
trans._InverseTransform(CustomTransform(constraint.cov, constraint.tau))
])

trans.biject_to.register(CustomConstraint, constraintToTransform)

class CustomTransform(dist.transforms.Transform):
domain = dist.constraints.real_vector
codomain = dist.constraints.independent(dist.constraints.positive, 1)

def __init__(self, cov, tau):
self.chol = jnp.linalg.cholesky(tau * jnp.linalg.inv(cov))
def __call__(self, x):
return jnp.matmul(self.chol, x)
def _inverse(self, y):
return jnp.matmul(jnp.linalg.inv(self.chol), y)
def log_abs_det_jacobian(self, x, y, intermediates=None):

class CustomMultivariateNormal(dist.Distribution):

@property
def support(self):
return CustomConstraint(self.covariance_matrix, self.tau)

def __init__(self, covariance_matrix, tau, validate_args=None):
self.covariance_matrix = covariance_matrix
self.tau = tau

loc = np.zeros(covariance_matrix.shape[0])
scale = np.identity(covariance_matrix.shape[0])
self._mv_normal = dist.MultivariateNormal(loc, scale)

super().__init__(
batch_shape=self._mv_normal._batch_shape,
event_shape=self._mv_normal._event_shape,
validate_args=validate_args)

def sample(self, key, sample_shape=()):
return self._mv_normal.sample(key, sample_shape)

@ny.distributions.util.validate_sample
def log_prob(self, value):
return self._mv_normal.log_prob(value)
``````

But this is quite ugly and I’d welcome any simplifications.

Interesting! I guess a solution is to repameterize your model:

``````with handlers.reparam(config=infer.reparam.TransformReparam()):
# instead of cholesky(cov), you can use LKJCholesky for `corr`
# then multiply it with `std_devs[..., None]` - this is faster and more stable
weights = sample("weights", TransformedDistribution(dist.Normal(0,1).expand([width]),
LowerCholeskyAffine(0., cholesky(cov)))
weights_pos = soft_abs(weights)
``````

where `approx_abs` is a differentiable approximation of `abs`, e.g.

• sqrt(x^2 + eps) - when eps is a small positive number, say 0.0001
• (softplus(cx) + softplus(-cx)) / c - where c is a large positive number, say 1000

This approach seems to work really nicely. Thanks so much!