[MCMC] Pyro speed compare to NumPyro on google colab

Thanks to great workers Pyro conrtibutor!

I tried to compare pyro mcmc to numpyro mcmc focused on speed. The result was numpyro is 10x faster than pyro. But I do not understand where come from that difference is.

Jax operations is NOT 10x faster than PyTorch even with JIT compile on CPU.

I used google colab which have only one cpu. Therefore mcmc chain is only one.

Experiment models are exactly same.

pyro code

def model_pyro(temp_obs, weather_obs, sells_obs):
    temp_means = sample(
        "temp_means", 
        dist.Normal(loc=torch.tensor(30.),
                    scale=torch.tensor(2.0))
    )
    temp_stds = sample(
        "temp_stds", 
        dist.LogNormal(loc=torch.tensor(0.),
                       scale=torch.tensor(2.))
    )

    sells_std = sample(
        "sells_std", 
        dist.LogNormal(loc=torch.tensor([0., 0.]),
                    scale=torch.tensor([5.0, 5.0]))
    )

    with plate("temps_", 2):
        temps = sample(
            "temps",
            dist.Normal(temp_means, temp_stds)
        )
    
    temp_coeff = sample("temp_coeff",
                        dist.Normal(torch.tensor([0., 0.]), torch.tensor([1., 1.])))
    temp_bias = sample("temp_bias",
                       dist.Normal(torch.tensor([0., 0.]), torch.tensor([10., 10.])))
    weather_prob = sample("weather_prob",
                         dist.Beta(1.0, 1.0))

    with plate("days", size=len(temp_obs)):
        weather = sample(
            "weather", 
            dist.Bernoulli(probs=weather_prob).expand([len(temp_obs)]),
            obs=weather_obs
        ).long()
        temp = sample(
            "temp", 
            dist.Normal(loc=temps[weather], scale=temp_stds),
            obs=temp_obs
        )
        sells = sample(
            "sells",
            dist.Normal(loc=temp_coeff[weather] * temp + temp_bias[weather], 
                        scale=sells_std[weather]),
            obs=sells_obs
        )
    return temp, weather, sells

nuts = infer.mcmc.NUTS(model_pyro, jit_compile=True, ignore_jit_warnings=True,
                       max_tree_depth=10)
mcmc = infer.mcmc.MCMC(nuts, 1000, 200)
mcmc.run(temp, wether.float(), sells)

numpyro code

def model(temp_obs, weather_obs, sells_obs):

    temp_means = sample(
        "temp_means", 
        dist.Normal(loc=jnp.array(30.),
                    scale=jnp.array(2.0)),
    )
    temp_stds = sample(
        "temp_stds", 
        dist.LogNormal(loc=jnp.array(0.),
                       scale=jnp.array(2.)),
    )

    sells_std = sample(
        "sells_std", 
        dist.LogNormal(loc=jnp.array([0., 0.]),
                    scale=jnp.array([5.0, 5.0])),
    )

    with plate("temps_", 2):
        temps = sample(
            "temps",
            dist.Normal(temp_means, temp_stds),
        )
    
    temp_coeff = sample("temp_coeff",
                        dist.Normal(jnp.array([0., 0.]), jnp.array([1., 1.])),
                        )
    temp_bias = sample("temp_bias",
                       dist.Normal(jnp.array([0., 0.]), jnp.array([100., 100.])),
                       )
    weather_prob = sample("weather_prob",
                         dist.Beta(1.0, 1.0),
                         )

    with plate("days", size=len(temp_data)):
        weather = sample(
            "weather", 
            dist.Bernoulli(probs=weather_prob),
            obs=weather_obs,
        ).astype(jnp.int32)

        temp = sample(
            "temp", 
            dist.Normal(loc=temps[weather], scale=temp_stds),
            obs=temp_obs,
        )

        sells = sample(
            "sells",
            dist.Normal(loc=temp_coeff[weather] * temp + temp_bias[weather], 
                           scale=sells_std[weather]),
            obs=sells_obs,
        )
    return temp, weather, sells


mcmc = infer.MCMC(infer.NUTS(model, max_tree_depth=10), 
                  num_warmup=200, 
                  num_samples=1000, 
                  progress_bar=True,)
mcmc.run(random.PRNGKey(2020), temp_data, weather_data, sells_data)

The result time

pyro: 153 seconds
numpyro : 12 seconds

Where this different is come from?

Thanks for reading.

I’m not an expert, but it seems to me that the difference comes from the way NUTS is implemented in NumPyro: Iterative NUTS · pyro-ppl/numpyro Wiki · GitHub

1 Like

Thanks for reply.

NumPyro NUTS is very fast especialy to hierarchical model. I wish numpyro will have many more successful years.

this is discussed briefly in the numpyro paper; see sec 3.1