I tried to use but it seems that requires small special handling for the Plate. Here is my model (a simplified version I’m using now to test some concepts before I dive deep) and I’m using some auto normal guide.
def model(y_tensor, n_tensor):
"""
All obs are splited into M segmetns. Each segment follows an NB distribution with different
(alpha, beta) as in the dist.GammaPoisson dist. The purpose is to estimate all alpha s and beta s
for all segments independtly but in this sigle code block (there is no dependency among segmetns for now, but later on
I may add partial pooling, features, etc).
y_tensor: size N. Standing for N obs, each obs belongs to one defined segment. Data are ordered from seg 1 to seg N.
n_tenosor: contains the obs count for seg 1, 2, ..., N with sum(n_tensor) = len(y_tensor).
"""
n_segments = len(n_tensor)
b0 = pyro.sample("b0", dist.Normal(loc=torch.zeros((n_segments,1)), scale=5))
phi = pyro.sample("phi", dist.HalfCauchy(2.0 * torch.ones((n_segments,1))))
mu = torch.exp(b0)
mu_reshape = mu.reshape(2, -1)
uplift = pyro.deterministic('uplift', mu_reshape[1, :] - mu_reshape[0, :])
beta = phi/mu
alpha = phi
# broad cast
alpha_cast = alpha.repeat_interleave(n_tensor)
beta_cast = beta.repeat_interleave(n_tensor)
with pyro.plate("data", len(y_tensor)):
pyro.sample("obs", dist.GammaPoisson(alpha_cast, beta_cast), obs=y_tensor)
# compute guide
%%time
y_tensor = torch.tensor(y_list, dtype=torch.float)
n_tensor = torch.tensor(n_list)
from pyro.infer.autoguide import AutoNormal, init_to_mean
from pyro.infer import SVI, Trace_ELBO
num_iters = 10000
guide = AutoNormal(model, init_loc_fn=init_to_mean)
svi = SVI(model,
guide,
optim.Adam({"lr": .001}),
loss=Trace_ELBO())
pyro.clear_param_store()
loss = []
for i in range(num_iters):
elbo = svi.step(y_tensor, n_tensor)
loss.append(elbo)
if i % 500 == 0:
print("Elbo loss: {}".format(elbo))
# sample nodes
%%time
from pyro.infer import Predictive
num_samples = 100000
predictive = Predictive(model, guide=guide, num_samples=num_samples, return_sites=('b0', 'phi', 'uplift'))
pred_res_raw = predictive(y_tensor, n_tensor)
It takes 6 mins to get 100k samples. If I set parallel=True
, it compalins
ValueError: Shape mismatch inside plate('_num_predictive_samples') at site b0 dim -2, 1000 vs 40