get's stuck in Google Colab high ram / GPU / TPU runtime

I noticed that when I try to run using Google Colab with either High Ram runtime or GPU/CUDA runtime the process get’s stuck… E.g., using this simple model:

def model(obs=None):
    mu = pyro.sample("theta", dist.Uniform(-10*torch.ones(dim), high=10*torch.ones(dim)))
    ber = pyro.sample("mix", dist.Bernoulli(0.5))
    sigma = ber*torch.ones(()) + (1-ber)*0.1*torch.ones(())
    return pyro.sample("obs", dist.Normal(mu, sigma), obs=obs)

With very few observations - leads to the following

Which is stuck for an unreasonable long time… Also it’s stuck on something that doesn’t seem to be related to the actual algorithm, but to the multithreading/multiprocessing:

The same happened to me with GPU environment.

There is no problem if I set num_chains=1, or if I use many chains but a regular Google Colab environment - though I get a warning that num_chains “is more than available_cpu=1. Chains will be drawn sequentially.”

I’m guessing this is a multiprocessing bug that might be related also to using Colab?

Or am I doing something wrong?

If you are using cuda, you might set MCMC(..., mp_context="spawn") as described in the MCMC docs. Also if your primary inference method will be MCMC (rather than SVI), we’d recommend switching to NumPyro which is much faster, especially for small models.