Shape changes when using broadcasting with SVI and AutoDelta guide

I’m trying to fit a 2PL IRT model with ordinal responses with fixed cutpoint offsets, like so:

import sys

import numpyro
import numpyro.distributions as dist
from jax import numpy as jnp
from jax import random
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta
from numpyro.infer.initialization import init_to_median
from numpyro.optim import Adam


def irt2pl(ncls, resp, word_subsample_size=None):
    nstud, nitems = resp.shape
    difficulty_offsets = numpyro.sample(
        "difficulty_offsets",
        dist.TransformedDistribution(
            dist.Normal(0, 1).expand([ncls - 1]), dist.transforms.OrderedTransform()
        ),
    )
    with numpyro.plate("nstud", nstud, dim=-1):
        abilities = numpyro.sample("abilities", dist.Normal())
    with numpyro.plate("nitems", nitems, dim=-1):
        difficulties = numpyro.sample("difficulties", dist.Normal())
        discriminations = numpyro.sample("discriminations", dist.HalfNormal())
    offset_difficulties = jnp.expand_dims(difficulties, 1) + jnp.expand_dims(
        difficulty_offsets, 0
    )
    print("abilities.shape", abilities.shape)
    print("discriminations.shape", discriminations.shape)
    predictor = jnp.expand_dims(abilities, 1) * jnp.expand_dims(discriminations, 0)
    cutpoints = jnp.expand_dims(
        offset_difficulties * jnp.expand_dims(discriminations, 0), 0
    )
    with numpyro.plate("nstud", nstud, dim=-2), numpyro.plate("nitems", nitems, dim=-1):
        numpyro.sample("resp", dist.OrderedLogistic(predictor, cutpoints), obs=resp)


resp = jnp.array(
    [[0, 1, 4, 3], [0, 1, 4, 3], [3, 4, 4, 4], [2, 2, 4, 4], [2, 2, 4, 4], [1, 2, 3, 3]]
)

if len(sys.argv) >= 2 and sys.argv[1] == "NUTS":
    kernel = NUTS(irt2pl, init_strategy=init_to_median())
    mcmc = MCMC(kernel, num_warmup=500, num_samples=2000)
    rng_key = random.PRNGKey(42)
    mcmc.run(rng_key, 5, resp, word_subsample_size=1000)
    mcmc.print_summary()
else:
    optim = Adam(0.1, 0.8, 0.99)
    elbo = Trace_ELBO()
    guide = AutoDelta(irt2pl, init_loc_fn=init_to_median())
    rng_key = random.PRNGKey(42)
    svi = SVI(irt2pl, guide, optim, loss=elbo)
    svi.run(rng_key, 200, 5, resp)
    print(guide(5, resp))

It works fine with NUTS. I get the following output:

$ python broken.py NUTS
abilities.shape (6,)
discriminations.shape (4,)
abilities.shape (6,)
discriminations.shape (4,)
abilities.shape (6,)
discriminations.shape (4,)
  0%|                                                                                                                                                                                     | 0/2500 [00:00<?, ?it/s]abilities.shape (6,)
discriminations.shape (4,)
sample: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2500/2500 [00:05<00:00, 483.15it/s, 26 steps of size 1.48e-01. acc. prob=0.83]

                           mean       std    median      5.0%     95.0%     n_eff     r_hat
         abilities[0]     -0.19      0.71     -0.20     -1.31      1.06    583.34      1.00
         abilities[1]     -0.18      0.72     -0.19     -1.36      0.96    691.21      1.00
         abilities[2]      1.31      0.77      1.27      0.04      2.59    609.96      1.00
         abilities[3]      0.76      0.71      0.74     -0.41      1.94    609.01      1.00
         abilities[4]      0.76      0.70      0.76     -0.32      1.91    505.14      1.00
         abilities[5]     -0.29      0.63     -0.31     -1.33      0.70    804.61      1.01
      difficulties[0]      0.74      0.65      0.74     -0.25      1.86    690.99      1.00
      difficulties[1]      0.54      0.66      0.55     -0.53      1.61    640.11      1.00
      difficulties[2]     -1.70      0.72     -1.69     -2.81     -0.44    761.89      1.00
      difficulties[3]     -1.21      0.67     -1.20     -2.26     -0.09    765.05      1.00
