Stop tracking of variables

Hello, I have a large model with O(100,000) latent variables and so running the MCMC for many steps results in a very high memory usage. Is it possible in numpyro to specify which variables are not to be tracked if I do not wish to keep track of them?

Or, would the alternative be to run for a number of steps, save what I need and then restart from a last state?

Thanks!

i don’t think there’s any mechanism to do that out of the box. so the alternative you propose is probably the easiest path to get what you want.

Please see extra_fields docs of MCMC.run Markov Chain Monte Carlo (MCMC) — NumPyro documentation for a solution

Thanks. extra_fields works in most cases. However whenever I include reparameterization of xdir like in this toy example, then the code throws an error after it finished sampling. If I comment out the reparameterization, then it works as desired.

def model():
    x = numpyro.sample("x", dist.Normal(0, 1))
    y = numpyro.sample("y", dist.Normal(0, 1))

    with numpyro.handlers.reparam(config={"xdir": numpyro.infer.reparam.ProjectedNormalReparam()}):
        xdir = numpyro.sample("xdir",  dist.ProjectedNormal(jnp.zeros(3)))

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(random.PRNGKey(0), extra_fields=("~z.x",))

The motivation for that reparameterization is that I want to sample angles on the surface of a sphere.

In fact, the problem seems to appear whenever I use something like @numpyro.handlers.reparam(config={"xdir_dipole": numpyro.infer.reparam.CircularReparam()}), so potentially any reparameterization.

Could you make a bug for this? For reparam models, we need to invoke a postprocess function that replays the model to collect deterministic samples after sampling. If you don’t need those deterministic samples, then it’s better to not replay the model.

I think you can do this

with numpyro.handlers.block(hide=["xdir"]), numpyro.handlers.reparam(...):
    ...

Thanks! This helped but still is not exactly what I need since I need to collect the deterministic samples.

For example, on this toy example I don’t want to collect x but I want to collect the samples of theta. But this approach hides theta as well. Is there any way around this please?

def model():
    x = numpyro.sample("x", dist.Normal(0, 1))
    y = numpyro.sample("y", dist.Normal(0, 1))

    with numpyro.handlers.block(hide=["xdir_dipole", "theta"],), numpyro.handlers.reparam(config={"xdir_dipole": numpyro.infer.reparam.ProjectedNormalReparam()}):
        xdir = numpyro.sample("xdir_dipole",  dist.ProjectedNormal(jnp.zeros(3)))
        theta = numpyro.deterministic("theta", jnp.arccos(xdir[2]))

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
mcmc.run(random.PRNGKey(0), extra_fields=("~z.x",))

I guess you don’t want to block “theta” in your code numpyro.handlers.block(hide=["xdir_dipole", "theta"],)

Could you make an issue for this? I think we can do replaying while collecting samples, not at the end. I barely remember that doing so caused some performance issues (but could be useful for your usage case) Edit: it was just an api change back to the early numpyro development Add additional options for MCMC sampling by neerajprad · Pull Request #316 · pyro-ppl/numpyro · GitHub Edit 2 created the new issue Postprocessing while collecting samples · Issue #1908 · pyro-ppl/numpyro · GitHub

Just a note that the issue could have been fixed in the master branch. Could you double check @Rich-Sti ?

Awesome! This completely fixed it :slight_smile: Thanks so much