Factor vs constraint on centered and decentered parameterizations

I’m trying to use a HalfMultivariateNormal prior on a “non-root” parameter.

The most straightforward attempt can be seen in model1 below, but this leads to some divergences (especially on certain sets of data) because of Neal’s funnel issues.

When I manually decenter the distribution, as in model2 below, I have to switch back to a regular MultivariateNormal and then place a factor on the centered values. However, this approach leads to 100% divergence on negative data. I think because the factor constraint doesn’t get “propagated” in the same way as the proper constraint in HalfMultivariateNormal, we just end up applying -np.inf log-prob to everything.

Is the best approach here to write a proper Transform for TransformReparam on HalfMultivariateNormal that can handle both the decentering during sampling and a constraint on the centered version?

Thanks. Let me know if anything is unclear.

import numpy as np

import numpyro as ny
from numpyro.infer import MCMC, NUTS
import numpyro.distributions as dist
from numpyro import optim

import jax.numpy as jnp
import jax.nn as jnn
from jax import random, ops
import jax

import matplotlib.pyplot as plt

from functools import partial

class HalfMultivariateNormal(dist.Distribution):
    support = dist.constraints.independent(dist.constraints.positive, 1)
    arg_constraints = {
        "covariance_matrix": dist.constraints.positive_definite,
    }

    # Mostly copied from MultivariateNormal
    def __init__(self, covariance_matrix, validate_args=None):
        self._mv_normal = dist.MultivariateNormal(np.zeros(covariance_matrix.shape[0]), covariance_matrix)
        self._dims = covariance_matrix.shape[0]
        
        loc = np.zeros(covariance_matrix.shape[0])
        # temporary append a new axis to loc
        loc = loc[..., jnp.newaxis]
        
        loc, self.covariance_matrix = dist.util.promote_shapes(loc, covariance_matrix)
        scale_tril = jnp.linalg.cholesky(self.covariance_matrix)
        
        batch_shape = jax.lax.broadcast_shapes(
            jnp.shape(loc)[:-2], jnp.shape(scale_tril)[:-2]
        )
        event_shape = jnp.shape(scale_tril)[-1:]
        
        super(HalfMultivariateNormal, self).__init__(
            batch_shape=batch_shape,
            event_shape=event_shape,
            validate_args=validate_args
        )
        
    def sample(self, key, sample_shape=()):
        return jnp.abs(self._mv_normal.sample(key, sample_shape))

    @ny.distributions.util.validate_sample
    def log_prob(self, value):
        return self._mv_normal.log_prob(value) + jnp.log(2) * self._dims

def model1(shape, data=None):
    length = shape[0]
    width = shape[1]

    tau = ny.sample("tau", dist.HalfNormal(scale=1))
    corr = ny.sample("corr", dist.LKJ(width, concentration=1))
    std_devs = ny.sample("std_devs", dist.HalfNormal(jnp.ones(width)))
    cov = jnp.matmul(jnp.matmul(jnp.diag(std_devs), corr), jnp.diag(std_devs))
    
    weights = ny.sample("weights", HalfMultivariateNormal(covariance_matrix=jnp.linalg.inv(cov) * tau))
    
    with ny.plate("data", length):
        ny.sample("final", dist.MultivariateNormal(loc=jnp.matmul(cov, weights), covariance_matrix=cov), obs=data)

def model2(shape, data=None):
    length = shape[0]
    width = shape[1]

    tau = ny.sample("tau", dist.HalfNormal(scale=1))
    corr = ny.sample("corr", dist.LKJ(width, concentration=1))
    std_devs = ny.sample("std_devs", dist.HalfNormal(jnp.ones(width)))
    cov = jnp.matmul(jnp.matmul(jnp.diag(std_devs), corr), jnp.diag(std_devs))
    
    weights_reparam = ny.sample("weights_reparam", dist.MultivariateNormal(loc=np.zeros(width), covariance_matrix=jnp.identity(width)))
    weights = jnp.matmul(jnp.linalg.cholesky(tau * jnp.linalg.inv(cov)), weights_reparam)
    ny.factor("positive_only", jax.lax.cond(jnp.any(weights < 0), lambda _: -np.inf, lambda _: 0.0, operand=None))
    
    with ny.plate("data", length):
        ny.sample("final", dist.MultivariateNormal(loc=jnp.matmul(cov, weights), covariance_matrix=cov), obs=data)

I was eventually able to get the desired behavior via:

class CustomConstraint(dist.constraints.Constraint):
    def __init__(self, cov, tau):
        self.cov = cov
        self.tau = tau
    @property
    def event_dim(self):
        return dist.constraints.independent(dist.constraints.positive, 1).event_dim
    def __call__(self, x):
        return dist.constraints.independent(dist.constraints.positive, 1)(jnp.matmul(jnp.linalg.cholesky(self.tau * jnp.linalg.inv(self.cov)), x))
    
def constraintToTransform(constraint):
    return trans.ComposeTransform([
        trans.IndependentTransform(trans.ExpTransform(), 1),
        trans._InverseTransform(CustomTransform(constraint.cov, constraint.tau))
    ])
    
trans.biject_to.register(CustomConstraint, constraintToTransform)

class CustomTransform(dist.transforms.Transform):
    domain = dist.constraints.real_vector
    codomain = dist.constraints.independent(dist.constraints.positive, 1)
 
    def __init__(self, cov, tau):
        self.chol = jnp.linalg.cholesky(tau * jnp.linalg.inv(cov))
    def __call__(self, x):
        return jnp.matmul(self.chol, x)
    def _inverse(self, y):
        return jnp.matmul(jnp.linalg.inv(self.chol), y)
    def log_abs_det_jacobian(self, x, y, intermediates=None):
        return jnp.broadcast_to(jnp.log(jnp.abs(jnp.linalg.det(self.chol))), jnp.shape(x))
        
    
class CustomMultivariateNormal(dist.Distribution):
    
    @property
    def support(self):
        return CustomConstraint(self.covariance_matrix, self.tau)
    
    def __init__(self, covariance_matrix, tau, validate_args=None):
        self.covariance_matrix = covariance_matrix
        self.tau = tau
        
        loc = np.zeros(covariance_matrix.shape[0])
        scale = np.identity(covariance_matrix.shape[0])
        self._mv_normal = dist.MultivariateNormal(loc, scale)
        
        super().__init__(
            batch_shape=self._mv_normal._batch_shape,
            event_shape=self._mv_normal._event_shape,
            validate_args=validate_args)

        
    def sample(self, key, sample_shape=()):
        return self._mv_normal.sample(key, sample_shape)
    
    @ny.distributions.util.validate_sample
    def log_prob(self, value):
        return self._mv_normal.log_prob(value)

But this is quite ugly and I’d welcome any simplifications.

Interesting! I guess a solution is to repameterize your model:

with handlers.reparam(config=infer.reparam.TransformReparam()):
    # instead of cholesky(cov), you can use LKJCholesky for `corr`
    # then multiply it with `std_devs[..., None]` - this is faster and more stable
    weights = sample("weights", TransformedDistribution(dist.Normal(0,1).expand([width]),
        LowerCholeskyAffine(0., cholesky(cov)))
weights_pos = soft_abs(weights)

where approx_abs is a differentiable approximation of abs, e.g.

  • sqrt(x^2 + eps) - when eps is a small positive number, say 0.0001
  • (softplus(cx) + softplus(-cx)) / c - where c is a large positive number, say 1000

This approach seems to work really nicely. Thanks so much!