Numpyro complaining about pmap inside jit?

I want to use numpyro to infer the parameters of an ODE system. The complication is that I want to solve the same ODE system for 100s/1000s of different objects (different initial conditions) within a single iteration of free parameter values. In principle I could use vmap for this but I want to use pmap (or shard_map) to take advantage of the embarrassingly parallel nature of this problem to assign each ODE system to a different CPU core on a single node. (vmap over my ODE system takes forever to compile, and while the thousands of cores on a single GPU may help, I might need multiple GPUs anyway due to their limited memory.)

I have adapted the official predator-prey numpyro example (Example: Predator-Prey Model — NumPyro documentation) to illustrate how I imagine using pmap within the numpyro model function. The following code works for this simple example but I get a warning: UserWarning: The jitted function scan includes a pmap. Using jit-of-pmap can lead to inefficient data movement, as the outer jit does not preserve sharded data representations and instead collects input and output arrays onto a single device. Consider removing the outer jit unless you know what you're doing. See https://github.com/google/jax/issues/2926.

For my actual more complicated ODE system where pmap takes longer to finish (~30 sec), I get that same warning plus MANY more print statements that look like

E0324 17:04:20.251938  815459 collective_ops_utils.h:269] This thread has been waiting for 5000ms for and may be stuck: participant AllReduceParticipantData{rank=0, element_count=280, type=F64, rendezvous_key=RendezvousKey{run_id=RunId: 248276, global_devices=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39], num_local_participants=40, collective_op_kind=cross_replica, op_id=220}} waiting for all participants to arrive at rendezvous RendezvousKey{run_id=RunId: 248276, global_devices=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39], num_local_participants=40, collective_op_kind=cross_replica, op_id=220}

Is there a recommended best practice for using pmap/shard_map within a numpyro model where, within a single iteration of MCMC (single realization of free parameters), we need to evaluate an expensive but embarrassingly parallel function (ODE solver in this case) over 1000s of other fixed parameters?

import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" # assume we have 8 CPU cores for pmap

from jax import config
config.update("jax_enable_x64", True)
from jax.random import PRNGKey, uniform, lognormal
import jax.numpy as jnp
from jax import jit, pmap
import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import LYNXHARE, load_dataset
from numpyro.infer import MCMC, NUTS, Predictive
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5

### download pre-packaged numpyro data for Lotka-Volterra problem
_, fetch = load_dataset(LYNXHARE, shuffle=False)
year, data = fetch()  # data is in hare -> lynx order

### Lotka-Volterra ODE function
def vector_field(t, y, args):
    prey, predator = y
    alpha, beta, gamma, delta = args    
    dprey_dt = alpha * prey - beta * prey * predator
    dpredator_dt = -gamma * predator + delta * prey * predator
    return dprey_dt, dpredator_dt

### setup for diffrax
term = ODETerm(vector_field)
solver = Tsit5()
t0 = jnp.asarray(year[0],dtype=jnp.float64)
t1 = jnp.asarray(year[-1],dtype=jnp.float64)
dt0 = 0.1 
saveat = SaveAt(ts=jnp.asarray(year,dtype=jnp.float64)) # evaluate at same times as observations

"""
pretend we have 8 different objects aka 8 different ICs we want to pmap over 
yes this could be done with vmap, but for my actual problem, I need pmap/shard_map 
since I have 1000s of objects/ICs (embarassingly parallel problem) and vmap is slow to compile
"""

# generate 8 random combinations of (alpha,beta,gamma,delta,init_prey,init_predator)
rand_params = uniform(PRNGKey(1),shape=(8,4),minval=0.01,maxval=1.0) # (alpha,beta,gamma,delta) ~ U(0.01,1)
rand_ICs = lognormal(PRNGKey(1),jnp.log(10),(8,2)) # init_prey, init_predator ~ LogNormal(sigma=log(10))
rand_args = jnp.hstack((rand_params,rand_ICs)) # each row is (alpha,beta,gamma,delta,init_prey,init_predator) -- 8 rows for 8 fixed ICs

# function that solves ODE problem for a given array of (alpha,beta,gamma,delta,init_prey,init_predator)
def solve_diffrax(args):
    
    # returning transposed solution so it has same dims as observed data: 91x2
    return jnp.asarray(diffeqsolve(term, solver, t0, t1, dt0, tuple(args[4:]), args=args[:4], saveat=saveat).ys).T

# create a pmapped version of that ODE solver function -- this will operate over each row of rand_args in parallel (8 ICs for 8 devices)
pmap_solve_diffrax = pmap(solve_diffrax)

# run it once to compile
ptest = pmap_solve_diffrax(rand_args)


