Possible to use `pmap` within likelihood computation?

I’m wondering if there’s a way to efficiently use jax.pmap for parallel computation of the log density. It seems that the entire log density is automatically jited, yielding a big performance penalty. Here’s a toy example showing this behavior (I know there’s nothing to be gained via pmap here, it just demonstrates the issue):

import numpy as np
import jax
import numpyro
import jax.random as random
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import jax.numpy as jnp

## dataset dimensions
M=100; N=200

## true parameter values
sbeta_true = np.sqrt(.5)
se_true = np.sqrt(.5)
beta_true = np.random.randn(M)*sbeta_true

## observed data
X = np.random.randn(M*N).reshape(N,M)
X = np.apply_along_axis(lambda x: (x-np.mean(x))/np.std(x), 0, X) / np.sqrt(M)
e = np.random.randn(N)*se_true
y = X @ beta_true + e
y -= np.mean(y)

## map reduce computation of likelihood
def y_likelihood(X,b,y,se):
    return jnp.sum(dist.Normal(0., se).log_prob(y-jnp.dot(X,b)))

mapped_y_likelihood=jax.pmap(y_likelihood, in_axes=(0,None,0,None))


## model specification
def toy_model(y=None, X=None):
    s_beta = numpyro.sample('s_beta', dist.HalfCauchy(1))
    s_e = numpyro.sample('s_e', dist.HalfCauchy(1),sample_shape=(1,))
    beta = numpyro.sample('beta', dist.Normal(0.,s_beta),sample_shape=(X.shape[1],))
    lpy=mapped_y_likelihood(X, beta, y, s_e)
    numpyro.factor('y',jnp.sum(lpy))

## construct kernel
nuts_kernel = NUTS(toy_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)

rng_key = random.PRNGKey(0)

