When I am tring to implement latent variable GP using Numpyro, I usually use a non-center parameterization in which the cholesky deompostion of the cov matrix is multiplied with a vector of standard normal variables. But I am wondering if I can use the reparam handler in numpyro to automatically reparamterize multivairate normal distribution? I see the doc of multivairate normal dist where there is a list of reparamized parameters, But I don’t know how to implement that.
Thank you
Hi @z563751632, I think you can use TransformReparam for this
with handlers.reparam(config=TransformReparam()):
numpyro.sample("x",
dist.TransformedDistribution(
dist.Normal(0, 1).expand(loc.shape[-1:]),
dist.transforms.LowerCholeskyAffine(loc, scale_tril))
FYI, if the dimension is large, using reparam here also speeds up log density computation because the base dist Normal
's log_prob is much faster to compute than MultivariateNormal
's log_prob.
If you found this a bit verbose, you can open a FR for sort of LocScaleTriLReparam (though I think it is better to be more verbose).