How do I enforce a positive constraint on a deterministic parameter?

I’m expressing a model as

\text{treatment} = \text{baseline} + \text{delta}

where \text{baseline} \sim \text{TruncatedNormal(5, 10, low=0)} and \text{delta} \sim N(0, 10)

and treatment = numpyro.deterministic(“treatment”, baseline + delta)

I know for a fact that both baseline and treatment are positive. I have enforced constraint on baseline by using a TruncatedNormal, but how do I make sure that the posterior samples of treatment only contains positive samples?

i don’t think there’s any particularly easy way of doing so.

you can either revisit your priors on baseline and delta and ensure their sum is always positive by definition of the priors or you can put a soft constraint on treatment with a factor statement, e.g.

penalty = penalty_strength * (jnp.fabs(treatment) - treatment)
numpyro.factor("soft_constraint", -penalty)

but note this not a hard constraint so treatment may still be (slightly) negative depending on the strength of the penalty

2 Likes

Hi @martinjankowiak. Thank you, this helped. Could you please tell me where this penalty is getting added?

factor is added directly to the log density of the model.

I have no idea whether this is a great idea or not, but I have converged on a pattern like the following:


param_that_must_be_positive = numpyro.deterministic("...", my_softplus(param_that_might_be_negative, beta=10, offset=0.01))

with

def my_softplus(x, beta=1, offset=0):
    # mirroring the pytorch implementation https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html
    # (the jax one has no beta, and neither of them has an offset)
    cond = x * beta < 20
    # See https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
    # and https://github.com/google/jax/issues/1052
    # for why this double-select is necessary to avoid nan gradients
    x_safe = jax.lax.select(cond, x, jax.numpy.ones_like(x))
    cond_true_val = 1/beta * jax.numpy.log(jax.numpy.exp(beta * offset) + jax.numpy.exp(beta * x_safe))
    cond_false_val = x
    return_val = jax.lax.select(cond, cond_true_val, cond_false_val)
    return return_val

which is my custom tunable jax implementation of (a slightly generalized version of) the softplus function that smoothly maps values from the whole real line to values > 0.