Hi, I have a Bayesian model (specifically the Dawid-Skene model) with the caveat that I want to have observations for some labels. The model can be described roughly as follows:
There are N total items, A total annotators, and C total classes. Each annotator classifies each item.
The prior distribution for the true label for each item is sampled from some categorical distribution, and conditional upon this true label, each annotator labels the item according to their own confusion matrix (so if the true label z=1, then they sample according to P(z’|z=1), which is a confusion matrix).
I’m trying to fit this model on some simulated data first:
def model(annotations, observed_z, mask):
num_items, num_raters = annotations.shape
num_classes = len(torch.unique(annotations))
# Prior for the class proportions
pi = pyro.sample("pi", dist.Dirichlet(torch.ones(num_classes)))
# Priors for the annotator confusion matrices
with pyro.plate("raters", num_raters):
theta = pyro.sample("theta", dist.Dirichlet(0.5 * torch.eye(num_classes) + 0.25).to_event(1))
with pyro.plate("items", num_items):
# with pyro.poutine.mask(mask=mask):
z = pyro.sample("z", dist.Categorical(pi), obs=observed_z, obs_mask=mask)
# z = pyro.sample("z", dist.Categorical(pi).mask(mask), obs=observed_z, obs_mask=mask)
# Condition on the observed values using a mask
# z = torch.where(mask, observed_z, z)
for r in pyro.plate("raters_loop", num_raters):
probs = theta[...,r, z, :]
pyro.sample(f"y_{r}", dist.Categorical(probs), obs=annotations[...,:, r])
nuts_kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True)
# print(observed_z[mask == False])
# nuts_kernel = NUTS(conditioned_model, jit_compile=True, ignore_jit_warnings=True)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)
mcmc.run(annotations_simulated, observed_z, mask)
Here, annotations_simulated is 1000 examples x 5 annotators, observed_z is 1000 examples x 1 label (those labels which are unobserved are set to -1), and mask is True where observed_z >=0 and False when it’s -1.
Yet, I get errors that MCMC algorithm is trying to using -1 as an index into theta. Specifically, it seems that the line where I am sampling z is not working — z is still -1 where observed_z is -1. Any ideas on how to fix this? I already verified I could sample with pyro.infer.Predictive
with parallel=True