Partial observations with obs_mask

Hi, I have a Bayesian model (specifically the Dawid-Skene model) with the caveat that I want to have observations for some labels. The model can be described roughly as follows:

There are N total items, A total annotators, and C total classes. Each annotator classifies each item.

The prior distribution for the true label for each item is sampled from some categorical distribution, and conditional upon this true label, each annotator labels the item according to their own confusion matrix (so if the true label z=1, then they sample according to P(z’|z=1), which is a confusion matrix).

I’m trying to fit this model on some simulated data first:

def model(annotations, observed_z, mask):
    num_items, num_raters = annotations.shape
    num_classes = len(torch.unique(annotations)) 
    
    # Prior for the class proportions
    pi = pyro.sample("pi", dist.Dirichlet(torch.ones(num_classes)))

    # Priors for the annotator confusion matrices
    with pyro.plate("raters", num_raters):
        theta = pyro.sample("theta", dist.Dirichlet(0.5 * torch.eye(num_classes) + 0.25).to_event(1))


    with pyro.plate("items", num_items):
        # with pyro.poutine.mask(mask=mask): 
        z = pyro.sample("z", dist.Categorical(pi), obs=observed_z, obs_mask=mask)

        # z = pyro.sample("z", dist.Categorical(pi).mask(mask), obs=observed_z, obs_mask=mask)
        
        # Condition on the observed values using a mask
        # z = torch.where(mask, observed_z, z)
        
        for r in pyro.plate("raters_loop", num_raters):
            probs = theta[...,r, z, :]
            pyro.sample(f"y_{r}", dist.Categorical(probs), obs=annotations[...,:, r])

nuts_kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True)

# print(observed_z[mask == False])
# nuts_kernel = NUTS(conditioned_model, jit_compile=True, ignore_jit_warnings=True)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)
mcmc.run(annotations_simulated, observed_z, mask)

Here, annotations_simulated is 1000 examples x 5 annotators, observed_z is 1000 examples x 1 label (those labels which are unobserved are set to -1), and mask is True where observed_z >=0 and False when it’s -1.

Yet, I get errors that MCMC algorithm is trying to using -1 as an index into theta. Specifically, it seems that the line where I am sampling z is not working — z is still -1 where observed_z is -1. Any ideas on how to fix this? I already verified I could sample with pyro.infer.Predictive with parallel=True

@ssadhuka , could you provide a full reproducible code please?

Sure thing, @ordabayev

def model(annotations, observed_z, mask):
    num_items, num_raters = annotations.shape
    num_classes = len(torch.unique(annotations))
    
    # Prior for the class proportions
    pi = pyro.sample("pi", dist.Dirichlet(torch.ones(num_classes)))

    # Priors for the annotator confusion matrices
    with pyro.plate("raters", num_raters):
        theta = pyro.sample("theta", dist.Dirichlet(0.5 * torch.eye(num_classes) + 0.25).to_event(1))


    with pyro.plate("items", num_items):
        
        # with pyro.poutine.mask(mask=mask): 
        z = pyro.sample("z", dist.Categorical(pi), obs=observed_z, obs_mask=mask)
        
        for r in pyro.plate("raters_loop", num_raters):
            probs = theta[...,r, z, :]
            pyro.sample(f"y_{r}", dist.Categorical(probs), obs=annotations[...,:, r])





def simulate_data_with_observed(num_items, num_raters, true_pi, confusion_diagonal):
    num_classes = len(true_pi)
    
    # Create confusion matrix
    off_diag = (1 - confusion_diagonal) / (num_classes - 1)
    confusion_matrix = off_diag * torch.ones(num_classes, num_classes) + \
                       (confusion_diagonal - off_diag) * torch.eye(num_classes)
    
    # Simulate true labels
    z = dist.Categorical(true_pi).sample([num_items])
    
    # Create the observed_z tensor
    observed_z = torch.full((num_items,), -1, dtype=torch.int64)
    mask = torch.ones(num_items, dtype=torch.bool)
    
    # Randomly select 10% of the items to reveal their true labels
    num_observed = int(0.1 * num_items)
    observed_indices = torch.randperm(num_items)[:num_observed]
    observed_z[observed_indices] = z[observed_indices]
    mask[observed_indices] = False
    
    # Simulate annotations based on true labels and confusion matrices
    annotations = torch.zeros((num_items, num_raters), dtype=torch.int64)
    for i in range(num_items):
        for r in range(num_raters):
            probs = confusion_matrix[:, z[i]]
            annotations[i, r] = dist.Categorical(probs).sample()
    
    return annotations, observed_z, mask


