By default, numpyro will save all the random variables, however, I want to run a very long mcmc and too many random variables will take up too much memory. Is there a way to save only a subset of variables instead of all? Thanks in advance!
I think you can subclass HMC/NUTS and remove unnecessary terms in postprocess_fn.
Actually, currently we perform
postprocess_fn after collecting all states. So the above solution won’t save memory. I think we can add a flag to apply
postprocess_fn during collecting all states at this place. Please make a FR/PR for this.
note there there is also a
thinning argument here. this won’t let you subset variables but it’ll let you keep only e.g. every 10th sample which also leads to large memory savings