MCMC Sampling Slow for Discrete Bayesian Inference

Hi everyone,

I want to post a question that it is very slow when I use NUT for discrete Bayesian inference. I build a model just for two binary nodes A–> B. The probabilities p(A), p(B|A=0), p(B|A=1) have a prior of Beta distribution. Given the data of A and B, we want to get some samples of p(A), p(B|A=0), p(B|A=1). I checked the previous questions but still can not improve the speed of my codes. If I use jit.compile = True. I got a lot of warnings and I can not get reasonable results. Here are my codes

import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
import numpy as np
import matplotlib.pyplot as plt
######### generate data###############
data_A = list(np.random.binomial(1,0.5,N))
data_B = []

for i in data_A:
    if i==0:

data_A = torch.tensor(data_A,dtype = torch.float32)
data_B = torch.tensor(data_B,dtype = torch.float32)

# build the A-->B model
def model(N):

   # parameter
    probs_a = pyro.sample("probs_a", dist.Beta(0.5,0.5))
    probs_a0b = pyro.sample("probs_a0b", dist.Beta(0.5,0.5))
    probs_a1b = pyro.sample("probs_a1b", dist.Beta(0.5,0.5))
    probs_b = torch.tensor([[1-probs_a0b, probs_a0b],[1-probs_a1b, probs_a1b]])

    with pyro.plate("data", N):
        a = pyro.sample("a", dist.Bernoulli(logits=probs_a))
        b = pyro.sample("b", dist.Bernoulli(logits=probs_b[a.long(),1]))

    return a, b

# inference
from pyro.infer import MCMC, NUTS

conditioned_model = pyro.condition(model, data={"a": data_A, "b": data_B})
nuts_kernel = NUTS(conditioned_model, step_size = 1, adapt_mass_matrix=True)

mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

1000 samples take me about one hour. Can someone give me some ideas to improve it? Thank you so much for your help!

I think you should replace

probs_b = torch.tensor([[1-probs_a0b, probs_a0b],[1-probs_a1b, probs_a1b]])


probs_b = torch.stack([torch.stack([1 - probs_a0b, probs_a0b]),
                       torch.stack([1 - probs_a1b, probs_a1b])])
1 Like

Thank you so much for your quick response. It is much faster right now!! Just a followup question, I only changed probs_b and then ran the codes but it does not give reasonable results. For example, from the ground truth, probs_a should around 0.5 but I got a lot of samples near 0.001. Then, I changed dist.Bernoulli to dist.Categorical and then I could successfully get samples near 0.5. So I am confused.

I didn’t check but I guess you want to use dist.Bernoulli(probs=...) instead of dist.Bernoulli(logits=...).