### numpyro model where the 8 ICs will be solved in parallel for different (alpha,beta,gamma,delta) 
### obs_data is the observed time series of (prey,predator) vs time and we will fix uncertainty obs_sigma=0.8 instead of fitting for it
def model(rand_args, obs_data, obs_sigma=0.8):

    # sample our free parameters (alpha,beta,gamma,delta)
    alpha = numpyro.sample('alpha',dist.TruncatedNormal(low=0.0,loc=1.0,scale=0.5))
    beta = numpyro.sample('beta',dist.TruncatedNormal(low=0.0,loc=0.05,scale=0.05))
    gamma = numpyro.sample('gamma',dist.TruncatedNormal(low=0.0,loc=1.0,scale=0.5))
    delta = numpyro.sample('delta',dist.TruncatedNormal(low=0.0,loc=0.05,scale=0.05))    

    # update rand_args with the random samples of (alpha,beta,gamma,delta) 
    rand_args = rand_args.at[:,0].set(alpha)
    rand_args = rand_args.at[:,1].set(beta)
    rand_args = rand_args.at[:,2].set(gamma)
    rand_args = rand_args.at[:,3].set(delta)    
    
    # use pmap to solve ODEs for all 8 ICs in parallel for current set of (alpha,beta,gamma,delta)
    psol = pmap_solve_diffrax(rand_args)
    
    # simple placeholder: compute mean solution of (prey,predator) vs time averaging over the 8 ICs
    mean_sol = jnp.mean(psol,axis=0)
    
    # likelihood similar to official example https://num.pyro.ai/en/stable/examples/ode.html
    numpyro.sample('obs', dist.LogNormal(jnp.log(mean_sol), obs_sigma), obs=obs_data)


### Finally, prior predictive check to test
### This works but raises "jit-of-pmap" warning
prior_predictive = Predictive(model, num_samples=100)
prior_predictions = prior_predictive(PRNGKey(5), rand_args=rand_args, obs_data=None, obs_sigma=0.8)['obs']

I would recommend playing with jax code first to see if you can use pmap etc. that way. You might want to use pjit with sharding constraints but I feel that it is still inefficient. Both data parallelism and tensor parallelism do not apply to your case.

@fehiepsi thanks I think I got it working with shard_map + nested vmap inside. The shard_map splits the ~1000s of fixed parameters into batches equal to the # of devices available. The inner nested vmap then tells diffrax to vectorize the ODE solve over the # of parameter combinations in a single batch.

Problem: numpyro sampling with AIES runs fast for this simple predator-prey ODE system with the nested shard_map + vmap approach. For my more complicated actual ODE system, the nested shard_map + vmap also works (~500 objects solved in parallel in 0.01 sec for a given set of model parameters) both outside of numpyro and for the prior predictive check. However, when trying to sample (with AIES for example), it prints lots of these statements (I think 40 times since I have 40 CPU core devices):

E0401 22:54:46.762880  421724 collective_ops_utils.h:269] This thread has been waiting for 5000ms for and may be stuck: participant AllReduceParticipantData{rank=34, element_count=40, type=F64, rendezvous_key=RendezvousKey{run_id=RunId: 166676, global_devices=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39], num_local_participants=40, collective_op_kind=cross_module, op_id=1}} waiting for all participants to arrive at rendezvous RendezvousKey{run_id=RunId: 166676, global_devices=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39], num_local_participants=40, collective_op_kind=cross_module, op_id=1}

Any ideas? Could it be because I’m maxing out threads (batching with shard_map over Ndevices, and also setting Nchains=10 or higher with chain_method=‘vectorized’)?

Would using .block_until_ready() help / be wise inside the numpyro model function?

What do you mean by “Both data parallelism and tensor parallelism do not apply to your case.”?

Re data & tensor (or model) parallelism: you can find a reference here

I think shard_map might also work with jit and grad. I don’t think block until ready helps. Could you try to vmap (corresponding to mcmc chains) value_and_grad of your non-numpyro model? If it works, then I think you can use the lower level api for sampling, as in Example: Stochastic Volatility — NumPyro documentation

Thanks @fehiepsi I will look into that soon. Currently as a placeholder I am just using @vmap to solve a smaller # of ODE systems with different ICs/forcing terms for a given set of MCMC parameters within numpyro. I made some optimizations to my solve_diffrax function for my actual ODE problem so vmap alone is better (but still would prefer the much faster shard_map + nested vmap).

However, even when running just num_samples=100 with 10 vectorized chains (so mcmc.get_samples() returns 1000 parameter combinations), drawing posterior predictive check samples is taking a long time. Any idea for how to speed up posterior predictive check? I assume this should be parallelizable since it’s post-processing…

post_samples = mcmc.get_samples()
post_predictive = Predictive(model,post_samples)
post_obs = post_predictive(PRNGKey(1), rand_args=rand_args, obs_data=None, obs_sigma=0.8)['obs']

Edit: in addition, I am trying to convert the mcmc object to arviz InferenceData so I can save as netcdf and use arviz plots, but this is also taking a long time – ~30 sec with only ~10K samples for the simple predator-prey ODE example… I suspect both the posterior predictive and az.from_numpyro are not taking advantage of parallelism?

numpyro_data = az.from_numpyro(mcmc,prior=prior_predictions,posterior_predictive=post_preds)

Maybe Predictive(..., parallel=True)?