MCMC Intermediate computation saving

Hello,
For sure my title may look cryptic. Here is what I’m looking for.
When you run MCMC (HMC/NUTS, Langevin or whatever…) you set a series of parameters like number of chains, # of burning samples, the total number of samples per chain, and also the thinning fraction.

When I cook myself here a phi4 sampling toy scenario using Langevin in JAX (here it is a use-case with a single chain), I have this piece of code

def simul_scan(k, N=32, d=2, l=0.02, seed=4, n_burn=1000,n_spls=100_000,n_thinning=10,dt=0.01):
    uparams = (l,k)

    n_steps =  n_spls // n_thinning
    def body_in(i,carry):
        subkeys,phi=carry
        phi = langevin(subkeys[i],phi,uparams,N,dt)
        return subkeys,phi

    def body_out(phi,keys): 
        _,phi = jax.lax.fori_loop(0,n_thinning,body_in,(keys,phi))
        mag = jnp.mean(phi)
        mag2 = mag*mag
        mag4 = mag2*mag2
        return  phi,(mag,mag2,mag4) 

    #init
    key=jax.random.PRNGKey(seed)
    key,subkey = jax.random.split(key)
    phi = lattice_init(subkey,N,d)

    #burn
    keys= jax.random.split(key,num=n_burn+1)
    key = keys[0]
    subkeys=keys[1:]
    _,phi = jax.lax.fori_loop(0,n_burn,body_in,(subkeys,phi))

    #sampling
    keys = jax.random.split(key,num=n_spls+1)
    key = keys[0]
    subkeys=keys[1:]

    subkeys=subkeys.reshape(n_steps,subkeys.shape[0]//n_steps,2)

    phi_last, mags = jax.lax.scan(body_out, phi, subkeys)

    return phi_last, mags

My point concerns the use of jax.lax.scan(body_out, phi, subkeys) which trigs the computation of intermediate quantities (mag,mag2,mag4) every n_thinning samples.
Then, I do not need to store all the samples (here the phis which are N^d fields.).

Of course, I would like to replace Langevin by MCMC NUTS for instance, it is why I’m wandering if there is such mechanism (equivalent of lax.scan(body_out...) ) that can be set using MCMC Numpyro ? Thanks

not sure what you’re asking but the MCMC class takes a thinning argument:

https://num.pyro.ai/en/stable/mcmc.html#numpyro.infer.mcmc.MCMC

Hi, the point is to not save the samples but the store “mag”-variables on samples based on thinning.

you don’t want to store the latents at all but you want to store thinned deterministic functions of latents?

Yes the ( mag,mag2,mag4) functions that I store to get a distribution based on thinned samples

if you add mag = numpyro.deterministic("mag", my_deterministic_function(...)) and use thinning you’ll get at least some of what you want. doing something else would probably require subclassing some numpyro code and making custom modifications