I’m trying to fit a 2PL IRT model with ordinal responses with fixed cutpoint offsets, like so:
import sys
import numpyro
import numpyro.distributions as dist
from jax import numpy as jnp
from jax import random
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta
from numpyro.infer.initialization import init_to_median
from numpyro.optim import Adam
def irt2pl(ncls, resp, word_subsample_size=None):
nstud, nitems = resp.shape
difficulty_offsets = numpyro.sample(
"difficulty_offsets",
dist.TransformedDistribution(
dist.Normal(0, 1).expand([ncls - 1]), dist.transforms.OrderedTransform()
),
)
with numpyro.plate("nstud", nstud, dim=-1):
abilities = numpyro.sample("abilities", dist.Normal())
with numpyro.plate("nitems", nitems, dim=-1):
difficulties = numpyro.sample("difficulties", dist.Normal())
discriminations = numpyro.sample("discriminations", dist.HalfNormal())
offset_difficulties = jnp.expand_dims(difficulties, 1) + jnp.expand_dims(
difficulty_offsets, 0
)
print("abilities.shape", abilities.shape)
print("discriminations.shape", discriminations.shape)
predictor = jnp.expand_dims(abilities, 1) * jnp.expand_dims(discriminations, 0)
cutpoints = jnp.expand_dims(
offset_difficulties * jnp.expand_dims(discriminations, 0), 0
)
with numpyro.plate("nstud", nstud, dim=-2), numpyro.plate("nitems", nitems, dim=-1):
numpyro.sample("resp", dist.OrderedLogistic(predictor, cutpoints), obs=resp)
resp = jnp.array(
[[0, 1, 4, 3], [0, 1, 4, 3], [3, 4, 4, 4], [2, 2, 4, 4], [2, 2, 4, 4], [1, 2, 3, 3]]
)
if len(sys.argv) >= 2 and sys.argv[1] == "NUTS":
kernel = NUTS(irt2pl, init_strategy=init_to_median())
mcmc = MCMC(kernel, num_warmup=500, num_samples=2000)
rng_key = random.PRNGKey(42)
mcmc.run(rng_key, 5, resp, word_subsample_size=1000)
mcmc.print_summary()
else:
optim = Adam(0.1, 0.8, 0.99)
elbo = Trace_ELBO()
guide = AutoDelta(irt2pl, init_loc_fn=init_to_median())
rng_key = random.PRNGKey(42)
svi = SVI(irt2pl, guide, optim, loss=elbo)
svi.run(rng_key, 200, 5, resp)
print(guide(5, resp))
It works fine with NUTS. I get the following output:
$ python broken.py NUTS
abilities.shape (6,)
discriminations.shape (4,)
abilities.shape (6,)
discriminations.shape (4,)
abilities.shape (6,)
discriminations.shape (4,)
0%| | 0/2500 [00:00<?, ?it/s]abilities.shape (6,)
discriminations.shape (4,)
sample: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2500/2500 [00:05<00:00, 483.15it/s, 26 steps of size 1.48e-01. acc. prob=0.83]
mean std median 5.0% 95.0% n_eff r_hat
abilities[0] -0.19 0.71 -0.20 -1.31 1.06 583.34 1.00
abilities[1] -0.18 0.72 -0.19 -1.36 0.96 691.21 1.00
abilities[2] 1.31 0.77 1.27 0.04 2.59 609.96 1.00
abilities[3] 0.76 0.71 0.74 -0.41 1.94 609.01 1.00
abilities[4] 0.76 0.70 0.76 -0.32 1.91 505.14 1.00
abilities[5] -0.29 0.63 -0.31 -1.33 0.70 804.61 1.01
difficulties[0] 0.74 0.65 0.74 -0.25 1.86 690.99 1.00
difficulties[1] 0.54 0.66 0.55 -0.53 1.61 640.11 1.00
difficulties[2] -1.70 0.72 -1.69 -2.81 -0.44 761.89 1.00
difficulties[3] -1.21 0.67 -1.20 -2.26 -0.09 765.05 1.00
difficulty_offsets[0] -1.56 0.70 -1.56 -2.58 -0.32 566.79 1.00
difficulty_offsets[1] -0.67 0.76 -0.62 -1.93 0.51 420.50 1.00
difficulty_offsets[2] 0.27 0.63 0.27 -0.72 1.29 606.98 1.00
difficulty_offsets[3] 1.50 0.78 1.45 0.27 2.80 556.69 1.00
discriminations[0] 1.22 0.58 1.17 0.26 2.11 349.73 1.00
discriminations[1] 0.88 0.50 0.82 0.04 1.53 932.60 1.00
discriminations[2] 1.82 0.56 1.76 0.91 2.70 650.47 1.00
discriminations[3] 1.31 0.47 1.28 0.56 2.05 453.34 1.00
Number of divergences: 692
However, with SVI using the AutoDelta guide:
$ python broken.py
abilities.shape (6,)
discriminations.shape (4,)
abilities.shape (6,)
discriminations.shape (4,)
abilities.shape (6,)
discriminations.shape (4,)
abilities.shape (6,)
discriminations.shape (4,)
abilities.shape (6,)
discriminations.shape (4,)
abilities.shape (6, 6)
discriminations.shape (4,)
Traceback (most recent call last):
File "broken.py", line 55, in <module>
svi.run(rng_key, 200, 5, resp)
File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/numpyro/infer/svi.py", line 201, in run
svi_state = self.init(rng_key, *args, **kwargs)
File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/numpyro/infer/svi.py", line 107, in init
model_trace = trace(replay(model_init, guide_trace)).get_trace(*args, **kwargs, **self.static_kwargs)
File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/numpyro/handlers.py", line 162, in get_trace
self(*args, **kwargs)
File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/numpyro/primitives.py", line 80, in __call__
return self.fn(*args, **kwargs)
File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/numpyro/primitives.py", line 80, in __call__
return self.fn(*args, **kwargs)
File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/numpyro/primitives.py", line 80, in __call__
return self.fn(*args, **kwargs)
File "broken.py", line 31, in irt2pl
predictor = jnp.expand_dims(abilities, 1) * jnp.expand_dims(discriminations, 0)
File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5256, in deferring_binary_op
return binary_op(self, other)
File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 396, in fn
x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 333, in _promote_args
return _promote_shapes(fun_name, *_promote_dtypes(*args))
File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 251, in _promote_shapes
result_rank = len(lax.broadcast_shapes(*shapes))
File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/jax/_src/util.py", line 198, in wrapper
return cached(bool(config.x64_enabled), *args, **kwargs)
File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/jax/_src/util.py", line 191, in cached
return f(*args, **kwargs)
File "/home/frankier/edu/doc/vocabmodel/.direnv/python-3.8.6/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 97, in broadcast_shapes
raise ValueError("Incompatible shapes for broadcasting: {}"
ValueError: Incompatible shapes for broadcasting: ((6, 1, 6), (1, 1, 4))
It looks like what happens is after some warm up iterations, abilities.shape
changes, causing a broadcasting error. I’m quite new to Pyro/NumPyro and don’t completely understand all the mechanics, so I’m not really sure if where the problem is, but the fact it works with NUTS gives me some confidence that it should with SVI + an autoguide. Is the problem to do with expand_dims()? Is the AutoDelta guide somehow picking up the shape after expansion/broadcasting?