Vectorizing guides

Hi!

I’m trying to implement an example from NumPyro on Bayesian models of annotation in Pyro using SVI and following the SVI Part I tutorial. The problem I’ve faced is that I cannot provide a guide for every posterior I want to estimate when I use pyro.plate — it’s possible to provide a guide for all conditionally independent variables at the same time.

I wrote a code that actually works, but it’s slow because it’s not vectorized. Here is what I’ve got for BCC (the same model as the dawid_skene in the example above).

def bcc(worker_pos, tasks_pos, worker_labels):
    n_tasks = torch.max(tasks_pos) + 1
    n_workers = torch.max(worker_pos) + 1
    n_labels = torch.max(worker_labels) + 1
    
    beta = []
    for i in range(n_workers):
        confusion_matrix = []
        for j in range(n_labels):
            init = torch.ones(n_labels).cuda()
            init[j] = n_labels - 1 if n_labels > 2 else 2
            confusion_matrix.append(pyro.sample(f'beta_{i}_{j}', dist.Dirichlet(init)))
        confusion_matrix = torch.stack(confusion_matrix)
        beta.append(confusion_matrix)
    beta = torch.stack(beta)

    with pyro.plate('items', n_tasks, dim=-1):
        pi = pyro.sample('pi', dist.Dirichlet(torch.ones(n_labels).cuda()))

    c = []
    for i in range(n_tasks):
        c_i = pyro.sample(f'c_{i}', dist.Categorical(pi[i].squeeze())).squeeze()
        c.append(c_i)
    c = torch.stack(c)
    for i in range(len(worker_labels)):
        y = pyro.sample(f'y_{i}', dist.Categorical(beta[worker_pos[i], c[tasks_pos[i]], :]), obs=worker_labels[i])

def bcc_guide(worker_pos, tasks_pos, worker_labels):
    n_tasks = torch.max(tasks_pos) + 1
    n_workers = torch.max(worker_pos) + 1
    n_labels = torch.max(worker_labels) + 1
    
    beta = []
    for i in range(n_workers):
        confusion_matrix = []
        confusion_matrix_q = []
        for j in range(n_labels):
            init = torch.ones(n_labels).cuda()
            init[j] = n_labels - 1 if n_labels > 2 else 2
            beta_i_j_q = pyro.param(f'beta_{i}_{j}_q', init)
            confusion_matrix.append(pyro.sample(f'beta_{i}_{j}', dist.Dirichlet(beta_i_j_q)))
        confusion_matrix = torch.stack(confusion_matrix)
        beta.append(confusion_matrix)
    beta = torch.stack(beta)
    
    pi_q = pyro.param('pi_q', torch.ones(n_labels).cuda())
    with pyro.plate('items', n_tasks, dim=-1):
        pi = pyro.sample('pi', dist.Dirichlet(pi_q))
    
    c = []
    for i in range(n_tasks):
        c_i_q = pyro.param(f'c_{i}_q', torch.ones(n_labels).cuda() / n_labels, constraint=constraints.simplex)
        c_i = pyro.sample(f'c_{i}', dist.Categorical(c_i_q)).squeeze()
        c.append(c_i)
    c = torch.stack(c)

I want to estimate a posterior for every, for instance, c_i, so I cannot use pyro.plate. I also tried to use batches with something like dist.Dirichlet(n_workers, n_labels, n_labels) and provide a guide as a three-dimensional vector, but it doesn’t work with SVI.

I’m pretty sure that I’m missing something, and there is an obvious way to vectorize this code, but I haven’t found anything similar to my case in docs. It would be great if you could provide a tip.

You can use pyro.plate in your model in exactly the same way as in the original example, and likewise in any guide you write. Have you taken a look at our tensor shape tutorial and the Gaussian mixture model example, which is an even simpler analogue of the mixture models in the annotation example? That might help you clear up any misunderstandings about how plates work in Pyro and NumPyro.

Yes, I’ve taken a look at these tutorials. I think the problem is that I need to generate a sample with batch_shape = (,), because otherwise, I won’t be able to provide a guide for every batch item (only for the entire distribution that I sampled a batch from). I didn’t figure out how to do it with pyro.plate, but I’ll take a look at the links you have provided once more. Thank you!