How to apply a NumPyro transform to selected elements of a vector parameter?

Hi all,

I’m working on a hierarchical model in NumPyro where I have a vector-valued parameter (e.g., utility coefficients) and I want to apply a transform (e.g., ExpTransform) to only some of its elements, while leaving the others unchanged. This is needed, for example, to enforce sign constraints — say, I want the price coefficient to be strictly negative, but allow the others to be unconstrained.

In my actual model, the prior on the vector is multivariate normal with non-diagonal covariance, so splitting the parameter and sampling components separately isn’t an option — I need to preserve the dependency structure in the prior.

Here’s a minimal example:

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist

def model(x, y=None):
    mu = jnp.array([0.0, 0.0])
    Sigma = jnp.array([[1.0, 0.8],
                       [0.8, 1.0]])
    mvn = dist.MultivariateNormal(mu, covariance_matrix=Sigma)

    # Want to apply transform only to beta[1], e.g. beta[1] = -exp(z[1])
    beta = numpyro.sample(
        "beta",
        dist.TransformedDistribution(
            mvn,
            [
                # ??? How to transform only beta[1]?
            ]
        )
    )

    sigma = numpyro.sample("sigma", dist.HalfNormal(1.0))
    mu_y = beta[0] + beta[1] * x
    numpyro.sample("obs", dist.Normal(mu_y, sigma), obs=y)

My question:

What’s the recommended way to apply a transform to just one coordinate of a multivariate parameter in NumPyro?

  • I want to keep the multivariate normal prior with non-diagonal covariance
  • I want to transform only a subset of the vector (e.g. make beta[1] strictly negative)
  • I’d like to use TransformedDistribution if possible so that things like TransformReparam still work

Thanks for your help!

i’m not sure i understand what you’re asking. not everything needs to be done through a transform. e.g. this could be a viable model:

def model(obs):
    z = numpyro.sample("z", dist.MultivariateNormal(...))
    z_even, z_odd = z[0::2], z[1::2]
    z_transform = jnp.concatenate([z_even, jnp.exp(z_odd)])
    numpyro.sample("obs", dist.Normal(z_transform, 1.0), obs=obs)

Thanks! What’s the general advice on when to use transform or not?

when it is convenient