Hi Everyone!
I’m working with a hierarchical model, using a non-centered parameterization, and sampling with NUTS. The issue I’m running in to is that I can’t figure out how to enforce constraints on the parameters after the non-centering parameterization.
I have a section of the model that looks like this currently:
with numpyro.plate('digital_plate', self.n_digital_terms):
global_digital_coef = numpyro.sample('global_digital', dist.HalfNormal(media_scale))
with numpyro.plate('categories_digital_plate', self.n_categories):
cat_digital_coef = numpyro.sample('categories_digital', dist.HalfNormal(media_scale)).T
global_digital_coef = global_digital_coef[..., jnp.newaxis]
jnp.einsum('tf..., f... -> t...', transformed_spend, cat_digital_coef + global_digital_coef)
I’m using the two HalfNormal distributions purely because I need to ensure cat_digital_coef + global_digital_coef >= 0
and I don’t know how to enforce that.
I want something more along the lines of the following:
with numpyro.plate('digital_plate', self.n_digital_terms):
global_digital_coef = numpyro.sample('global_digital', dist.HalfNormal(media_scale))
with numpyro.plate('categories_digital_plate', self.n_categories):
cat_digital_coef = numpyro.sample('categories_digital', dist.Normal(0, media_scale)).T
global_digital_coef = global_digital_coef[..., jnp.newaxis]
jnp.einsum('tf..., f... -> t...', transformed_spend, cat_digital_coef + global_digital_coef)
I’ve read some resources on using LocScaleReparam
, numpyro.infer.reparam
, etc. but I’m really struggling to figure out how to use those tools and enforce my constraints.
I’d really appreciate help figuring out how to enforce the constraint cat_digital_coef + global_digital_coef >= 0
in my model.
Thanks a ton for any help!