The input type and shape for MCMC.run with AIES

Hi,

If I have some function that takes an array params [n,m], and i want to use it with the AIES kernel and MCMC, say with 20 chains, what shape should the inputs be?

aies = AIES(some_function,moves={AIES.DEMove(): 0.5, AIES.StretchMove(): 0.5})
mcmc = MCMC(aies,
            num_warmup=1000,
            num_samples=2000,
            num_chains=num_chains,
            chain_method="vectorized")
mcmc.run(
            jax.random.PRNGKey(0),
            params=initial_params)

i.e. what shape should initial_params be? [n,m] or something like [num_chains,n,m]

The docs dont seem to make it clear.

Thanks!

cc @amifalk

Yes, params should be proceeded by (num_chains,). The log density function is evaluated with by vmapping across axis 0 (the num_chains axis)

1 Like

Thanks @amifalk - I’ll give that a go.

So some_function essentially will be called multiple times, with [0,:], [1,:], [2,:]… [num_chains,:] etc?

It seems to be passing through the entire [num_chains,:] inital_parameter array to the function.

some_function requires a 1d array of length 2, e.g/ [4.6051702, -5.241507].

I tile this with
initial_params = jnp.tile(initial_params, (num_chains, 1))

to get:

[[ 4.6051702, -5.241507 ],
       [ 4.6051702, -5.241507 ],
       [ 4.6051702, -5.241507 ],
       [ 4.6051702, -5.241507 ]]

for say 4 chains. But when the function is called it has:

initial_params = Array([[ 4.6051702, -5.241507 ],
       [ 4.6051702, -5.241507 ],
       [ 4.6051702, -5.241507 ],
       [ 4.6051702, -5.241507 ]], dtype=float32)

Clearly I am misunderstanding something here. Any suggestions on how to get it to work as expected?

Under the hood, the function argument you provide to potential_fn is vmapped. One vectorized function call is applied over the chains. It might be helpful to read about the vmap transform - the function calls aren’t executed sequentially like you’re suggesting: Quickstart — JAX documentation

1 Like

Ah - ok. Thanks @amifalk!