I am new to Pyro and probabilistic inference, but started by running some simple examples inspired by tutorials in other languages (Church and webPPL).
This example makes a coin with a certain probability of being fair (here is the strong prior of 0.99) and then flips this coin with the selected Bernoulli weight (different for fair and unfair coins). The goal is then to infer the posterior distribution of “fairness” given an observed data that suggests the probability of the coin being fair is low. The code is running, but extremely slowly (sometimes taking ~10s p/ iteration). May someone give me a hint on something that I’m not coding right and is slowing down inference?
def model(observed_data):
fair_prior = 0.99 # the probability of the coin being fair
fair_coin = pyro.sample("coin", dist.Bernoulli(fair_prior)).long()
flip_prob = pyro.sample("flip_prob", dist.Normal(torch.tensor([0.95, 0.5])[fair_coin], torch.tensor([0.01, 0.01])[fair_coin]))
with pyro.plate("data_plate"):
pyro.sample("flip", dist.Bernoulli(flip_prob), obs=observed_data)
def guide(observed_data):
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=100)
mcmc.run(observed_data)
mcmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}
return mcmc_samples
# conditioning on the observed data -> coin is probably not fair
observed_data = torch.tensor([1., 0., 1., 1., 1., 1., 0., 1., 1., 1.])
posterior_samples = guide(observed_data)
print(posterior_samples)