[MCMC] Pyro speed compare to NumPyro on google colab

Thanks to great workers Pyro conrtibutor!

I tried to compare pyro mcmc to numpyro mcmc focused on speed. The result was numpyro is 10x faster than pyro. But I do not understand where come from that difference is.

Jax operations is NOT 10x faster than PyTorch even with JIT compile on CPU.

I used google colab which have only one cpu. Therefore mcmc chain is only one.

Experiment models are exactly same.

pyro code

def model_pyro(temp_obs, weather_obs, sells_obs):
    temp_means = sample(
        "temp_means", 
        dist.Normal(loc=torch.tensor(30.),
                    scale=torch.tensor(2.0))
    )
    temp_stds = sample(
        "temp_stds", 
        dist.LogNormal(loc=torch.tensor(0.),
                       scale=torch.tensor(2.))
    )

    sells_std = sample(
        "sells_std", 
        dist.LogNormal(loc=torch.tensor([0., 0.]),
                    scale=torch.tensor([5.0, 5.0]))
    )

    with plate("temps_", 2):
        temps = sample(
            "temps",
            dist.Normal(temp_means, temp_stds)
        )
    
    temp_coeff = sample("temp_coeff",
                        dist.Normal(torch.tensor([0., 0.]), torch.tensor([1., 1.])))
    temp_bias = sample("temp_bias",
                       dist.Normal(torch.tensor([0., 0.]), torch.tensor([10., 10.])))
    weather_prob = sample("weather_prob",
                         dist.Beta(1.0, 1.0))

    with plate("days", size=len(temp_obs)):
        weather = sample(
            "weather", 
            dist.Bernoulli(probs=weather_prob).expand([len(temp_obs)]),
            obs=weather_obs
        ).long()
        temp = sample(
            "temp", 
            dist.Normal(loc=temps[weather], scale=temp_stds),
            obs=temp_obs
        )
        sells = sample(
            "sells",
            dist.Normal(loc=temp_coeff[weather] * temp + temp_bias[weather], 
                        scale=sells_std[weather]),
            obs=sells_obs
        )
    return temp, weather, sells

nuts = infer.mcmc.NUTS(model_pyro, jit_compile=True, ignore_jit_warnings=True,
                       max_tree_depth=10)
mcmc = infer.mcmc.MCMC(nuts, 1000, 200)
mcmc.run(temp, wether.float(), sells)

numpyro code

def model(temp_obs, weather_obs, sells_obs):

    temp_means = sample(
        "temp_means", 
        dist.Normal(loc=jnp.array(30.),
                    scale=jnp.array(2.0)),
    )
    temp_stds = sample(
        "temp_stds", 
        dist.LogNormal(loc=jnp.array(0.),
                       scale=jnp.array(2.)),
    )

    sells_std = sample(
        "sells_std", 
        dist.LogNormal(loc=jnp.array([0., 0.]),
                    scale=jnp.array([5.0, 5.0])),
    )

    with plate("temps_", 2):
        temps = sample(
            "temps",
            dist.Normal(temp_means, temp_stds),
        )
    
    temp_coeff = sample("temp_coeff",
                        dist.Normal(jnp.array([0., 0.]), jnp.array([1., 1.])),
                        )
    temp_bias = sample("temp_bias",
                       dist.Normal(jnp.array([0., 0.]), jnp.array([100., 100.])),
                       )
    weather_prob = sample("weather_prob",
                         dist.Beta(1.0, 1.0),
                         )

    with plate("days", size=len(temp_data)):
        weather = sample(
            "weather", 
            dist.Bernoulli(probs=weather_prob),
            obs=weather_obs,
        ).astype(jnp.int32)

        temp = sample(
            "temp", 
            dist.Normal(loc=temps[weather], scale=temp_stds),
            obs=temp_obs,
        )

        sells = sample(
            "sells",
            dist.Normal(loc=temp_coeff[weather] * temp + temp_bias[weather], 
                           scale=sells_std[weather]),
            obs=sells_obs,
        )
    return temp, weather, sells


mcmc = infer.MCMC(infer.NUTS(model, max_tree_depth=10), 
                  num_warmup=200, 
                  num_samples=1000, 
                  progress_bar=True,)
mcmc.run(random.PRNGKey(2020), temp_data, weather_data, sells_data)

The result time

pyro: 153 seconds
numpyro : 12 seconds

Where this different is come from?

Thanks for reading.

Iā€™m not an expert, but it seems to me that the difference comes from the way NUTS is implemented in NumPyro: https://github.com/pyro-ppl/numpyro/wiki/Iterative-NUTS

1 Like

Thanks for reply.

NumPyro NUTS is very fast especialy to hierarchical model. I wish numpyro will have many more successful years.

this is discussed briefly in the numpyro paper; see sec 3.1