I’m wondering if there’s a way to efficiently use `jax.pmap`

for parallel computation of the log density. It seems that the entire log density is automatically `jit`

ed, yielding a big performance penalty. Here’s a toy example showing this behavior (I know there’s nothing to be gained via `pmap`

here, it just demonstrates the issue):

```
import numpy as np
import jax
import numpyro
import jax.random as random
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import jax.numpy as jnp
## dataset dimensions
M=100; N=200
## true parameter values
sbeta_true = np.sqrt(.5)
se_true = np.sqrt(.5)
beta_true = np.random.randn(M)*sbeta_true
## observed data
X = np.random.randn(M*N).reshape(N,M)
X = np.apply_along_axis(lambda x: (x-np.mean(x))/np.std(x), 0, X) / np.sqrt(M)
e = np.random.randn(N)*se_true
y = X @ beta_true + e
y -= np.mean(y)
## map reduce computation of likelihood
def y_likelihood(X,b,y,se):
return jnp.sum(dist.Normal(0., se).log_prob(y-jnp.dot(X,b)))
mapped_y_likelihood=jax.pmap(y_likelihood, in_axes=(0,None,0,None))
## model specification
def toy_model(y=None, X=None):
s_beta = numpyro.sample('s_beta', dist.HalfCauchy(1))
s_e = numpyro.sample('s_e', dist.HalfCauchy(1),sample_shape=(1,))
beta = numpyro.sample('beta', dist.Normal(0.,s_beta),sample_shape=(X.shape[1],))
lpy=mapped_y_likelihood(X, beta, y, s_e)
numpyro.factor('y',jnp.sum(lpy))
## construct kernel
nuts_kernel = NUTS(toy_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
rng_key = random.PRNGKey(0)
## run model
mcmc.run(rng_key, y=y.reshape(2,y.shape[0]//2),
X=X.reshape(2,X.shape[0]//2,X.shape[1]))
```

This yields the warning `UserWarning: The jitted function _body_fn includes a pmap. Using jit-of-pmap can lead to inefficient data movement, as the outer jit does not preserve sharded data representations and instead collects input and output arrays onto a single device. Consider removing the outer jit unless you know what you're doing.`

For reference, the above code corresponds to the generative model

\sigma_e,\sigma_\beta\sim\mathrm{priors},

\beta \vert \sigma_e,\sigma_\beta \sim N(0, \sigma_\beta),

y \vert\beta, \sigma_e,\sigma_\beta \sim N(X\beta, \sigma_e),

I’m trying to parallelize the computation of the likelihood of y and the computation of X\beta.

Two questions:

- Am I correct to assume this is outer
`jit`

is the cause for poor performance? - Is there a way around this behavior? Based on @fehiepsi’s answer here, I know this is atypical usage, but I’m not sure if it’s infeasible.

My motivating case (not in the toy model above) involves a likelihood with a multiplication between a large fixed matrix and a large vector of latent variables. I’d like to be able to split this across multiple GPUs. Note that since the limiting factor is the dimension of the parameter vector rather than the number of data points, parallelization at the level of multiple chains won’t make a big difference… Thanks!