where \text{baseline} \sim \text{TruncatedNormal(5, 10, low=0)} and \text{delta} \sim N(0, 10)
and treatment = numpyro.deterministic(“treatment”, baseline + delta)
I know for a fact that both baseline and treatment are positive. I have enforced constraint on baseline by using a TruncatedNormal, but how do I make sure that the posterior samples of treatment only contains positive samples?
i don’t think there’s any particularly easy way of doing so.
you can either revisit your priors on baseline and delta and ensure their sum is always positive by definition of the priors or you can put a soft constraint on treatment with a factor statement, e.g.
def my_softplus(x, beta=1, offset=0):
# mirroring the pytorch implementation https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html
# (the jax one has no beta, and neither of them has an offset)
cond = x * beta < 20
# See https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
# and https://github.com/google/jax/issues/1052
# for why this double-select is necessary to avoid nan gradients
x_safe = jax.lax.select(cond, x, jax.numpy.ones_like(x))
cond_true_val = 1/beta * jax.numpy.log(jax.numpy.exp(beta * offset) + jax.numpy.exp(beta * x_safe))
cond_false_val = x
return_val = jax.lax.select(cond, cond_true_val, cond_false_val)
return return_val
which is my custom tunable jax implementation of (a slightly generalized version of) the softplus function that smoothly maps values from the whole real line to values > 0.