Hello,
For sure my title may look cryptic. Here is what I’m looking for.
When you run MCMC (HMC/NUTS, Langevin or whatever…) you set a series of parameters like number of chains, # of burning samples, the total number of samples per chain, and also the thinning fraction.
When I cook myself here a phi4 sampling toy scenario using Langevin in JAX (here it is a use-case with a single chain), I have this piece of code
def simul_scan(k, N=32, d=2, l=0.02, seed=4, n_burn=1000,n_spls=100_000,n_thinning=10,dt=0.01):
uparams = (l,k)
n_steps = n_spls // n_thinning
def body_in(i,carry):
subkeys,phi=carry
phi = langevin(subkeys[i],phi,uparams,N,dt)
return subkeys,phi
def body_out(phi,keys):
_,phi = jax.lax.fori_loop(0,n_thinning,body_in,(keys,phi))
mag = jnp.mean(phi)
mag2 = mag*mag
mag4 = mag2*mag2
return phi,(mag,mag2,mag4)
#init
key=jax.random.PRNGKey(seed)
key,subkey = jax.random.split(key)
phi = lattice_init(subkey,N,d)
#burn
keys= jax.random.split(key,num=n_burn+1)
key = keys[0]
subkeys=keys[1:]
_,phi = jax.lax.fori_loop(0,n_burn,body_in,(subkeys,phi))
#sampling
keys = jax.random.split(key,num=n_spls+1)
key = keys[0]
subkeys=keys[1:]
subkeys=subkeys.reshape(n_steps,subkeys.shape[0]//n_steps,2)
phi_last, mags = jax.lax.scan(body_out, phi, subkeys)
return phi_last, mags
My point concerns the use of jax.lax.scan(body_out, phi, subkeys)
which trigs the computation of intermediate quantities (mag,mag2,mag4)
every n_thinning
samples.
Then, I do not need to store all the samples (here the phi
s which are N^d fields.).
Of course, I would like to replace Langevin by MCMC NUTS for instance, it is why I’m wandering if there is such mechanism (equivalent of lax.scan(body_out...)
) that can be set using MCMC Numpyro ? Thanks