Hello devs, I’m trying to reparametrize Truncated Normal in my NumPyro model like this:
(I’m using NUTS MCMC)
Initially in my model (results in rhats equal to 1 for a
)
a = numpyro.sample(
"a",
dist.TruncatedNormal(mu_a, sigma_a, low=0)
)
I reparametrize this as follows
reparam_config = {"a": TransformReparam()}
with numpyro.handlers.reparam(config=reparam_config):
a = numpyro.sample(
"a",
dist.TransformedDistribution(
dist.Normal(0, 1),
[AffineTransform(mu_a, sigma_a), AbsTransform()]
)
)
This is how it looks like
def model(...)
...
with numpyro.plate("plate_r", 1, dim=-1):
with numpyro.plate("plate_s", 2, dim=-2):
mu_a = numpyro.sample(
"mu_a",
dist.TruncatedNormal(10, 5, low=0)
)
sigma_a = numpyro.sample("sigma_a", dist.HalfNormal(5))
with numpyro.plate("plate_f", 3, dim=-3):
# Earlier: Rhats for all `a` = 1
# a = numpyro.sample(
# "a",
# dist.TruncatedNormal(mu_a, sigma_a, low=0)
# )
# Now: Rhats for all `a` = 1, but for `a_base` ~ 1.5 or 2
reparam_config = {"a": TransformReparam()}
with numpyro.handlers.reparam(config=reparam_config):
a = numpyro.sample(
"a",
dist.TransformedDistribution(
dist.Normal(0, 1),
[AffineTransform(mu_a, sigma_a), AbsTransform()]
)
)
...
For the new model, the rhats for a
are again 1, but the rhats for a_base
are close to 2 or 1.5. The final recruitment curves (which make use of “a”) are all fine (like in the initial model)
How do I evaluate this situation? Can I ignore the rhats for a_base? I’m not sure what’s going on here. Is this a common occurrence with AbsTransform?