Numpyro Chapter 2 MBML

I think the joint density isn’t continuous here? I am not 100% sure though. The model below looks like it should work, but I get a ValueError: Expected the joint log density is a scalar

toy data:

responses_check = jnp.array([[0., 1., 0., 1., 0., 1., 0., 1.], [0., 0., 1., 1., 0., 0., 1., 1.], [0., 0., 0., 0., 1., 1., 1., 1.]])
skills_needed_check = [[0], [1], [0, 1]]
def model(
    graded_responses, skills_needed: list[list[int]], prob_mistake=0.1, prob_guess=0.2
):
    n_questions, n_participants = graded_responses.shape
    n_skills = max(map(max, skills_needed)) + 1
    
    participants_plate = numpyro.plate("participants_plate", n_participants)
    
    with participants_plate:
        with numpyro.plate("skills_plate", n_skills):
            theta = numpyro.sample("theta", dist.Beta(1,1))
    
    with participants_plate:
        skills = []
        for s in range(n_skills):
            skills.append(numpyro.sample("skill_{}".format(s), dist.Bernoulli(theta[s])))

    for q in range(n_questions):
        has_skills = reduce(operator.mul, [skills[i] for i in skills_needed[q]])
        prob_correct = has_skills * (1 - prob_mistake) + (1 - has_skills) * prob_guess
        isCorrect = numpyro.sample(
            "isCorrect_{}".format(q),
            dist.Bernoulli(prob_correct).to_event(1),
            obs=graded_responses[q],
        )

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=1000, num_chains=1, )
mcmc.run(rng_key, responses_check, skills_needed_check)
Output
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_154919/2526860559.py in <module>
      1 nuts_kernel = NUTS(model)
      2 mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=1000, num_chains=1, )
----> 3 mcmc.run(rng_key, responses_check, skills_needed_check)

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    564         map_args = (rng_key, init_state, init_params)
    565         if self.num_chains == 1:
--> 566             states_flat, last_state = partial_map_fn(map_args)
    567             states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    568         else:

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    353         rng_key, init_state, init_params = init
    354         if init_state is None:
--> 355             init_state = self.sampler.init(
    356                 rng_key,
    357                 self.num_warmup,

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    687                 vmap(random.split)(rng_key), 0, 1
    688             )
--> 689         init_params = self._init_state(
    690             rng_key_init_model, model_args, model_kwargs, init_params
    691         )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
    633     def _init_state(self, rng_key, model_args, model_kwargs, init_params):
    634         if self._model is not None:
--> 635             init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
    636                 rng_key,
    637                 self._model,

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    616         init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
    617     prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 618     (init_params, pe, grad), is_valid = find_valid_initial_params(
    619         rng_key,
    620         substitute(

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
    372     # Handle possible vectorization
    373     if rng_key.ndim == 1:
--> 374         (init_params, pe, z_grad), is_valid = _find_valid_params(
    375             rng_key, exit_early=True
    376         )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key, exit_early)
    358             # Early return if valid params found. This is only helpful for single chain,
    359             # where we can avoid compiling body_fn in while_loop.
--> 360             _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
    361             if not_jax_tracer(is_valid):
    362                 if device_get(is_valid):

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in body_fn(state)
    343                 z_grad = jacfwd(potential_fn)(params)
    344             else:
--> 345                 pe, z_grad = value_and_grad(potential_fn)(params)
    346             z_grad_flat = ravel_pytree(z_grad)[0]
    347             is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

    [... skipping hidden 7 frame]

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
    225     )
    226     # no param is needed for log_density computation because we already substitute
--> 227     log_joint, model_trace = log_density_(
    228         substituted_model, model_args, model_kwargs, {}
    229     )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py in log_density(model, model_args, model_kwargs, params)
    268     :return: log of joint density and a corresponding model trace
    269     """
--> 270     result, model_trace, _ = _enum_log_density(
    271         model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
    272     )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
    238     result = funsor.optimizer.apply_optimizer(lazy_result)
    239     if len(result.inputs) > 0:
--> 240         raise ValueError(
    241             "Expected the joint log density is a scalar, but got {}. "
    242             "There seems to be something wrong at the following sites: {}.".format(

ValueError: Expected the joint log density is a scalar, but got (2,). There seems to be something wrong at the following sites: {'_pyro_dim_2'}.