MAP / Autodelta for Annotation Model

I created the dawid_skene model in pyro (following the example numpyro implementation here: Example: Bayesian Models of Annotation — NumPyro documentation) and would like to use MAP / Autodelta to infer the beta and pi parameters (their example implementation uses MCMC/NUTS). However, I’m finding that every row of beta seems to have similar values when I use Autodelta, unlike the MCMC output. How come this is the case? Am I not specifying Autodelta properly?

import numpy as np

import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS, Predictive
from pyro.infer.reparam import LocScaleReparam
from pyro.ops.indexing import Vindex
import torch
from pyro import poutine

def get_data():
    """
    :return: a tuple of annotator indices and class indices. The first term has shape
        `num_positions` whose entries take values from `0` to `num_annotators - 1`.
        The second term has shape `num_items x num_positions` whose entries take values
        from `0` to `num_classes - 1`.
    """
    # NB: the first annotator assessed each item 3 times
    positions = np.array([1, 1, 1, 2, 3, 4, 5])
    # fmt: off
    annotations = np.array(
        [[1, 1, 1, 1, 1, 1, 1], [3, 3, 3, 4, 3, 3, 4], [1, 1, 2, 2, 1, 2, 2],
         [2, 2, 2, 3, 1, 2, 1], [2, 2, 2, 3, 2, 2, 2], [2, 2, 2, 3, 3, 2, 2],
         [1, 2, 2, 2, 1, 1, 1], [3, 3, 3, 3, 4, 3, 3], [2, 2, 2, 2, 2, 2, 3],
         [2, 3, 2, 2, 2, 2, 3], [4, 4, 4, 4, 4, 4, 4], [2, 2, 2, 3, 3, 4, 3],
         [1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 3, 2, 1, 2], [1, 2, 1, 1, 1, 1, 1],
         [1, 1, 1, 2, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1],
         [2, 2, 2, 2, 2, 2, 1], [2, 2, 2, 1, 3, 2, 2], [2, 2, 2, 2, 2, 2, 2],
         [2, 2, 2, 2, 2, 2, 1], [2, 2, 2, 3, 2, 2, 2], [2, 2, 1, 2, 2, 2, 2],
         [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [2, 3, 2, 2, 2, 2, 2],
         [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 1, 2, 1, 1, 2, 1],
         [1, 1, 1, 1, 1, 1, 1], [3, 3, 3, 3, 2, 3, 3], [1, 1, 1, 1, 1, 1, 1],
         [2, 2, 2, 2, 2, 2, 2], [2, 2, 2, 3, 2, 3, 2], [4, 3, 3, 4, 3, 4, 3],
         [2, 2, 1, 2, 2, 3, 2], [2, 3, 2, 3, 2, 3, 3], [3, 3, 3, 3, 4, 3, 2],
         [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], [1, 2, 1, 2, 1, 1, 1],
         [2, 3, 2, 2, 2, 2, 2], [1, 2, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2]])
    # fmt: on
    # we minus 1 because in Python, the first index is 0
    return positions - 1, annotations - 1

def dawid_skene(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 2 of reference [1].
    """
    num_annotators = int(np.max(positions.numpy())) + 1
    num_classes = int(np.max(annotations.numpy())) + 1
    num_items, num_positions = annotations.shape

    with pyro.plate("annotator", num_annotators, dim=-2):
        with pyro.plate("class", num_classes):
            beta = pyro.sample(
                "beta",
                dist.Dirichlet(
                    torch.as_tensor(np.ones(num_classes), dtype=torch.float32)
                ),
            )

    pi = pyro.sample(
        "pi", dist.Dirichlet(torch.as_tensor(np.ones(num_classes), dtype=torch.float32))
    )

    with pyro.plate("item", num_items, dim=-2):
        c = pyro.sample("c", dist.Categorical(pi), infer={"enumerate": "parallel"})

        # here we use Vindex to allow broadcasting for the second index `c`
        # ref: http://num.pyro.ai/en/latest/utilities.html#numpyro.contrib.indexing.vindex
        with pyro.plate("position", num_positions):
            pyro.sample(
                "y", dist.Categorical(Vindex(beta)[positions, c, :]), obs=annotations
            )


annotators, annotations = get_data()
annotators = torch.as_tensor(annotators, dtype=torch.int32)
annotations = torch.as_tensor(annotations, dtype=torch.int32)

auto_guide = pyro.infer.autoguide.AutoDelta(poutine.block(dawid_skene, hide=["c"]))
adam = pyro.optim.Adam({"lr": 0.005})
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(dawid_skene, auto_guide, adam, elbo)

losses = []
for step in range(1000):  # Consider running for more steps.
    loss = svi.step(annotators, annotations)
    losses.append(loss)
    if step % 100 == 0:
        print("Elbo loss: {}".format(loss))

for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name).data.cpu().numpy())

The output looks like:

AutoDelta.beta [[[0.38091123 0.44739306 0.13653345 0.03516233]
  [0.3715091  0.4673514  0.12593347 0.035206  ]
  [0.40346456 0.4354736  0.12896016 0.03210165]
  [0.3762617  0.45263806 0.13539347 0.03570675]]

 [[0.35426164 0.3288546  0.24657635 0.07030745]
  [0.3430905  0.34670907 0.24421021 0.06599025]
  [0.37392765 0.3172224  0.24761106 0.06123893]
  [0.34890842 0.3403393  0.2386342  0.07211807]]

 [[0.44422567 0.3692978  0.11658386 0.06989267]
  [0.43524897 0.38630754 0.10851461 0.06992885]
  [0.4625954  0.37054363 0.10382633 0.06303472]
  [0.43369433 0.38469973 0.11444563 0.06716025]]

 [[0.4066831  0.37100366 0.15156713 0.07074611]
  [0.3870014  0.38656577 0.15755187 0.06888096]
  [0.4142693  0.36777595 0.15424019 0.06371451]
  [0.38995415 0.38550267 0.15771072 0.06683248]]

 [[0.46695903 0.3283722  0.15577653 0.04889225]
  [0.45142862 0.34985995 0.15032366 0.04838768]
  [0.48499504 0.31682262 0.15536885 0.04281344]
  [0.4584936  0.33643484 0.15625899 0.04881264]]]
AutoDelta.pi [0.25712258 0.22512354 0.25192493 0.26582897]