I realize that it may be hard to do much troubleshooting without the full model code, but here’s the seemingly relevant info I can easily give:

I have a hyperparameter `mu_tau`

for `mu`

defined roughly like this:

```
with reparam(config={"mu_tau": TransformReparam()}):
mu_tau = (
ny.sample(
f"mu_tau",
dist.TransformedDistribution(dist.HalfNormal(scale=1), AffineTransform(loc=0, scale=0.1)),
)
)
with reparam(config={"mu": TransformReparam()}):
mu = ny.sample(
"mu",
dist.TransformedDistribution(
dist.Normal(loc=0, scale=1).expand((shape.n_endog,)),
LowerCholeskyAffine(loc=jnp.zeros(shape.n_endog), scale_tril=mu_chol * mu_tau),
),
)
```

This model works fine (0 divergences, r hats at 1) around 70% of the time. But for some `PRNGKey`

s passed to `mcmc.run`

, one chain is an outlier from the rest (also happens to ~2 chains when I bump up to 8 chains). Here are the diagnostic graphs.

Latent posterior distributions for each chain:

Prior and posterior distributions:

Chain autocorrelation:

Rank plot:

Trace plot:

Does anyone have any guesses as to what’s going on or what additional diagnostic information I should seek?