Hi everyone,
I want to post a question that it is very slow when I use NUT for discrete Bayesian inference. I build a model just for two binary nodes A–> B. The probabilities p(A), p(B|A=0), p(B|A=1) have a prior of Beta distribution. Given the data of A and B, we want to get some samples of p(A), p(B|A=0), p(B|A=1). I checked the previous questions but still can not improve the speed of my codes. If I use jit.compile = True. I got a lot of warnings and I can not get reasonable results. Here are my codes
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
import numpy as np
import matplotlib.pyplot as plt
######### generate data###############
pyro.set_rng_seed(0)
N=10000
data_A = list(np.random.binomial(1,0.5,N))
data_B = []
for i in data_A:
if i==0:
data_B.append(np.random.binomial(1,0.25,1)[0])
else:
data_B.append(np.random.binomial(1,0.75,1)[0])
data_A = torch.tensor(data_A,dtype = torch.float32)
data_B = torch.tensor(data_B,dtype = torch.float32)
# build the A-->B model
pyro.clear_param_store()
@pyro.infer.config_enumerate
def model(N):
# parameter
probs_a = pyro.sample("probs_a", dist.Beta(0.5,0.5))
probs_a0b = pyro.sample("probs_a0b", dist.Beta(0.5,0.5))
probs_a1b = pyro.sample("probs_a1b", dist.Beta(0.5,0.5))
probs_b = torch.tensor([[1-probs_a0b, probs_a0b],[1-probs_a1b, probs_a1b]])
with pyro.plate("data", N):
a = pyro.sample("a", dist.Bernoulli(logits=probs_a))
b = pyro.sample("b", dist.Bernoulli(logits=probs_b[a.long(),1]))
return a, b
# inference
from pyro.infer import MCMC, NUTS
conditioned_model = pyro.condition(model, data={"a": data_A, "b": data_B})
nuts_kernel = NUTS(conditioned_model, step_size = 1, adapt_mass_matrix=True)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run(N)
hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}
1000 samples take me about one hour. Can someone give me some ideas to improve it? Thank you so much for your help!