Constraints on Parameters with a Non-Centered Parameterization

Hi Everyone!

I’m working with a hierarchical model, using a non-centered parameterization, and sampling with NUTS. The issue I’m running in to is that I can’t figure out how to enforce constraints on the parameters after the non-centering parameterization.

I have a section of the model that looks like this currently:

        with numpyro.plate('digital_plate', self.n_digital_terms):
            global_digital_coef = numpyro.sample('global_digital', dist.HalfNormal(media_scale))
            with numpyro.plate('categories_digital_plate', self.n_categories):
                cat_digital_coef = numpyro.sample('categories_digital', dist.HalfNormal(media_scale)).T
                global_digital_coef = global_digital_coef[..., jnp.newaxis]

        jnp.einsum('tf..., f... -> t...', transformed_spend, cat_digital_coef + global_digital_coef)

I’m using the two HalfNormal distributions purely because I need to ensure cat_digital_coef + global_digital_coef >= 0 and I don’t know how to enforce that.

I want something more along the lines of the following:

        with numpyro.plate('digital_plate', self.n_digital_terms):
            global_digital_coef = numpyro.sample('global_digital', dist.HalfNormal(media_scale))
            with numpyro.plate('categories_digital_plate', self.n_categories):
                cat_digital_coef = numpyro.sample('categories_digital', dist.Normal(0, media_scale)).T
                global_digital_coef = global_digital_coef[..., jnp.newaxis]

        jnp.einsum('tf..., f... -> t...', transformed_spend, cat_digital_coef + global_digital_coef)

I’ve read some resources on using LocScaleReparam, numpyro.infer.reparam, etc. but I’m really struggling to figure out how to use those tools and enforce my constraints.

I’d really appreciate help figuring out how to enforce the constraint cat_digital_coef + global_digital_coef >= 0 in my model.

Thanks a ton for any help!

you can probably use something like

ImproperUniform(constraints.greater_than(0), (6, 8), event_shape=(5,)))

Thanks for the idea! I just want to confirm I’m thinking of this correctly based on the Ordinal Regression example. It looks like my code should end up looking something like the following:

with numpyro.plate('digital_plate', self.n_digital_terms):
    global_digital_coef = numpyro.sample('global_digital', dist.HalfNormal(media_scale))
    with numpyro.plate('categories_digital_plate', self.n_categories):
        const = numpyro.sample(
            'const',
            ImproperUniform(numpyro.distributions.constraints.positive, (), ())
        )
        cat_digital_coef = numpyro.sample('categories_digital', dist.Normal(0, media_scale), obs=const).T
        global_digital_coef = global_digital_coef[..., jnp.newaxis]

jnp.einsum('tf..., f... -> t...', transformed_spend, cat_digital_coef + global_digital_coef)

Am I thinking about that right?