difficulty_offsets[0]     -1.56      0.70     -1.56     -2.58     -0.32    566.79      1.00
difficulty_offsets[1]     -0.67      0.76     -0.62     -1.93      0.51    420.50      1.00
difficulty_offsets[2]      0.27      0.63      0.27     -0.72      1.29    606.98      1.00
difficulty_offsets[3]      1.50      0.78      1.45      0.27      2.80    556.69      1.00
   discriminations[0]      1.22      0.58      1.17      0.26      2.11    349.73      1.00
   discriminations[1]      0.88      0.50      0.82      0.04      1.53    932.60      1.00
   discriminations[2]      1.82      0.56      1.76      0.91      2.70    650.47      1.00
   discriminations[3]      1.31      0.47      1.28      0.56      2.05    453.34      1.00

Number of divergences: 692

However, with SVI using the AutoDelta guide:

$ python broken.py
abilities.shape (6,)
discriminations.shape (4,)
abilities.shape (6,)
discriminations.shape (4,)
abilities.shape (6,)
discriminations.shape (4,)
abilities.shape (6,)
discriminations.shape (4,)
abilities.shape (6,)
discriminations.shape (4,)
abilities.shape (6, 6)
discriminations.shape (4,)
Traceback (most recent call last):
  File "broken.py", line 55, in <module>
    svi.run(rng_key, 200, 5, resp)
  File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/numpyro/infer/svi.py", line 201, in run
    svi_state = self.init(rng_key, *args, **kwargs)
  File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/numpyro/infer/svi.py", line 107, in init
    model_trace = trace(replay(model_init, guide_trace)).get_trace(*args, **kwargs, **self.static_kwargs)
  File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/numpyro/handlers.py", line 162, in get_trace
    self(*args, **kwargs)
  File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/numpyro/primitives.py", line 80, in __call__
    return self.fn(*args, **kwargs)
  File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/numpyro/primitives.py", line 80, in __call__
    return self.fn(*args, **kwargs)
  File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/numpyro/primitives.py", line 80, in __call__
    return self.fn(*args, **kwargs)
  File "broken.py", line 31, in irt2pl
    predictor = jnp.expand_dims(abilities, 1) * jnp.expand_dims(discriminations, 0)
  File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5256, in deferring_binary_op
    return binary_op(self, other)
  File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 396, in fn
    x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
  File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 333, in _promote_args
    return _promote_shapes(fun_name, *_promote_dtypes(*args))
  File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 251, in _promote_shapes
    result_rank = len(lax.broadcast_shapes(*shapes))
  File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/jax/_src/util.py", line 198, in wrapper
    return cached(bool(config.x64_enabled), *args, **kwargs)
  File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/jax/_src/util.py", line 191, in cached
    return f(*args, **kwargs)
  File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 97, in broadcast_shapes
    raise ValueError("Incompatible shapes for broadcasting: {}"
ValueError: Incompatible shapes for broadcasting: ((6, 1, 6), (1, 1, 4))

It looks like what happens is after some warm up iterations, abilities.shape changes, causing a broadcasting error. I’m quite new to Pyro/NumPyro and don’t completely understand all the mechanics, so I’m not really sure if where the problem is, but the fact it works with NUTS gives me some confidence that it should with SVI + an autoguide. Is the problem to do with expand_dims()? Is the AutoDelta guide somehow picking up the shape after expansion/broadcasting?

It seems that there is dim conflictions in your model

with numpyro.plate("nstud", nstud, dim=-1):
# then later
with numpyro.plate("nstud", nstud, dim=-2) ...

Could you make a FR for a better error message?

Okay, I see the problem now. I fixed it by renaming the second plates to nstud2 and resp2. This fixes it.

I opened the issue at FR: Fail with informative error message when same plate name used with inconsistent dimensions · Issue #1045 · pyro-ppl/numpyro · GitHub I’m still a bit confused about the inconsistency with MCMC. My working model is that different inference methods should vary in whether the converge, but be similar in terms of the errors they have inside the model. I suppose if this case is rejected early though, the inference algorithms will become more “exception-equivalent”.

Thanks for you help!