Issue when combining reparametrization and automatic enumeration

Hi there,

I am trying to replicate in pyro the annotators.py example in numpyro.

The reparametrization is creating problems when combined with the automatic enumeration of the discrete variables in pyro.

The code below is exactly the same as the numpyro example, apart from the obvious translations from jax to torch.

def hierarchical_dawid_skene(positions: torch.Tensor, annotations: torch.Tensor) -> None:
    """
    This model corresponds to the plate diagram in Figure 4 of reference [1].
    """
    num_annotators = positions.unique().numel()
    num_classes = annotations.unique().numel()
    num_items, num_positions = annotations.shape

    # debugging
    print(f"{num_classes=}, {num_annotators=}, {num_items=}, {num_positions=}")

    with pyro.plate("class", num_classes):
        # NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is
        # invariant up to a constant, so we'll follow [1]: fix the last term of `beta`
        # to 0 and only define hyperpriors for the first `num_classes - 1` terms.
        zeta = pyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
        omega = pyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))

    with pyro.plate("annotator", num_annotators, dim=-2):
        with pyro.plate("class_abilities", num_classes):
            # non-centered parameterization
            with reparam(config={"beta": LocScaleReparam(centered=0.)}):    # <- with this it does not work, beta is reshaped
                beta = pyro.sample("beta", dist.Normal(zeta, omega).to_event(1)).
            # pad 0 last dimension
            beta = F.pad(beta, [0, 1] + [0, 0] * (beta.dim() - 1))

    pi = pyro.sample("pi", dist.Dirichlet(torch.ones(num_classes)))

    with pyro.plate("item", num_items, dim=-2):
        c = pyro.sample("c", dist.Categorical(probs=pi))
        
        # debugging
        print(f"{c.shape=}, {beta.shape=}")

        with pyro.plate("position", num_positions):
            logits = Vindex(beta)[positions, c, :]
            pyro.sample("y", dist.Categorical(logits=logits), obs=annotations)

In the first MCMC iteration with NUTS (following the numpyro example), I get the following debug prints

num_classes=4, num_annotators=5, num_items=10, num_positions=7
c.shape=torch.Size([10, 1]), beta.shape=torch.Size([5, 4, 4])

In the second iteration, when c is enumerated, I get the following debug prints

num_classes=4, num_annotators=5, num_items=10, num_positions=7
c.shape=torch.Size([4, 1, 1]), beta.shape=torch.Size([4, 4])

What puzzles me is the fact that beta size is changed from (5, 4, 4) to (4, 4). This does not happen when I remove the reparametrization.

Any suggestion on where to look to understand what happens?

Thanks a lot in advance for your time.

Best,
Pietro

I think you can try to replace the above lines by

beta_base = pyro.sample("beta_base", dist.Normal(0., 1.).expand([num_classes-1]).to_event(1))
beta = beta_base * omega + zeta

If that code works then it is likely that there is an issue in LocScaleReparam.

1 Like

Hi @fehiepsi,

Thanks a lot for your answer. It works indeed! So I guess there is a problem with LocScaleReparam.

Also, as noted here, the performance of hierarchical_dawid_skene is extremely slower than the numpyro version (>5mins vs 30sec). Perhaps the manual parametrization causes the gap in performance, but it is really surprising for me.

Anyway, thank you very much - I hope this was useful :slight_smile:

Best,
Pietro

Thanks @pietrolesci! Could you make a github issue for this together with reproducible code? I can take a look at the issue. Manual reparameterization should have the same performance.

1 Like

Hi @fehiepsi, you can find the issue here.