## run model
mcmc.run(rng_key, y=y.reshape(2,y.shape[0]//2),
         X=X.reshape(2,X.shape[0]//2,X.shape[1]))

This yields the warning UserWarning: The jitted function _body_fn includes a pmap. Using jit-of-pmap can lead to inefficient data movement, as the outer jit does not preserve sharded data representations and instead collects input and output arrays onto a single device. Consider removing the outer jit unless you know what you're doing. For reference, the above code corresponds to the generative model

\sigma_e,\sigma_\beta\sim\mathrm{priors},
\beta \vert \sigma_e,\sigma_\beta \sim N(0, \sigma_\beta),
y \vert\beta, \sigma_e,\sigma_\beta \sim N(X\beta, \sigma_e),

I’m trying to parallelize the computation of the likelihood of y and the computation of X\beta.

Two questions:

  1. Am I correct to assume this is outer jit is the cause for poor performance?
  2. Is there a way around this behavior? Based on @fehiepsi’s answer here, I know this is atypical usage, but I’m not sure if it’s infeasible.

My motivating case (not in the toy model above) involves a likelihood with a multiplication between a large fixed matrix and a large vector of latent variables. I’d like to be able to split this across multiple GPUs. Note that since the limiting factor is the dimension of the parameter vector rather than the number of data points, parallelization at the level of multiple chains won’t make a big difference… Thanks!

I think you can try to run HMC with fixed num_steps to calculate the average time per single leapfrog step. Then you try to compute log density and its gradient to see how long it is. This will give you an overview of the picture.

I’m not sure I entirely understand your suggestion? I’ve compared the pmap'd versus vectorized performance using HMC with a fixed number of steps and I get the following:

## map reduce computation of likelihood
def y_likelihood(X,b,y,se):
    return jnp.sum(dist.Normal(0., se).log_prob(y-jnp.dot(X,b)))

mapped_y_likelihood=jax.pmap(y_likelihood, in_axes=(0,None,0,None))

## model specification
def toy_model_vec(y=None, X=None):
    s_beta = numpyro.sample('s_beta', dist.HalfCauchy(1))
    s_e = numpyro.sample('s_e', dist.HalfCauchy(1),sample_shape=(1,))
    beta = numpyro.sample('beta', dist.Normal(0.,s_beta),sample_shape=(X.shape[1],))
    lpy=y_likelihood(X, beta, y, s_e)
    numpyro.factor('y',lpy)

## model specification with pmap likelihood
def toy_model_pmap(y=None, X=None):
    s_beta = numpyro.sample('s_beta', dist.HalfCauchy(1))
    s_e = numpyro.sample('s_e', dist.HalfCauchy(1),sample_shape=(1,))
    beta = numpyro.sample('beta', dist.Normal(0.,s_beta),sample_shape=(X.shape[2],))
    lpy=mapped_y_likelihood(X, beta, y, s_e)
    numpyro.factor('y',jnp.sum(lpy))


rng_key = random.PRNGKey(0)
nuts_kernel_vec = NUTS(toy_model_vec)
mcmc_nuts_vec = MCMC(nuts_kernel_vec, num_warmup=500, num_samples=500)

## run NUTS kernel to get initial state for HMC comparison
mcmc_nuts_vec.run(rng_key, y=y, X=X) ## 17 sec

step_size = mcmc_nuts_vec.last_state.adapt_state.step_size
inv_mm = mcmc_nuts_vec.last_state.adapt_state.inverse_mass_matrix

## construct HMC kernels
hmc_kernel_vec = HMC(toy_model_vec,
                     num_steps=15,
                     step_size=step_size,
                     inverse_mass_matrix=inv_mm)

hmc_kernel_pmap = HMC(toy_model_pmap,
                      num_steps=15,
                      step_size=step_size,
                      inverse_mass_matrix=inv_mm)

mcmc_hmc_vec = MCMC(hmc_kernel_vec, num_warmup=500, num_samples=500)
mcmc_hmc_pmap = MCMC(hmc_kernel_pmap, num_warmup=500, num_samples=500)

## run vanilla HMC
mcmc_hmc_vec.run(rng_key, ## 6 seconds
                 y=y,
                 X=X)

## run vanilla with pmap
mcmc_hmc_pmap.run(rng_key, ## 23 seconds
                  y=y.reshape(2,y.shape[0]//2),
                  X=X.reshape(2,X.shape[0]//2,X.shape[1]))

I get the Using jit-of-pmap can lead to inefficient data movement warning with the pmap'd HMC version and run about 4x slower than the vectorized version. But I’m not sure if this is really because of jit-of-pmap. Is there away to disable the jiting of the sampler?

Given a model, you can compute log density with Runtime Utilities — NumPyro documentation . Kind of

def get_log_prob(params):
    return log_density(model,...)

%time get_log_probs(params)
%time jax.grad(get_log_probs)(params)

That will give you time for 1 leapfrog step without jit.

Then with HMC, after getting the amount of time, you divide it by (num_warmup + num_samples) * num_steps. This will give you time for 1 leapfrog step with HMC jit.

You can also jit get_log_prob and compare.

I remember that jax team also provides a context manager to disable jit. That might be helpful for you.

This is super helpful. I think this is equivalent to what you suggested: I define the pmap'd and regular version of the model:

## model specification
def toy_model_vec(y=None, X=None):
    s_beta = numpyro.sample('s_beta', dist.HalfCauchy(1))
    s_e = numpyro.sample('s_e', dist.HalfCauchy(1),sample_shape=(1,))
    beta = numpyro.sample('beta', dist.Normal(0.,s_beta),sample_shape=(X.shape[1],))
    lpy=y_likelihood(X, beta, y, s_e)
    numpyro.factor('y',lpy)

## model specification with pmap likelihood
def toy_model_pmap(y=None, X=None):
    s_beta = numpyro.sample('s_beta', dist.HalfCauchy(1))
    s_e = numpyro.sample('s_e', dist.HalfCauchy(1),sample_shape=(1,))
    beta = numpyro.sample('beta', dist.Normal(0.,s_beta),sample_shape=(X.shape[2],))
    lpy=mapped_y_likelihood(X, beta, y, s_e)
    numpyro.factor('y',jnp.sum(lpy))

I then compare the times to evaluate jit'd and non-jit'd functions after increasing the sample size N=32000.:

def test_vec(params):
    return numpyro.infer.util.log_density(toy_model_vec, model_args=(),
                                   model_kwargs={'y':y,'X':X},
                                   params=params)[0]
def test_pmap(params):
    return numpyro.infer.util.log_density(toy_model_pmap, model_args=(),
                                   model_kwargs={'y':y.reshape(2,y.shape[0]//2),
                                                 'X':X.reshape(2,X.shape[0]//2,X.shape[1])},
                                   params=params)[0]

test_vals = {'s_e':.7,'s_beta':.7,'beta':beta_true}

## no jit
%timeit null=test_pmap(test_vals) ## 125 ms
%timeit null=jax.grad(test_pmap)(test_vals)## 144 ms
%timeit null=test_vec(test_vals) ## 226 ms
%timeit null=jax.grad(test_vec)(test_vals)## 197 ms

## with jit
%timeit null=jax.jit(test_pmap)(test_vals) ## 1.46 ms
%timeit null=jax.grad(jax.jit(test_pmap))(test_vals) ## 6.36 ms
%timeit null=jax.jit(test_vec)(test_vals) ## 456 mus
%timeit null=jax.grad(jax.jit(test_vec))(test_vals) ## 2.76 ms

So without jit, the pmap version is 2x as fast, but with jit, it’s 3x slower! I think this is in line with my hypothesis that the jit-of-pmap is the issue…

I think jax’s pjit will be useful here. Something like pjit(mcmc_sample_fn)(batch_of_data) where we use psum in the model to aggregate log probabilities at each device. In principle it will work and will be faster but would require a bit engineering effort. If you really need this feature, please make a github issue, I will add some pointers there.

After some experimentation I was able to implement a pjit’d log density / gradient. The idea is that I have observed data X\in \mathbb{R}^{n\times m}, y\in \mathbb{R}^{n\times 1}, latent variables beta\in \mathbb{R}^{m\times 1}, and parameters s_beta and s_e.

I can parallelize over the n rows of my observed data but sampling s_e, s_beta, beta on a single GPU, then computing the likelihood of y - jnp.dot(X, beta) across multiple GPUs.

import numpy as np
import jax
import numpyro
import jax.random as random
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, HMC
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.experimental import maps
from jax.experimental import PartitionSpec

## dataset dimensions
M=4000; N=10000

## true parameter values
sbeta_true = np.sqrt(.5)
se_true = np.sqrt(.5)
beta_true = np.random.randn(M)*sbeta_true

## simulate observed data
X = np.random.randn(M*N).reshape(N,M)
X = np.apply_along_axis(lambda x: (x-np.mean(x))/np.std(x), 0, X) / np.sqrt(M)
e = np.random.randn(N)*se_true
y = X @ beta_true + e
y -= np.mean(y)
y=y.reshape(y.shape[0],1)
beta_true=beta_true.reshape(beta_true.shape[0],1)

## define log density
def _lpdf(X,y,b,sb,se):
    ## priors
    ll = dist.HalfCauchy(1).log_prob(sb) + dist.HalfCauchy(1).log_prob(se)
    ## latent variables
    ll += jnp.sum(dist.Normal(0., sb).log_prob(b))
    ## distributed part involving observed data
    ll += jax.lax.psum(dist.Normal(0., se).log_prob(y-jnp.dot(X,b)),0)[0]
    return ll

## function to distribute rows of X,y across mesh
shard = pjit(
    lambda x: x,
    in_axis_resources=None,
    out_axis_resources=PartitionSpec('n', 'm'))

## single device version for comparison
jit_lpdf = jax.jit(_lpdf)
grad_lpdf = jax.grad(_lpdf, argnums=(2,3,4))
jgrad_lpdf = jax.jit(grad_lpdf)

## distributed likelihood
pjit_lpdf = pjit(_lpdf,
                 in_axis_resources=(PartitionSpec('n', 'm'), PartitionSpec('n', 'm'),
                                    None, None, None),
                 out_axis_resources=None)
pgrad_lpdf = pjit(grad_lpdf,
                 in_axis_resources=(PartitionSpec('n', 'm'), PartitionSpec('n', 'm'),
                                    None, None, None),
                 out_axis_resources=None)

## shard data across devices
shard = pjit(
    lambda x: x,
    in_axis_resources=None,
    out_axis_resources=PartitionSpec('n', 'm'))

with maps.Mesh(mesh.devices, mesh.axis_names):
    X_sharded = shard(X)
    y_sharded = shard(y)

## evaluate multi-gpu lpdf in 1.78ms
%%timeit
with maps.Mesh(mesh.devices, mesh.axis_names):
    lp = pjit_lpdf(X_sharded,y_sharded,beta_true,.7,.7)
    lgrad = pgrad_lpdf(X_sharded,y_sharded,beta_true,.7,.7)

## evaluate singe-gpu lpdf in 322ms
%%timeit
lp=jit_lpdf(X,y,beta_true,.7,.7)
lgrad=jgrad_lpdf(X,y,beta_true,.7,.7)

This works very well with two GPUs, though the above comparison must not be fair as the two GPU version is 200x faster (I think this is because the data are already sharded across the gpus?). But as far as I can tell there’s no way to sample this lpdf because numpyro will attempt to jit the already pjit'd density, which throws an error.

I will happily open a github issue if you think this might be a feasible thing to add in?

I guess you need to use block_until_ready to have a fair comparison.

open a github issue

Yes, please. I think we just need to expose an option to change this jit to pjit. (you might want to try it first if you want :slight_smile: