Reparametrization for Truncated Normal

Hello devs, I’m trying to reparametrize Truncated Normal in my NumPyro model like this:

(I’m using NUTS MCMC)

Initially in my model (results in rhats equal to 1 for a)

a = numpyro.sample(
    "a",
    dist.TruncatedNormal(mu_a, sigma_a, low=0)
)

I reparametrize this as follows

reparam_config = {"a": TransformReparam()}
with numpyro.handlers.reparam(config=reparam_config):
    a = numpyro.sample(
        "a",
        dist.TransformedDistribution(
            dist.Normal(0, 1),
            [AffineTransform(mu_a, sigma_a), AbsTransform()]
        )
    )

This is how it looks like

def model(...)
    ...
    with numpyro.plate("plate_r", 1, dim=-1):
        with numpyro.plate("plate_s", 2, dim=-2):
            mu_a = numpyro.sample(
                "mu_a",
                dist.TruncatedNormal(10, 5, low=0)
            )
            sigma_a = numpyro.sample("sigma_a", dist.HalfNormal(5))

            with numpyro.plate("plate_f", 3, dim=-3):
                # Earlier: Rhats for all `a` = 1
                # a = numpyro.sample(
                #     "a",
                #     dist.TruncatedNormal(mu_a, sigma_a, low=0)
                # )

                # Now: Rhats for all `a` = 1, but for `a_base` ~ 1.5 or 2
                reparam_config = {"a": TransformReparam()}
                with numpyro.handlers.reparam(config=reparam_config):
                    a = numpyro.sample(
                        "a",
                        dist.TransformedDistribution(
                            dist.Normal(0, 1),
                            [AffineTransform(mu_a, sigma_a), AbsTransform()]
                        )
                    )
    ...

For the new model, the rhats for a are again 1, but the rhats for a_base are close to 2 or 1.5. The final recruitment curves (which make use of “a”) are all fine (like in the initial model)

How do I evaluate this situation? Can I ignore the rhats for a_base? I’m not sure what’s going on here. Is this a common occurrence with AbsTransform?