I think the joint density isn’t continuous here? I am not 100% sure though. The model below looks like it should work, but I get a ValueError: Expected the joint log density is a scalar
toy data:
responses_check = jnp.array([[0., 1., 0., 1., 0., 1., 0., 1.], [0., 0., 1., 1., 0., 0., 1., 1.], [0., 0., 0., 0., 1., 1., 1., 1.]])
skills_needed_check = [[0], [1], [0, 1]]
def model(
graded_responses, skills_needed: list[list[int]], prob_mistake=0.1, prob_guess=0.2
):
n_questions, n_participants = graded_responses.shape
n_skills = max(map(max, skills_needed)) + 1
participants_plate = numpyro.plate("participants_plate", n_participants)
with participants_plate:
with numpyro.plate("skills_plate", n_skills):
theta = numpyro.sample("theta", dist.Beta(1,1))
with participants_plate:
skills = []
for s in range(n_skills):
skills.append(numpyro.sample("skill_{}".format(s), dist.Bernoulli(theta[s])))
for q in range(n_questions):
has_skills = reduce(operator.mul, [skills[i] for i in skills_needed[q]])
prob_correct = has_skills * (1 - prob_mistake) + (1 - has_skills) * prob_guess
isCorrect = numpyro.sample(
"isCorrect_{}".format(q),
dist.Bernoulli(prob_correct).to_event(1),
obs=graded_responses[q],
)
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=1000, num_chains=1, )
mcmc.run(rng_key, responses_check, skills_needed_check)
Output
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_154919/2526860559.py in <module>
1 nuts_kernel = NUTS(model)
2 mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=1000, num_chains=1, )
----> 3 mcmc.run(rng_key, responses_check, skills_needed_check)
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
564 map_args = (rng_key, init_state, init_params)
565 if self.num_chains == 1:
--> 566 states_flat, last_state = partial_map_fn(map_args)
567 states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
568 else:
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
353 rng_key, init_state, init_params = init
354 if init_state is None:
--> 355 init_state = self.sampler.init(
356 rng_key,
357 self.num_warmup,
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
687 vmap(random.split)(rng_key), 0, 1
688 )
--> 689 init_params = self._init_state(
690 rng_key_init_model, model_args, model_kwargs, init_params
691 )
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
633 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
634 if self._model is not None:
--> 635 init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
636 rng_key,
637 self._model,
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
616 init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
617 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 618 (init_params, pe, grad), is_valid = find_valid_initial_params(
619 rng_key,
620 substitute(
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
372 # Handle possible vectorization
373 if rng_key.ndim == 1:
--> 374 (init_params, pe, z_grad), is_valid = _find_valid_params(
375 rng_key, exit_early=True
376 )
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key, exit_early)
358 # Early return if valid params found. This is only helpful for single chain,
359 # where we can avoid compiling body_fn in while_loop.
--> 360 _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
361 if not_jax_tracer(is_valid):
362 if device_get(is_valid):
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in body_fn(state)
343 z_grad = jacfwd(potential_fn)(params)
344 else:
--> 345 pe, z_grad = value_and_grad(potential_fn)(params)
346 z_grad_flat = ravel_pytree(z_grad)[0]
347 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
[... skipping hidden 7 frame]
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
225 )
226 # no param is needed for log_density computation because we already substitute
--> 227 log_joint, model_trace = log_density_(
228 substituted_model, model_args, model_kwargs, {}
229 )
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py in log_density(model, model_args, model_kwargs, params)
268 :return: log of joint density and a corresponding model trace
269 """
--> 270 result, model_trace, _ = _enum_log_density(
271 model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
272 )
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
238 result = funsor.optimizer.apply_optimizer(lazy_result)
239 if len(result.inputs) > 0:
--> 240 raise ValueError(
241 "Expected the joint log density is a scalar, but got {}. "
242 "There seems to be something wrong at the following sites: {}.".format(
ValueError: Expected the joint log density is a scalar, but got (2,). There seems to be something wrong at the following sites: {'_pyro_dim_2'}.