Hi all,
I’m working on a hierarchical model in NumPyro where I have a vector-valued parameter (e.g., utility coefficients) and I want to apply a transform (e.g., ExpTransform) to only some of its elements, while leaving the others unchanged. This is needed, for example, to enforce sign constraints — say, I want the price coefficient to be strictly negative, but allow the others to be unconstrained.
In my actual model, the prior on the vector is multivariate normal with non-diagonal covariance, so splitting the parameter and sampling components separately isn’t an option — I need to preserve the dependency structure in the prior.
Here’s a minimal example:
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
def model(x, y=None):
mu = jnp.array([0.0, 0.0])
Sigma = jnp.array([[1.0, 0.8],
[0.8, 1.0]])
mvn = dist.MultivariateNormal(mu, covariance_matrix=Sigma)
# Want to apply transform only to beta[1], e.g. beta[1] = -exp(z[1])
beta = numpyro.sample(
"beta",
dist.TransformedDistribution(
mvn,
[
# ??? How to transform only beta[1]?
]
)
)
sigma = numpyro.sample("sigma", dist.HalfNormal(1.0))
mu_y = beta[0] + beta[1] * x
numpyro.sample("obs", dist.Normal(mu_y, sigma), obs=y)
My question:
What’s the recommended way to apply a transform to just one coordinate of a multivariate parameter in NumPyro?
- I want to keep the multivariate normal prior with non-diagonal covariance
- I want to transform only a subset of the vector (e.g. make beta[1] strictly negative)
- I’d like to use TransformedDistribution if possible so that things like TransformReparam still work
Thanks for your help!