How to run svi on gpu when the default backend for jax is cpu?

Hello devs. I have the default backend for jax set to cpu. I do this using numpyro.set_platform("cpu"). I want to use the gpu accelerator when running svi, for which I transfer all the datasets (which go into the model) to cuda, but I see no difference in time consumption. The time consumption is same as the case when I don’t transfer the datasets to cuda.

However, if I set the default backend to cuda, numpyro.set_platform("cuda"), the time consumption reduces massively.

The following (toy) code (taken from one of NumPyro’s tutorials) reproduces the results: [The reasons I added num_regressions plate was to make the difference in time on cpu and gpu noticeable]

import time

import jax
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import Predictive, SVI, Trace_ELBO


def main(num_regressions=int(1e5), use_cuda=False):
    print(f"num_regressions: {num_regressions}")
    print(f"use_cuda: {use_cuda}")
    CPUS = jax.devices("cpu")
    GPUS = jax.devices("gpu")
    print(f"CPUS: {CPUS}")
    print(f"GPUS: {GPUS}")
    print(f"Default backend: {jax.default_backend()}")

    def model(data):
        with numpyro.plate("num_regressions", num_regressions):
            f = numpyro.sample("latent_fairness", dist.Beta(10, 10))
            with numpyro.plate("N", data.shape[0] if data is not None else 10):
                numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        with numpyro.plate("num_regressions", num_regressions):
            alpha_q = numpyro.param("alpha_q", 15., constraint=constraints.positive)
            beta_q = numpyro.param("beta_q", lambda rng_key: random.exponential(rng_key),
            numpyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

    data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)])
    data = data[..., None]  # add num_regressions dimension
    if use_cuda:
        data = jax.device_put(data, device=GPUS[0])

    print(f"data is on: {data.devices()}")

    optimizer = numpyro.optim.Adam(step_size=0.0005)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
    start = time.time()
    svi_result =, 10000, data)
    end = time.time()
    print(f"Elapsed time: {end - start:.2f}")

if __name__ == "__main__":
    # main(use_cuda=True)

With main(), the output is:

use_cuda: False
num_regressions: 100000
use_cuda: False
CPUS: [CpuDevice(id=0)]
GPUS: [cuda(id=0)]
Default backend: cpu
data is on: {CpuDevice(id=0)}
Elapsed time: 1084.24

and, with main(use_cuda=True), the output is:

num_regressions: 100000
use_cuda: True
CPUS: [CpuDevice(id=0)]
GPUS: [cuda(id=0)]
Default backend: cpu
data is on: {cuda(id=0)}
Elapsed time: 1086.58

Instead, if I change numpyro.set_platform("cpu") and change it to numpyro.set_platform("cuda") (where the import statements are), and run main(), the output is

num_regressions: 100000
use_cuda: False
CPUS: [CpuDevice(id=0)]
GPUS: [cuda(id=0)]
Default backend: gpu
data is on: {cuda(id=0)}
Elapsed time: 84.63

The reason I want to keep default backend to cpu is because I run out of memory when using Predictive utility function on gpu for large datasets. [While we are at it, is there a way to avoid this from happening when using Predictive on gpu?]. Once the svi is run, I want to get the posterior samples and move them back to the default cpu backend (and then use them to do predictions). Any idea what’s going wrong here?

For very large datasets, I would recommend running Predictive in chunks, then copying the chunk outputs to the host (using jax.device_get). I haven’t played with mixed devices so I’m not sure what’s the best practice.