Thanks to great workers Pyro conrtibutor!
I tried to compare pyro mcmc to numpyro mcmc focused on speed. The result was numpyro is 10x faster than pyro. But I do not understand where come from that difference is.
Jax operations is NOT 10x faster than PyTorch even with JIT compile on CPU.
I used google colab which have only one cpu. Therefore mcmc chain is only one.
Experiment models are exactly same.
pyro code
def model_pyro(temp_obs, weather_obs, sells_obs):
temp_means = sample(
"temp_means",
dist.Normal(loc=torch.tensor(30.),
scale=torch.tensor(2.0))
)
temp_stds = sample(
"temp_stds",
dist.LogNormal(loc=torch.tensor(0.),
scale=torch.tensor(2.))
)
sells_std = sample(
"sells_std",
dist.LogNormal(loc=torch.tensor([0., 0.]),
scale=torch.tensor([5.0, 5.0]))
)
with plate("temps_", 2):
temps = sample(
"temps",
dist.Normal(temp_means, temp_stds)
)
temp_coeff = sample("temp_coeff",
dist.Normal(torch.tensor([0., 0.]), torch.tensor([1., 1.])))
temp_bias = sample("temp_bias",
dist.Normal(torch.tensor([0., 0.]), torch.tensor([10., 10.])))
weather_prob = sample("weather_prob",
dist.Beta(1.0, 1.0))
with plate("days", size=len(temp_obs)):
weather = sample(
"weather",
dist.Bernoulli(probs=weather_prob).expand([len(temp_obs)]),
obs=weather_obs
).long()
temp = sample(
"temp",
dist.Normal(loc=temps[weather], scale=temp_stds),
obs=temp_obs
)
sells = sample(
"sells",
dist.Normal(loc=temp_coeff[weather] * temp + temp_bias[weather],
scale=sells_std[weather]),
obs=sells_obs
)
return temp, weather, sells
nuts = infer.mcmc.NUTS(model_pyro, jit_compile=True, ignore_jit_warnings=True,
max_tree_depth=10)
mcmc = infer.mcmc.MCMC(nuts, 1000, 200)
mcmc.run(temp, wether.float(), sells)
numpyro code
def model(temp_obs, weather_obs, sells_obs):
temp_means = sample(
"temp_means",
dist.Normal(loc=jnp.array(30.),
scale=jnp.array(2.0)),
)
temp_stds = sample(
"temp_stds",
dist.LogNormal(loc=jnp.array(0.),
scale=jnp.array(2.)),
)
sells_std = sample(
"sells_std",
dist.LogNormal(loc=jnp.array([0., 0.]),
scale=jnp.array([5.0, 5.0])),
)
with plate("temps_", 2):
temps = sample(
"temps",
dist.Normal(temp_means, temp_stds),
)
temp_coeff = sample("temp_coeff",
dist.Normal(jnp.array([0., 0.]), jnp.array([1., 1.])),
)
temp_bias = sample("temp_bias",
dist.Normal(jnp.array([0., 0.]), jnp.array([100., 100.])),
)
weather_prob = sample("weather_prob",
dist.Beta(1.0, 1.0),
)
with plate("days", size=len(temp_data)):
weather = sample(
"weather",
dist.Bernoulli(probs=weather_prob),
obs=weather_obs,
).astype(jnp.int32)
temp = sample(
"temp",
dist.Normal(loc=temps[weather], scale=temp_stds),
obs=temp_obs,
)
sells = sample(
"sells",
dist.Normal(loc=temp_coeff[weather] * temp + temp_bias[weather],
scale=sells_std[weather]),
obs=sells_obs,
)
return temp, weather, sells
mcmc = infer.MCMC(infer.NUTS(model, max_tree_depth=10),
num_warmup=200,
num_samples=1000,
progress_bar=True,)
mcmc.run(random.PRNGKey(2020), temp_data, weather_data, sells_data)
The result time
pyro: 153 seconds
numpyro : 12 seconds
Where this different is come from?
Thanks for reading.