I can recover similar posteriors to TFP for lambda1 and lambda2. The running time is pretty fast.
count_data = torch.tensor([
13, 24, 8, 24, 7, 35, 14, 11, 15, 11, 22, 22, 11, 57,
11, 19, 29, 6, 19, 12, 22, 12, 18, 72, 32, 9, 7, 13,
19, 23, 27, 20, 6, 17, 13, 10, 14, 6, 16, 15, 7, 2,
15, 15, 19, 70, 49, 7, 53, 22, 21, 31, 19, 11, 18, 20,
12, 35, 17, 23, 17, 4, 2, 31, 30, 13, 27, 0, 39, 37,
5, 14, 13, 22,
], dtype=torch.float)
def model(data):
alpha = (1. / data.mean())
lambda1 = pyro.sample("lambda1", dist.Exponential(rate=alpha))
lambda2 = pyro.sample("lambda2", dist.Exponential(rate=alpha))
tau = pyro.sample("tau", dist.Uniform(0, 1))
lambda1_size = int(tau.item() * data.size(0)) + 1
lambda2_size = data.size(0) - lambda1_size
lambda_ = torch.cat([lambda1.expand((lambda1_size,)), lambda2.expand((lambda2_size,))])
with pyro.plate("data", data.size(0)):
pyro.sample("obs", dist.Poisson(lambda_), obs=data)
nuts_kernel = NUTS(model, jit_compile=True)
posterior = MCMC(nuts_kernel, num_samples=10000, warmup_steps=5000, num_chains=1).run(count_data)
marginal = posterior.marginal(sites=["lambda1", "lambda2", "tau"]).support(flatten=True)
lambda_1_samples = marginal["lambda1"]
lambda_2_samples = marginal["lambda2"]
tau_samples = marginal["tau"]
However, posterior for tau
in Pyro is kind of Uniform over interval [0, 1]. @eb8680_2 @neerajprad, do you have some ideas how to backpropagate loss information to tau
?