Plates vs vectorized factor statements

I am trying to understand the difference between plates and factor statements. I have a multi-stage decision making model, and the convergence looks very different with the same data when I use plates vs when I don’t. For instance, I’ve replaced the following line:

with numpyro.handlers.mask(mask=final_stage_mask):
        numpyro.sample('obs', dist.Bernoulli(probs=hit_rate), obs=y)

with this:

# Plate for observations
    with numpyro.plate('observations', M):
        with numpyro.handlers.mask(mask=final_stage_mask):
            numpyro.sample('obs', dist.Bernoulli(probs=hit_rate), obs=y)

I understand, at a high level, that plates are used to denote independent observations, but I don’t understand why these two implementations are providing such different convergences. I’ve uploaded a version (including the outputs) comparing the non-plated vs plated version to co-lab: 24_11_27_compare_implementations.ipynb - Google Drive.

Mathematically, or backend implementation-wise, what is the difference between using a plate vs not?

could you post a rendered link, using e.g. https://gist.github.com/?

why these two implementations are providing such different convergences

your interpretation is correct, the behavior would be the same (except for some specific inference algorithms which exploit conditional independency). there might be a bug somewhere

Sure, here’s a rendered link: Google Colab. I am using standard MCMC as my inference algorithm.

Your model is fairly complex to parse. It seems that you are using enumeration. If so, plates are required. We need named dimensions to perform sum-product correctly (basically sum(logsumexp(x, 0), 1) != logsumexp(sum(x, 1), 0)). It is surprised to me that you didn’t see a warning for the missing plate annotation.

1 Like

Got it, thank you! Enumeration only applies to discrete latent variables, is that correct?

That’s right.