Hello, I have a large model with O(100,000) latent variables and so running the MCMC for many steps results in a very high memory usage. Is it possible in numpyro to specify which variables are not to be tracked if I do not wish to keep track of them?
Or, would the alternative be to run for a number of steps, save what I need and then restart from a last state?
Thanks!
i don’t think there’s any mechanism to do that out of the box. so the alternative you propose is probably the easiest path to get what you want.
Please see extra_fields
docs of MCMC.run Markov Chain Monte Carlo (MCMC) — NumPyro documentation for a solution
Thanks. extra_fields
works in most cases. However whenever I include reparameterization of xdir
like in this toy example, then the code throws an error after it finished sampling. If I comment out the reparameterization, then it works as desired.
def model():
x = numpyro.sample("x", dist.Normal(0, 1))
y = numpyro.sample("y", dist.Normal(0, 1))
with numpyro.handlers.reparam(config={"xdir": numpyro.infer.reparam.ProjectedNormalReparam()}):
xdir = numpyro.sample("xdir", dist.ProjectedNormal(jnp.zeros(3)))
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(random.PRNGKey(0), extra_fields=("~z.x",))
The motivation for that reparameterization is that I want to sample angles on the surface of a sphere.
In fact, the problem seems to appear whenever I use something like @numpyro.handlers.reparam(config={"xdir_dipole": numpyro.infer.reparam.CircularReparam()})
, so potentially any reparameterization.
Could you make a bug for this? For reparam models, we need to invoke a postprocess function that replays the model to collect deterministic samples after sampling. If you don’t need those deterministic samples, then it’s better to not replay the model.
I think you can do this
with numpyro.handlers.block(hide=["xdir"]), numpyro.handlers.reparam(...):
...
Thanks! This helped but still is not exactly what I need since I need to collect the deterministic samples.
For example, on this toy example I don’t want to collect x
but I want to collect the samples of theta
. But this approach hides theta
as well. Is there any way around this please?
def model():
x = numpyro.sample("x", dist.Normal(0, 1))
y = numpyro.sample("y", dist.Normal(0, 1))
with numpyro.handlers.block(hide=["xdir_dipole", "theta"],), numpyro.handlers.reparam(config={"xdir_dipole": numpyro.infer.reparam.ProjectedNormalReparam()}):
xdir = numpyro.sample("xdir_dipole", dist.ProjectedNormal(jnp.zeros(3)))
theta = numpyro.deterministic("theta", jnp.arccos(xdir[2]))
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
mcmc.run(random.PRNGKey(0), extra_fields=("~z.x",))
I guess you don’t want to block “theta” in your code numpyro.handlers.block(hide=["xdir_dipole", "theta"],)
Could you make an issue for this? I think we can do replaying while collecting samples, not at the end. I barely remember that doing so caused some performance issues (but could be useful for your usage case) Edit: it was just an api change back to the early numpyro development Add additional options for MCMC sampling by neerajprad · Pull Request #316 · pyro-ppl/numpyro · GitHub Edit 2 created the new issue Postprocessing while collecting samples · Issue #1908 · pyro-ppl/numpyro · GitHub
Just a note that the issue could have been fixed in the master branch. Could you double check @Rich-Sti ?
Awesome! This completely fixed it Thanks so much