num_items = 1000
num_raters = 2
true_pi = torch.tensor([0.1, 0.3, 0.6])
confusion_diagonal = 0.9

annotations_simulated, observed_z, mask = simulate_data_with_observed(num_items, num_raters, true_pi, confusion_diagonal)

nuts_kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)
mcmc.run(annotations_simulated, observed_z, mask)

The error I get is:

File /data/ddmg/frank/.conda/envs/frank/lib/python3.9/site-packages/torch/distributions/categorical.py:127, in Categorical.log_prob(self, value)
125 value, log_pmf = torch.broadcast_tensors(value, self.logits)
126 value = value[…, :1]
→ 127 return log_pmf.gather(-1, value).squeeze(-1)
RuntimeError: index -1 is out of bounds for dimension 1 with size 3

It’s specifically generated from the line probs = theta[…,r, z, :] and is trying to use the masked indices for observed_z (which are set to -1) to index into theta. However, I thought that the pyro.sample statement above should take care of the -1s through the obs_mask argument.

It’s actually set to False when observed in your code:

Also add this to your code:

        # replace -1 with 0
+       observed_z = torch.max(observed_z, torch.zeros_like(observed_z))
        # invert mask: obs_mask=~mask
        z = pyro.sample("z", dist.Categorical(pi), obs=observed_z, obs_mask=~mask)

The reason is that Pyro will first calculate the log_prob for z_observed and then mask out for non-observed z. Therefore you need to provide valid values instead of -1 so it can first calculate log_probs before masking. It doesn’t matter what values you provide (since they will be masked by obs_mask) as long as they are within supported values (e.g. zeros or ones)

Thank you! Yeah, I changed the mask to the opposite just to check if that was the issue; sorry, should’ve sent the older version. I’ll try out the torch.max line.

@ordabayev Thanks for the response. One other quick question in this implementation: I want to make the weight of the observed examples to be 2x compared to the unobserved examples in the likelihood. How can I do that?

You can use poutine.scale: Poutine (Effect handlers) — Pyro documentation

It can be used as a context manager:

scale = torch.tensor(...)  # 2 for observed, 1 for unobserved
with pyro.poutine.scale(scale):
    z = pyro.sample("z", dist.Categorical(pi), obs=observed_z, obs_mask=~mask)
1 Like

@ordabayev Thanks. Last question I had was on the difference between the above implementation and another implementation of the same method which I coded up in the time between my initial question and the first response:

def bayesian_model(annotations, observed_z, J, N, num_classes=3):
    # Priors for class probabilities
    pi = pyro.sample("pi", dist.Dirichlet(torch.ones(num_classes)))
    
    # Priors for annotator confusion matrices
    # theta = torch.empty(J, num_classes, num_classes)
    with pyro.plate("raters", J):
        theta = pyro.sample("theta", dist.Dirichlet(0.5 * torch.eye(num_classes) + 0.25).to_event(1))
    
    with pyro.plate("data", N):
        observed_weight = torch.tensor(2.0)
        is_observed = observed_z >= 0
        weight = torch.where(is_observed, observed_weight, torch.tensor(1.0))
        unobserved_z = pyro.sample("unobserved_z", dist.Categorical(probs=pi), infer={"enumerate": "parallel"})
        
        # Combine observed and unobserved labels
        true_labels = torch.where(observed_z >= 0, observed_z, unobserved_z)
        
        # Likelihood for annotations based on true labels
        for r in pyro.plate("raters_loop", J):
            probs = theta[r, true_labels, :]
            annotation = annotations[:, r]

            # Calculate log likelihood
            categorical = dist.Categorical(probs)
            log_likelihood = categorical.log_prob(annotation)

            # Weight the observed data
            is_observed = observed_z >= 0
            weight = torch.where(is_observed, observed_weight, torch.tensor(1.0))
            weighted_log_likelihood = weight * log_likelihood

            # Use pyro.factor to include the weighted log likelihood in the model
            pyro.sample(f"y_{r}", dist.Categorical(probs), obs=annotation)
            pyro.factor(f"obs_{r}", weighted_log_likelihood)

Both of these models converge to the true parameters on purely simulated data; however, when I run them on some other data they give notably different results. What, in the backend, is the difference between these two implementations? It seemed to me that this second implementation is the same as the first one except manual updates of the likelihood weights, for instance.

@ssadhuka at least one difference is that in the second example pyro.sample("unobserved_z", ...) is not masked unlike in the first case where it is implicitly masked (this makes them different models with different log_probs). Also you are using enumeration in once case and not the other. It is still the same model but inference and convergence might be different.

1 Like