Potential energy of reparametrized model

Hey, I’m using NUTS to sample a model with ~30 parameters, and I find that I get good convergence when I reparametrize the model with LocScaleReparam and TransformReparam for Gaussian and Uniform prior distributions, respectively.

I would like to extract the potential energy of each sample in the (original) chain. However, I find that if I ask for extra_fields=('potential_energy',), this value is different from the one I can compute manually by adding the log prior and the log likelihood for the posterior samples. I have set numpyro.enable_x64().

Is the reparametrization changing the value of the potential energy? Or are there two sets of potential energy values, one for the “base” parameters and one for the “original” ones?

Many thanks for your help!


Yes, p(x, y) is different from p(transform(x), y). A log det jacobian term is sometimes needed.

1 Like

Thank you @fehiepsi! So I assume the potential energy value returned by NUTS by default is the one referring to the reparametrized variables. If you happen to know whether the potential energy for the original parameters is also stored somewhere that would be great.

Many thanks again!

You can use log_density to compute it. Something like

samples = mcmc.get_samples()
pe = jax.vmap(lambda sample: -log_density(orig_model, args, kwargs, sample)[0])(samples)
1 Like

Brilliant - thank you!