Hi there,
I am trying to replicate in pyro the annotators.py example in numpyro.
The reparametrization is creating problems when combined with the automatic enumeration of the discrete variables in pyro.
The code below is exactly the same as the numpyro example, apart from the obvious translations from jax to torch.
def hierarchical_dawid_skene(positions: torch.Tensor, annotations: torch.Tensor) -> None:
"""
This model corresponds to the plate diagram in Figure 4 of reference [1].
"""
num_annotators = positions.unique().numel()
num_classes = annotations.unique().numel()
num_items, num_positions = annotations.shape
# debugging
print(f"{num_classes=}, {num_annotators=}, {num_items=}, {num_positions=}")
with pyro.plate("class", num_classes):
# NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is
# invariant up to a constant, so we'll follow [1]: fix the last term of `beta`
# to 0 and only define hyperpriors for the first `num_classes - 1` terms.
zeta = pyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
omega = pyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))
with pyro.plate("annotator", num_annotators, dim=-2):
with pyro.plate("class_abilities", num_classes):
# non-centered parameterization
with reparam(config={"beta": LocScaleReparam(centered=0.)}): # <- with this it does not work, beta is reshaped
beta = pyro.sample("beta", dist.Normal(zeta, omega).to_event(1)).
# pad 0 last dimension
beta = F.pad(beta, [0, 1] + [0, 0] * (beta.dim() - 1))
pi = pyro.sample("pi", dist.Dirichlet(torch.ones(num_classes)))
with pyro.plate("item", num_items, dim=-2):
c = pyro.sample("c", dist.Categorical(probs=pi))
# debugging
print(f"{c.shape=}, {beta.shape=}")
with pyro.plate("position", num_positions):
logits = Vindex(beta)[positions, c, :]
pyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
In the first MCMC iteration with NUTS (following the numpyro example), I get the following debug prints
num_classes=4, num_annotators=5, num_items=10, num_positions=7
c.shape=torch.Size([10, 1]), beta.shape=torch.Size([5, 4, 4])
In the second iteration, when c
is enumerated, I get the following debug prints
num_classes=4, num_annotators=5, num_items=10, num_positions=7
c.shape=torch.Size([4, 1, 1]), beta.shape=torch.Size([4, 4])
What puzzles me is the fact that beta
size is changed from (5, 4, 4)
to (4, 4)
. This does not happen when I remove the reparametrization.
Any suggestion on where to look to understand what happens?
Thanks a lot in advance for your time.
Best,
Pietro