I want to use numpyro to infer the parameters of an ODE system. The complication is that I want to solve the same ODE system for 100s/1000s of different objects (different initial conditions) within a single iteration of free parameter values. In principle I could use vmap for this but I want to use pmap (or shard_map) to take advantage of the embarrassingly parallel nature of this problem to assign each ODE system to a different CPU core on a single node. (vmap over my ODE system takes forever to compile, and while the thousands of cores on a single GPU may help, I might need multiple GPUs anyway due to their limited memory.)
I have adapted the official predator-prey numpyro example (Example: Predator-Prey Model — NumPyro documentation) to illustrate how I imagine using pmap within the numpyro model function. The following code works for this simple example but I get a warning: UserWarning: The jitted function scan includes a pmap. Using jit-of-pmap can lead to inefficient data movement, as the outer jit does not preserve sharded data representations and instead collects input and output arrays onto a single device. Consider removing the outer jit unless you know what you're doing. See https://github.com/google/jax/issues/2926.
For my actual more complicated ODE system where pmap takes longer to finish (~30 sec), I get that same warning plus MANY more print statements that look like
E0324 17:04:20.251938 815459 collective_ops_utils.h:269] This thread has been waiting for 5000ms for and may be stuck: participant AllReduceParticipantData{rank=0, element_count=280, type=F64, rendezvous_key=RendezvousKey{run_id=RunId: 248276, global_devices=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39], num_local_participants=40, collective_op_kind=cross_replica, op_id=220}} waiting for all participants to arrive at rendezvous RendezvousKey{run_id=RunId: 248276, global_devices=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39], num_local_participants=40, collective_op_kind=cross_replica, op_id=220}
Is there a recommended best practice for using pmap/shard_map within a numpyro model where, within a single iteration of MCMC (single realization of free parameters), we need to evaluate an expensive but embarrassingly parallel function (ODE solver in this case) over 1000s of other fixed parameters?
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8" # assume we have 8 CPU cores for pmap
from jax import config
config.update("jax_enable_x64", True)
from jax.random import PRNGKey, uniform, lognormal
import jax.numpy as jnp
from jax import jit, pmap
import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import LYNXHARE, load_dataset
from numpyro.infer import MCMC, NUTS, Predictive
from diffrax import diffeqsolve, ODETerm, SaveAt, Tsit5
### download pre-packaged numpyro data for Lotka-Volterra problem
_, fetch = load_dataset(LYNXHARE, shuffle=False)
year, data = fetch() # data is in hare -> lynx order
### Lotka-Volterra ODE function
def vector_field(t, y, args):
prey, predator = y
alpha, beta, gamma, delta = args
dprey_dt = alpha * prey - beta * prey * predator
dpredator_dt = -gamma * predator + delta * prey * predator
return dprey_dt, dpredator_dt
### setup for diffrax
term = ODETerm(vector_field)
solver = Tsit5()
t0 = jnp.asarray(year[0],dtype=jnp.float64)
t1 = jnp.asarray(year[-1],dtype=jnp.float64)
dt0 = 0.1
saveat = SaveAt(ts=jnp.asarray(year,dtype=jnp.float64)) # evaluate at same times as observations
"""
pretend we have 8 different objects aka 8 different ICs we want to pmap over
yes this could be done with vmap, but for my actual problem, I need pmap/shard_map
since I have 1000s of objects/ICs (embarassingly parallel problem) and vmap is slow to compile
"""
# generate 8 random combinations of (alpha,beta,gamma,delta,init_prey,init_predator)
rand_params = uniform(PRNGKey(1),shape=(8,4),minval=0.01,maxval=1.0) # (alpha,beta,gamma,delta) ~ U(0.01,1)
rand_ICs = lognormal(PRNGKey(1),jnp.log(10),(8,2)) # init_prey, init_predator ~ LogNormal(sigma=log(10))
rand_args = jnp.hstack((rand_params,rand_ICs)) # each row is (alpha,beta,gamma,delta,init_prey,init_predator) -- 8 rows for 8 fixed ICs
# function that solves ODE problem for a given array of (alpha,beta,gamma,delta,init_prey,init_predator)
def solve_diffrax(args):
# returning transposed solution so it has same dims as observed data: 91x2
return jnp.asarray(diffeqsolve(term, solver, t0, t1, dt0, tuple(args[4:]), args=args[:4], saveat=saveat).ys).T
# create a pmapped version of that ODE solver function -- this will operate over each row of rand_args in parallel (8 ICs for 8 devices)
pmap_solve_diffrax = pmap(solve_diffrax)
# run it once to compile
ptest = pmap_solve_diffrax(rand_args)
### numpyro model where the 8 ICs will be solved in parallel for different (alpha,beta,gamma,delta)
### obs_data is the observed time series of (prey,predator) vs time and we will fix uncertainty obs_sigma=0.8 instead of fitting for it
def model(rand_args, obs_data, obs_sigma=0.8):
# sample our free parameters (alpha,beta,gamma,delta)
alpha = numpyro.sample('alpha',dist.TruncatedNormal(low=0.0,loc=1.0,scale=0.5))
beta = numpyro.sample('beta',dist.TruncatedNormal(low=0.0,loc=0.05,scale=0.05))
gamma = numpyro.sample('gamma',dist.TruncatedNormal(low=0.0,loc=1.0,scale=0.5))
delta = numpyro.sample('delta',dist.TruncatedNormal(low=0.0,loc=0.05,scale=0.05))
# update rand_args with the random samples of (alpha,beta,gamma,delta)
rand_args = rand_args.at[:,0].set(alpha)
rand_args = rand_args.at[:,1].set(beta)
rand_args = rand_args.at[:,2].set(gamma)
rand_args = rand_args.at[:,3].set(delta)
# use pmap to solve ODEs for all 8 ICs in parallel for current set of (alpha,beta,gamma,delta)
psol = pmap_solve_diffrax(rand_args)
# simple placeholder: compute mean solution of (prey,predator) vs time averaging over the 8 ICs
mean_sol = jnp.mean(psol,axis=0)
# likelihood similar to official example https://num.pyro.ai/en/stable/examples/ode.html
numpyro.sample('obs', dist.LogNormal(jnp.log(mean_sol), obs_sigma), obs=obs_data)
### Finally, prior predictive check to test
### This works but raises "jit-of-pmap" warning
prior_predictive = Predictive(model, num_samples=100)
prior_predictions = prior_predictive(PRNGKey(5), rand_args=rand_args, obs_data=None, obs_sigma=0.8)['obs']