I created the dawid_skene model in pyro (following the example numpyro implementation here: Example: Bayesian Models of Annotation — NumPyro documentation) and would like to use MAP / Autodelta to infer the beta
and pi
parameters (their example implementation uses MCMC/NUTS). However, I’m finding that every row of beta
seems to have similar values when I use Autodelta, unlike the MCMC output. How come this is the case? Am I not specifying Autodelta properly?
import numpy as np
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS, Predictive
from pyro.infer.reparam import LocScaleReparam
from pyro.ops.indexing import Vindex
import torch
from pyro import poutine
def get_data():
"""
:return: a tuple of annotator indices and class indices. The first term has shape
`num_positions` whose entries take values from `0` to `num_annotators - 1`.
The second term has shape `num_items x num_positions` whose entries take values
from `0` to `num_classes - 1`.
"""
# NB: the first annotator assessed each item 3 times
positions = np.array([1, 1, 1, 2, 3, 4, 5])
# fmt: off
annotations = np.array(
[[1, 1, 1, 1, 1, 1, 1], [3, 3, 3, 4, 3, 3, 4], [1, 1, 2, 2, 1, 2, 2],
[2, 2, 2, 3, 1, 2, 1], [2, 2, 2, 3, 2, 2, 2], [2, 2, 2, 3, 3, 2, 2],
[1, 2, 2, 2, 1, 1, 1], [3, 3, 3, 3, 4, 3, 3], [2, 2, 2, 2, 2, 2, 3],
[2, 3, 2, 2, 2, 2, 3], [4, 4, 4, 4, 4, 4, 4], [2, 2, 2, 3, 3, 4, 3],
[1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 3, 2, 1, 2], [1, 2, 1, 1, 1, 1, 1],
[1, 1, 1, 2, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 1], [2, 2, 2, 1, 3, 2, 2], [2, 2, 2, 2, 2, 2, 2],
[2, 2, 2, 2, 2, 2, 1], [2, 2, 2, 3, 2, 2, 2], [2, 2, 1, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [2, 3, 2, 2, 2, 2, 2],
[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 2, 1, 1, 2, 1],
[1, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 2, 3, 3], [1, 1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 2], [2, 2, 2, 3, 2, 3, 2], [4, 3, 3, 4, 3, 4, 3],
[2, 2, 1, 2, 2, 3, 2], [2, 3, 2, 3, 2, 3, 3], [3, 3, 3, 3, 4, 3, 2],
[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 2, 1, 2, 1, 1, 1],
[2, 3, 2, 2, 2, 2, 2], [1, 2, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2]])
# fmt: on
# we minus 1 because in Python, the first index is 0
return positions - 1, annotations - 1
def dawid_skene(positions, annotations):
"""
This model corresponds to the plate diagram in Figure 2 of reference [1].
"""
num_annotators = int(np.max(positions.numpy())) + 1
num_classes = int(np.max(annotations.numpy())) + 1
num_items, num_positions = annotations.shape
with pyro.plate("annotator", num_annotators, dim=-2):
with pyro.plate("class", num_classes):
beta = pyro.sample(
"beta",
dist.Dirichlet(
torch.as_tensor(np.ones(num_classes), dtype=torch.float32)
),
)
pi = pyro.sample(
"pi", dist.Dirichlet(torch.as_tensor(np.ones(num_classes), dtype=torch.float32))
)
with pyro.plate("item", num_items, dim=-2):
c = pyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})
# here we use Vindex to allow broadcasting for the second index `c`
# ref: http://num.pyro.ai/en/latest/utilities.html#numpyro.contrib.indexing.vindex
with pyro.plate("position", num_positions):
pyro.sample(
"y", dist.Categorical(Vindex(beta)[positions, c, :]), obs=annotations
)
annotators, annotations = get_data()
annotators = torch.as_tensor(annotators, dtype=torch.int32)
annotations = torch.as_tensor(annotations, dtype=torch.int32)
auto_guide = pyro.infer.autoguide.AutoDelta(poutine.block(dawid_skene, hide=["c"]))
adam = pyro.optim.Adam({"lr": 0.005})
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(dawid_skene, auto_guide, adam, elbo)
losses = []
for step in range(1000): # Consider running for more steps.
loss = svi.step(annotators, annotations)
losses.append(loss)
if step % 100 == 0:
print("Elbo loss: {}".format(loss))
for name, value in pyro.get_param_store().items():
print(name, pyro.param(name).data.cpu().numpy())
The output looks like:
AutoDelta.beta [[[0.38091123 0.44739306 0.13653345 0.03516233]
[0.3715091 0.4673514 0.12593347 0.035206 ]
[0.40346456 0.4354736 0.12896016 0.03210165]
[0.3762617 0.45263806 0.13539347 0.03570675]]
[[0.35426164 0.3288546 0.24657635 0.07030745]
[0.3430905 0.34670907 0.24421021 0.06599025]
[0.37392765 0.3172224 0.24761106 0.06123893]
[0.34890842 0.3403393 0.2386342 0.07211807]]
[[0.44422567 0.3692978 0.11658386 0.06989267]
[0.43524897 0.38630754 0.10851461 0.06992885]
[0.4625954 0.37054363 0.10382633 0.06303472]
[0.43369433 0.38469973 0.11444563 0.06716025]]
[[0.4066831 0.37100366 0.15156713 0.07074611]
[0.3870014 0.38656577 0.15755187 0.06888096]
[0.4142693 0.36777595 0.15424019 0.06371451]
[0.38995415 0.38550267 0.15771072 0.06683248]]
[[0.46695903 0.3283722 0.15577653 0.04889225]
[0.45142862 0.34985995 0.15032366 0.04838768]
[0.48499504 0.31682262 0.15536885 0.04281344]
[0.4584936 0.33643484 0.15625899 0.04881264]]]
AutoDelta.pi [0.25712258 0.22512354 0.25192493 0.26582897]