MCMC inference running too slow for simple example

I am new to Pyro and probabilistic inference, but started by running some simple examples inspired by tutorials in other languages (Church and webPPL).

This example makes a coin with a certain probability of being fair (here is the strong prior of 0.99) and then flips this coin with the selected Bernoulli weight (different for fair and unfair coins). The goal is then to infer the posterior distribution of “fairness” given an observed data that suggests the probability of the coin being fair is low. The code is running, but extremely slowly (sometimes taking ~10s p/ iteration). May someone give me a hint on something that I’m not coding right and is slowing down inference?

def model(observed_data):
  fair_prior = 0.99 # the probability of the coin being fair
  fair_coin = pyro.sample("coin", dist.Bernoulli(fair_prior)).long()
  flip_prob = pyro.sample("flip_prob", dist.Normal(torch.tensor([0.95, 0.5])[fair_coin], torch.tensor([0.01, 0.01])[fair_coin]))
  with pyro.plate("data_plate"):
    pyro.sample("flip", dist.Bernoulli(flip_prob), obs=observed_data)
def guide(observed_data):
  nuts_kernel = NUTS(model)
  mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=100)
  mcmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}
  return mcmc_samples

# conditioning on the observed data -> coin is probably not fair
observed_data = torch.tensor([1., 0., 1., 1., 1., 1., 0., 1., 1., 1.])
posterior_samples = guide(observed_data)

I should add that my posterior samples only contain values for flip_prob and not for coin as I wished… Any help would be extremely valuable!

you can’t put a Normal prior on the probs that enters Bernoulli: probs must be between 0 and 1.

if your inference is slow it’s probably because your prior on flip_prob has two very distant and very narrow modes

Thanks for your insights. In reality, the normal prior was a way to get around the discrete inference problem, with low stdev to ensure that the Bernoulli weight wouldn’t go above 1, but I understand it could be not very good practice. What about my other issue, do you have something in mind? Basically I want to infer the posterior distribution over the fairness of the coins that are generated at each iteration, but my posterior_samples don’t return any value for coin

search the forum/docs for infer_discrete

you can also use DiscreteHMCGibbs or MixedHMC in numpyro