I’m wondering if there’s a way to efficiently use
jax.pmap for parallel computation of the log density. It seems that the entire log density is automatically
jited, yielding a big performance penalty. Here’s a toy example showing this behavior (I know there’s nothing to be gained via
pmap here, it just demonstrates the issue):
import numpy as np import jax import numpyro import jax.random as random import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS import jax.numpy as jnp ## dataset dimensions M=100; N=200 ## true parameter values sbeta_true = np.sqrt(.5) se_true = np.sqrt(.5) beta_true = np.random.randn(M)*sbeta_true ## observed data X = np.random.randn(M*N).reshape(N,M) X = np.apply_along_axis(lambda x: (x-np.mean(x))/np.std(x), 0, X) / np.sqrt(M) e = np.random.randn(N)*se_true y = X @ beta_true + e y -= np.mean(y) ## map reduce computation of likelihood def y_likelihood(X,b,y,se): return jnp.sum(dist.Normal(0., se).log_prob(y-jnp.dot(X,b))) mapped_y_likelihood=jax.pmap(y_likelihood, in_axes=(0,None,0,None)) ## model specification def toy_model(y=None, X=None): s_beta = numpyro.sample('s_beta', dist.HalfCauchy(1)) s_e = numpyro.sample('s_e', dist.HalfCauchy(1),sample_shape=(1,)) beta = numpyro.sample('beta', dist.Normal(0.,s_beta),sample_shape=(X.shape,)) lpy=mapped_y_likelihood(X, beta, y, s_e) numpyro.factor('y',jnp.sum(lpy)) ## construct kernel nuts_kernel = NUTS(toy_model) mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500) rng_key = random.PRNGKey(0) ## run model mcmc.run(rng_key, y=y.reshape(2,y.shape//2), X=X.reshape(2,X.shape//2,X.shape))
This yields the warning
UserWarning: The jitted function _body_fn includes a pmap. Using jit-of-pmap can lead to inefficient data movement, as the outer jit does not preserve sharded data representations and instead collects input and output arrays onto a single device. Consider removing the outer jit unless you know what you're doing. For reference, the above code corresponds to the generative model
\beta \vert \sigma_e,\sigma_\beta \sim N(0, \sigma_\beta),
y \vert\beta, \sigma_e,\sigma_\beta \sim N(X\beta, \sigma_e),
I’m trying to parallelize the computation of the likelihood of y and the computation of X\beta.
- Am I correct to assume this is outer
jitis the cause for poor performance?
- Is there a way around this behavior? Based on @fehiepsi’s answer here, I know this is atypical usage, but I’m not sure if it’s infeasible.
My motivating case (not in the toy model above) involves a likelihood with a multiplication between a large fixed matrix and a large vector of latent variables. I’d like to be able to split this across multiple GPUs. Note that since the limiting factor is the dimension of the parameter vector rather than the number of data points, parallelization at the level of multiple chains won’t make a big difference… Thanks!