Hi everyone,
I was wondering about how to properly reparameterize LogNormals in a hierarchical model.
I observed that in my case using the reparam.handler
loc = numpyro.sample("loc", Normal(0,1))
sd = numpyro.sample("loc", HalfNormal(1))
with numpyro.handlers.reparam(config={'param': LocScaleReparam(0)}):
param = numpyro.sample("param", LogNormal(loc,sd))
and the more classical approach
loc = numpyro.sample("loc", Normal(0,1))
sd = numpyro.sample("loc", HalfNormal(1))
param_raw = numpyro.sample("param_raw", Normal(0,sd))
param = numpyro.deterministic("param, jnp.exp(loc + param_raw))
yield different results, where the first results in extremely high rhat values (in my particular case).
I couldnt find something in the documentation or the forums, so my question is:

1. How do we reparameterize the LogNormal correctly? And
 2. Does reparameterizing the hyperpriors  if more informative  also make sense?
Thank you!
Best
N
Maybe
def model():
loc = numpyro.sample("loc", dist.Normal(0,1))
sd = numpyro.sample("scale", dist.HalfNormal(1))
with handlers.reparam(config={'param_base': LocScaleReparam(0)}):
with handlers.reparam(config={'param': TransformReparam()}):
param = numpyro.sample("param", dist.LogNormal(loc,sd))
print(handlers.trace(handlers.seed(model, rng_seed=0)).get_trace())
or
@handlers.reparam(config={'param_base': LocScaleReparam(0)})
@handlers.reparam(config={'param': TransformReparam()})
def model():
loc = numpyro.sample("loc", dist.Normal(0,1))
sd = numpyro.sample("scale", dist.HalfNormal(1))
param = numpyro.sample("param", dist.LogNormal(loc, sd))
Note that after Raise error when using LocScaleReparam for nonreal distributions by fehiepsi · Pull Request #1548 · pyroppl/numpyro · GitHub, applying LocScaleReparam directly to LogNormal will raise an error.
Hei @fehiepsi,
thanks for the prompt reply; that looks perfect. No wonder I got strange results.
Another small thing I noticed is often people provide the 0 to LocScaleReparam (invoking the full decentering), but the default is to learn a “persite perelement centering”. Is there some guidance on when to use what? I’m, of course, interested in LogNormals in particular, but also curious about more general thoughts on that matter.
Thank you!
Best N
that’s only relevant for variational inference