Batch processing numpyro models using Ray

Hello again,

Related post: Batch processing Pyro models so cc: @fonnesbeck as I think he’ll be interested in batch processing Bayesian models anyway.

I want to run lots of numpyro models in parallel. I created a new post because:

  • this post uses numpyro instead of pyro
  • I’m doing sampling instead of SVI
  • I’m using Ray instead of Dask
  • that post was 2021

I’m running a simple Neal’s funnel model as a test. I have set up a small Ray cluster of 3 instances with 4CPUs each. I run the Neal’s funnel model for dim in range(2, 20). I do this once running one dim after another (sequentially in the run_model function) and once in parallel using Ray remote (in the ray_run_model decorated with @ray.remote). Full code is at the bottom.

I think numpyro on Ray is working. When I look at the logs, I get a Speedup: 5.02x. And I can see the 3 instances and the CPUs being used. Below is the CPU usage plot. Spin up is until 21:45, then the sequential runs are until 21:47 and the parallel kicks in after. As the blue line is the head node, that’s the one doing running the sequential tasks and for the parallel tasks doing the scheduling across the rest of its CPUs and the CPUs on the other instances. So I’m fairly convinced it’s running a numpyro model as expected.

EDIT: adding a little bit of colour. I now ran the parallel test only for 1000 of the same dim models and looked at the CPU utilisation. This looks really good. In this case, the head node is in red and I made the worker nodes blues. The head node fills itself up before sending tasks to worker nodes, peaking at ~394% of the CPU on each node.

I’m experienced with numpyro but I’m unfamiliar with what goes on under the hood in XLA, and I’m inexperienced with Ray too. I have a few questions to those who might know more about how to get the most out this setup:

  • In the post linked above, the dev mentioned “several pieces of global state in Pyro’s internals… I would also not be surprised to see problems with parameters from different runs overwriting each other in the global parameter store.” Is this something I should be concerned with in numpyro too?
  • Further on this, to do with environment variables. I am just running chains in sequence, because I know that numpyro.set_host_device_count(X) changes a XLA_FLAGS environment variable, and that this might limit how many numpyro chains can happen simultaenously on an instance. But I am interested to hear whether you think parallel chains could work anyway or if that environment variable will mess things up? (Ray might be smart enough to deal with the environment variable stuff but I have no idea.)
  • Although numpyro runs one chain per core, I assume there’s some multi-threading going on e.g. for matrix multiplication in XLA (or more generally in numpy) where I might use more cores. Is this the case, and if so, is it a bad idea (wrt jax) to use up all the rest of the cores for batch model runs?
    • And for the Ray side of this, if I set @ray.remote(num_cpus=2), is Ray smart enough to allocate these CPUs to the multithreading of that particular model, or are processes going to be competing with each other regardless of the resources I specify to Ray. (No worries on this if it’s not obvious as this is more a Ray question than XLA, but I’m unfamiliar with stuff so perhaps it’s obvious.)

Happy to provide any more code to try things out and see how numpyro and Ray can be optimally used together. I think it’s an exciting way to scale many models/potentially chains across a cluster.

Thanks as always,
Theo

import os
import time

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import ray


def model(dim: int = 10) -> None:
    y = numpyro.sample("y", dist.Normal(0, 3))
    numpyro.sample("x", dist.Normal(jnp.zeros(dim - 1), jnp.exp(y / 2)))


def run_inference(model, rng_key: jax.Array, dim: int):
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=10000,
        num_samples=10000,
        num_chains=2,
        progress_bar=True,
    )
    mcmc.run(rng_key, dim=dim)
    return mcmc


def run_model(dim):
    rng_key = jax.random.PRNGKey(0)
    mcmc = run_inference(model, rng_key, dim=dim)
    samples = mcmc.get_samples()
    return samples["x"].shape


@ray.remote
def ray_run_model(dim):
    """Remote version of the model with some ray variables printed"""
    rng_key = jax.random.PRNGKey(0)
    mcmc = run_inference(model, rng_key, dim=dim)
    samples = mcmc.get_samples()

    context = ray.runtime_context.get_runtime_context()

    return {
        "result": samples["x"].shape,
        "node_id": context.get_node_id(),
        "job_id": ray.get_runtime_context().get_job_id(),
        "actor_id": ray.get_runtime_context().get_actor_id(),
        "task_id": ray.get_runtime_context().get_task_id(),
        "worker_id": ray.get_runtime_context().get_worker_id(),
        "assigned_resources": ray.get_runtime_context().get_assigned_resources(),
        "hostname": os.uname().nodename,
        "pid": os.getpid(),
    }


def test_ray_parallel():
    # Test sequential execution first
    print("Running sequential test...")
    start_time = time.time()
    sequential_results = [
        run_model(dim) for dim in range(2, 20)
    ]  # direct function calls
    sequential_time = time.time() - start_time

    print(f"Sequential execution took: {sequential_time:.2f} seconds")
    print(f"Sequential results: {sequential_results}")

    # Test parallel execution
    print("\nRunning parallel test...")
    start_time = time.time()
    futures = [ray_run_model.remote(dim) for dim in range(2, 20)]
    parallel_results = ray.get(futures)  # this will run in parallel
    parallel_time = time.time() - start_time

    print(f"Parallel execution took: {parallel_time:.2f} seconds")
    print("Parallel results: ")
    for results in parallel_results:
        print(results)

    return sequential_time, parallel_time


if __name__ == "__main__":
    ray.init(address=<>)
    try:
        seq_time, par_time = test_ray_parallel()
        print("\nSummary:")
        print(f"Sequential time: {seq_time:.2f} seconds")
        print(f"Parallel time: {par_time:.2f} seconds")
        print(f"Speedup: {seq_time / par_time:.2f}x")

        if par_time < seq_time:
            print("Ray parallel processing is working correctly!")
        else:
            print("Parallel execution was not faster than sequential!")

    except Exception as e:
        print(f"Error testing Ray: {str(e)}")
    finally:
        ray.shutdown()