Hi everyone,
I want to post a question that it is very slow when I use NUT for discrete Bayesian inference. I build a model just for two binary nodes A–> B. The probabilities p(A), p(B|A=0), p(B|A=1) have a prior of Beta distribution. Given the data of A and B, we want to get some samples of p(A), p(B|A=0), p(B|A=1). I checked the previous questions but still can not improve the speed of my codes. If I use jit.compile = True. I got a lot of warnings and I can not get reasonable results. Here are my codes
import torch import pyro import pyro.distributions as dist from torch.distributions import constraints import numpy as np import matplotlib.pyplot as plt ######### generate data############### pyro.set_rng_seed(0) N=10000 data_A = list(np.random.binomial(1,0.5,N)) data_B = [] for i in data_A: if i==0: data_B.append(np.random.binomial(1,0.25,1)[0]) else: data_B.append(np.random.binomial(1,0.75,1)[0]) data_A = torch.tensor(data_A,dtype = torch.float32) data_B = torch.tensor(data_B,dtype = torch.float32) # build the A-->B model pyro.clear_param_store() @pyro.infer.config_enumerate def model(N): # parameter probs_a = pyro.sample("probs_a", dist.Beta(0.5,0.5)) probs_a0b = pyro.sample("probs_a0b", dist.Beta(0.5,0.5)) probs_a1b = pyro.sample("probs_a1b", dist.Beta(0.5,0.5)) probs_b = torch.tensor([[1-probs_a0b, probs_a0b],[1-probs_a1b, probs_a1b]]) with pyro.plate("data", N): a = pyro.sample("a", dist.Bernoulli(logits=probs_a)) b = pyro.sample("b", dist.Bernoulli(logits=probs_b[a.long(),1])) return a, b # inference from pyro.infer import MCMC, NUTS conditioned_model = pyro.condition(model, data={"a": data_A, "b": data_B}) nuts_kernel = NUTS(conditioned_model, step_size = 1, adapt_mass_matrix=True) mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200) mcmc.run(N) hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}
1000 samples take me about one hour. Can someone give me some ideas to improve it? Thank you so much for your help!