How to reparameterize LogNormals in Hierarchical Models?

Hi everyone,

I was wondering about how to properly re-parameterize LogNormals in a hierarchical model.

I observed that in my case using the reparam.handler

loc = numpyro.sample("loc", Normal(0,1))
sd = numpyro.sample("loc", HalfNormal(1))

with numpyro.handlers.reparam(config={'param': LocScaleReparam(0)}):
    param = numpyro.sample("param", LogNormal(loc,sd))

and the more classical approach

loc = numpyro.sample("loc", Normal(0,1))
sd = numpyro.sample("loc", HalfNormal(1))

param_raw = numpyro.sample("param_raw", Normal(0,sd))
param = numpyro.deterministic("param, jnp.exp(loc + param_raw))

yield different results, where the first results in extremely high r-hat values (in my particular case).

I couldnt find something in the documentation or the forums, so my question is:

  • 1. How do we reparameterize the LogNormal correctly? And
  • 2. Does reparameterizing the hyperpriors - if more informative - also make sense?

Thank you!

Best
N

Maybe

def model():
    loc = numpyro.sample("loc", dist.Normal(0,1))
    sd = numpyro.sample("scale", dist.HalfNormal(1))

    with handlers.reparam(config={'param_base': LocScaleReparam(0)}):
        with handlers.reparam(config={'param': TransformReparam()}):
            param = numpyro.sample("param", dist.LogNormal(loc,sd))

print(handlers.trace(handlers.seed(model, rng_seed=0)).get_trace())

or

@handlers.reparam(config={'param_base': LocScaleReparam(0)})
@handlers.reparam(config={'param': TransformReparam()})
def model():
    loc = numpyro.sample("loc", dist.Normal(0,1))
    sd = numpyro.sample("scale", dist.HalfNormal(1))
    param = numpyro.sample("param", dist.LogNormal(loc, sd))

Note that after Raise error when using LocScaleReparam for non-real distributions by fehiepsi · Pull Request #1548 · pyro-ppl/numpyro · GitHub, applying LocScaleReparam directly to LogNormal will raise an error.

1 Like

Hei @fehiepsi,

thanks for the prompt reply; that looks perfect. No wonder I got strange results.

Another small thing I noticed is often people provide the 0 to LocScaleReparam (invoking the full decentering), but the default is to learn a “per-site per-element centering”. Is there some guidance on when to use what? I’m, of course, interested in LogNormals in particular, but also curious about more general thoughts on that matter.

Thank you!
Best N

that’s only relevant for variational inference