# MCMC Intermediate computation saving

Hello,
For sure my title may look cryptic. Here is what I’m looking for.
When you run MCMC (HMC/NUTS, Langevin or whatever…) you set a series of parameters like number of chains, # of burning samples, the total number of samples per chain, and also the thinning fraction.

When I cook myself here a phi4 sampling toy scenario using Langevin in JAX (here it is a use-case with a single chain), I have this piece of code

``````def simul_scan(k, N=32, d=2, l=0.02, seed=4, n_burn=1000,n_spls=100_000,n_thinning=10,dt=0.01):
uparams = (l,k)

n_steps =  n_spls // n_thinning
def body_in(i,carry):
subkeys,phi=carry
phi = langevin(subkeys[i],phi,uparams,N,dt)
return subkeys,phi

def body_out(phi,keys):
_,phi = jax.lax.fori_loop(0,n_thinning,body_in,(keys,phi))
mag = jnp.mean(phi)
mag2 = mag*mag
mag4 = mag2*mag2
return  phi,(mag,mag2,mag4)

#init
key=jax.random.PRNGKey(seed)
key,subkey = jax.random.split(key)
phi = lattice_init(subkey,N,d)

#burn
keys= jax.random.split(key,num=n_burn+1)
key = keys[0]
subkeys=keys[1:]
_,phi = jax.lax.fori_loop(0,n_burn,body_in,(subkeys,phi))

#sampling
keys = jax.random.split(key,num=n_spls+1)
key = keys[0]
subkeys=keys[1:]

subkeys=subkeys.reshape(n_steps,subkeys.shape[0]//n_steps,2)

phi_last, mags = jax.lax.scan(body_out, phi, subkeys)

return phi_last, mags
``````

My point concerns the use of `jax.lax.scan(body_out, phi, subkeys)` which trigs the computation of intermediate quantities `(mag,mag2,mag4)` every `n_thinning` samples.
Then, I do not need to store all the samples (here the `phi`s which are N^d fields.).

Of course, I would like to replace Langevin by MCMC NUTS for instance, it is why I’m wandering if there is such mechanism (equivalent of `lax.scan(body_out...) `) that can be set using MCMC Numpyro ? Thanks

not sure what you’re asking but the `MCMC` class takes a `thinning` argument:

https://num.pyro.ai/en/stable/mcmc.html#numpyro.infer.mcmc.MCMC

Hi, the point is to not save the samples but the store “mag”-variables on samples based on thinning.

you don’t want to store the latents at all but you want to store thinned deterministic functions of latents?

Yes the ( mag,mag2,mag4) functions that I store to get a distribution based on thinned samples

if you add `mag = numpyro.deterministic("mag", my_deterministic_function(...))` and use thinning you’ll get at least some of what you want. doing something else would probably require subclassing some numpyro code and making custom modifications