Hi!
I’m trying to implement an example from NumPyro on Bayesian models of annotation in Pyro using SVI and following the SVI Part I tutorial. The problem I’ve faced is that I cannot provide a guide for every posterior I want to estimate when I use pyro.plate
— it’s possible to provide a guide for all conditionally independent variables at the same time.
I wrote a code that actually works, but it’s slow because it’s not vectorized. Here is what I’ve got for BCC (the same model as the dawid_skene
in the example above).
def bcc(worker_pos, tasks_pos, worker_labels):
n_tasks = torch.max(tasks_pos) + 1
n_workers = torch.max(worker_pos) + 1
n_labels = torch.max(worker_labels) + 1
beta = []
for i in range(n_workers):
confusion_matrix = []
for j in range(n_labels):
init = torch.ones(n_labels).cuda()
init[j] = n_labels - 1 if n_labels > 2 else 2
confusion_matrix.append(pyro.sample(f'beta_{i}_{j}', dist.Dirichlet(init)))
confusion_matrix = torch.stack(confusion_matrix)
beta.append(confusion_matrix)
beta = torch.stack(beta)
with pyro.plate('items', n_tasks, dim=-1):
pi = pyro.sample('pi', dist.Dirichlet(torch.ones(n_labels).cuda()))
c = []
for i in range(n_tasks):
c_i = pyro.sample(f'c_{i}', dist.Categorical(pi[i].squeeze())).squeeze()
c.append(c_i)
c = torch.stack(c)
for i in range(len(worker_labels)):
y = pyro.sample(f'y_{i}', dist.Categorical(beta[worker_pos[i], c[tasks_pos[i]], :]), obs=worker_labels[i])
def bcc_guide(worker_pos, tasks_pos, worker_labels):
n_tasks = torch.max(tasks_pos) + 1
n_workers = torch.max(worker_pos) + 1
n_labels = torch.max(worker_labels) + 1
beta = []
for i in range(n_workers):
confusion_matrix = []
confusion_matrix_q = []
for j in range(n_labels):
init = torch.ones(n_labels).cuda()
init[j] = n_labels - 1 if n_labels > 2 else 2
beta_i_j_q = pyro.param(f'beta_{i}_{j}_q', init)
confusion_matrix.append(pyro.sample(f'beta_{i}_{j}', dist.Dirichlet(beta_i_j_q)))
confusion_matrix = torch.stack(confusion_matrix)
beta.append(confusion_matrix)
beta = torch.stack(beta)
pi_q = pyro.param('pi_q', torch.ones(n_labels).cuda())
with pyro.plate('items', n_tasks, dim=-1):
pi = pyro.sample('pi', dist.Dirichlet(pi_q))
c = []
for i in range(n_tasks):
c_i_q = pyro.param(f'c_{i}_q', torch.ones(n_labels).cuda() / n_labels, constraint=constraints.simplex)
c_i = pyro.sample(f'c_{i}', dist.Categorical(c_i_q)).squeeze()
c.append(c_i)
c = torch.stack(c)
I want to estimate a posterior for every, for instance, c_i
, so I cannot use pyro.plate
. I also tried to use batches with something like dist.Dirichlet(n_workers, n_labels, n_labels)
and provide a guide as a three-dimensional vector, but it doesn’t work with SVI.
I’m pretty sure that I’m missing something, and there is an obvious way to vectorize this code, but I haven’t found anything similar to my case in docs. It would be great if you could provide a tip.