Numpyro Chapter 2 MBML

A couple of notes:

  • The model is not suitable for enumeration. See some discussion here. Your code violates the restriction 2: no arrow is allowed to go from an enumerated variable to outside of the plates enclosing it.
  • If you want to use enumeration and accept the exponential slowness, you can just simply remove the plates enclosing your enumerated parameter:
    for i in range(n_participants):
        for j in range(n_skills):
            skill[i][j] = sample(...)
  • Using NUTS is unnecessary for models with no latent variable. Under enumeration, you can just use infer_discrete.
  • Without enumeration, it is better to use TraceGraph_ELBO objective, which is available in both Pyro and NumPyro.