A couple of notes:
- The model is not suitable for enumeration. See some discussion here. Your code violates the restriction 2: no arrow is allowed to go from an enumerated variable to outside of the plates enclosing it.
- If you want to use enumeration and accept the exponential slowness, you can just simply remove the plates enclosing your enumerated parameter:
for i in range(n_participants):
for j in range(n_skills):
skill[i][j] = sample(...)
- Using NUTS is unnecessary for models with no latent variable. Under enumeration, you can just use
infer_discrete. - Without enumeration, it is better to use
TraceGraph_ELBOobjective, which is available in both Pyro and NumPyro.