Hi all,
I am trying out the next chapter of MBML, chapter 2. I am starting from the awesome discussions found here on the forum and the code below is adopted from @fritzo’s suggestion, here. Unfortunately, I cannot get the model to work in numpyro. I think the problem stems from enumeration and requires a Vindex
, but I do not know where.
I can get some results in just Pyro, below, without any enumeration, but I would like to use numpyro and take advantage of enumeration.
Both models below, I am unsure how to debug. Model 1.
is trying to use enumeration, I couldn’t get that to work. I tried to take out the enumeration from 1.
to model 2.
, which worked in pyro, and I run into a similar but different error.
Any help would be appericated. =)
- Here is the model with
@config_enumeration
andinfer={"enumerate": "parallel"}
, most similar tofritzo
answer
ValueError: Invalid shape: expected (), actual (1,)
import operator
from functools import reduce
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.funsor import config_enumerate, enum, infer_discrete
from numpyro.handlers import seed, trace
from numpyro.infer import MCMC, NUTS
from numpyro.infer.util import Predictive
rng_key = jax.random.PRNGKey(2)
responses_check = jnp.array(
[
[0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0],
[0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
]
)
skills_needed_check = [[0], [1], [0, 1]]
@config_enumerate
def model(
graded_responses, skills_needed: list[list[int]], prob_mistake=0.1, prob_guess=0.2
):
n_questions, n_participants = graded_responses.shape
n_skills = max(map(max, skills_needed)) + 1
with numpyro.plate("participants_plate", n_participants):
with numpyro.plate("skills_plate", n_skills):
skills = numpyro.sample(
"skills", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}
)
for q in range(n_questions):
has_skills = reduce(operator.mul, [skills[i] for i in skills_needed[q]])
prob_correct = has_skills * (1 - prob_mistake) + (1 - has_skills) * prob_guess
isCorrect = numpyro.sample(
"isCorrect_{}".format(q),
dist.Bernoulli(prob_correct).to_event(1),
obs=graded_responses[q],
)
nuts_kernel = NUTS(model)
mcmc = MCMC(
nuts_kernel,
num_warmup=100,
num_samples=1000,
num_chains=1,
)
mcmc.run(rng_key, responses_check, skills_needed_check)
Output:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_131019/3237141500.py in <module>
54 )
55 # ValueError: Invalid shape: expected (), actual (1,)
---> 56 mcmc.run(rng_key, responses_check, skills_needed_check)
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
564 map_args = (rng_key, init_state, init_params)
565 if self.num_chains == 1:
--> 566 states_flat, last_state = partial_map_fn(map_args)
567 states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
568 else:
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
353 rng_key, init_state, init_params = init
354 if init_state is None:
--> 355 init_state = self.sampler.init(
356 rng_key,
357 self.num_warmup,
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
687 vmap(random.split)(rng_key), 0, 1
688 )
--> 689 init_params = self._init_state(
690 rng_key_init_model, model_args, model_kwargs, init_params
691 )
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
633 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
634 if self._model is not None:
--> 635 init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
636 rng_key,
637 self._model,
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
616 init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
617 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 618 (init_params, pe, grad), is_valid = find_valid_initial_params(
619 rng_key,
620 substitute(
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
372 # Handle possible vectorization
373 if rng_key.ndim == 1:
--> 374 (init_params, pe, z_grad), is_valid = _find_valid_params(
375 rng_key, exit_early=True
376 )
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key, exit_early)
358 # Early return if valid params found. This is only helpful for single chain,
359 # where we can avoid compiling body_fn in while_loop.
--> 360 _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
361 if not_jax_tracer(is_valid):
362 if device_get(is_valid):
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in body_fn(state)
343 z_grad = jacfwd(potential_fn)(params)
344 else:
--> 345 pe, z_grad = value_and_grad(potential_fn)(params)
346 z_grad_flat = ravel_pytree(z_grad)[0]
347 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
[... skipping hidden 7 frame]
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
225 )
226 # no param is needed for log_density computation because we already substitute
--> 227 log_joint, model_trace = log_density_(
228 substituted_model, model_args, model_kwargs, {}
229 )
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py in log_density(model, model_args, model_kwargs, params)
268 :return: log of joint density and a corresponding model trace
269 """
--> 270 result, model_trace, _ = _enum_log_density(
271 model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
272 )
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
179
180 dim_to_name = site["infer"]["dim_to_name"]
--> 181 log_prob_factor = funsor.to_funsor(
182 log_prob, output=funsor.Real, dim_to_name=dim_to_name
183 )
~/anaconda3/envs/mbml_numpyro/lib/python3.9/functools.py in wrapper(*args, **kw)
875 '1 positional argument')
876
--> 877 return dispatch(args[0].__class__)(*args, **kw)
878
879 funcname = getattr(func, '__name__', 'singledispatch function')
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/funsor/tensor.py in tensor_to_funsor(x, output, dim_to_name)
489 result = Tensor(x, dtype=output.dtype)
490 if result.output != output:
--> 491 raise ValueError(
492 "Invalid shape: expected {}, actual {}".format(
493 output.shape, result.output.shape
ValueError: Invalid shape: expected (), actual (1,)
- Remove
@config_enumeration
andinfer={"enumerate": "parallel"}
ValueError: Expected the joint log density is a scalar, but got (2,). There seems to be something wrong at the following sites: {'_pyro_dim_1'}.
import operator
from functools import reduce
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.funsor import config_enumerate, enum, infer_discrete
from numpyro.contrib.indexing import Vindex
from numpyro.handlers import seed, trace
from numpyro.infer import MCMC, NUTS
from numpyro.infer.util import Predictive
rng_key = jax.random.PRNGKey(2)
responses_check = jnp.array(
[
[0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0],
[0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
]
)
skills_needed_check = [[0], [1], [0, 1]]
def model(
graded_responses, skills_needed: list[list[int]], prob_mistake=0.1, prob_guess=0.2
):
n_questions, n_participants = graded_responses.shape
n_skills = max(map(max, skills_needed)) + 1
with numpyro.plate("participants_plate", n_participants):
skills = []
for s in range(n_skills):
skills.append(
numpyro.sample(
"skill_{}".format(s),
dist.Bernoulli(0.5),
infer={"enumerate": "parallel"},
)
)
for q in range(n_questions):
has_skills = reduce(operator.mul, [skills[i] for i in skills_needed[q]])
prob_correct = has_skills * (1 - prob_mistake) + (1 - has_skills) * prob_guess
isCorrect = numpyro.sample(
"isCorrect_{}".format(q),
dist.Bernoulli(prob_correct).to_event(1),
obs=graded_responses[q],
)
nuts_kernel = NUTS(model)
mcmc = MCMC(
nuts_kernel,
num_warmup=100,
num_samples=1000,
num_chains=1,
)
mcmc.run(rng_key, responses_check, skills_needed_check)
Output:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_131605/1969942747.py in <module>
61 )
62
---> 63 mcmc.run(rng_key, responses_check, skills_needed_check)
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
564 map_args = (rng_key, init_state, init_params)
565 if self.num_chains == 1:
--> 566 states_flat, last_state = partial_map_fn(map_args)
567 states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
568 else:
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
353 rng_key, init_state, init_params = init
354 if init_state is None:
--> 355 init_state = self.sampler.init(
356 rng_key,
357 self.num_warmup,
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
687 vmap(random.split)(rng_key), 0, 1
688 )
--> 689 init_params = self._init_state(
690 rng_key_init_model, model_args, model_kwargs, init_params
691 )
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
633 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
634 if self._model is not None:
--> 635 init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
636 rng_key,
637 self._model,
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
616 init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
617 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 618 (init_params, pe, grad), is_valid = find_valid_initial_params(
619 rng_key,
620 substitute(
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
372 # Handle possible vectorization
373 if rng_key.ndim == 1:
--> 374 (init_params, pe, z_grad), is_valid = _find_valid_params(
375 rng_key, exit_early=True
376 )
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key, exit_early)
358 # Early return if valid params found. This is only helpful for single chain,
359 # where we can avoid compiling body_fn in while_loop.
--> 360 _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
361 if not_jax_tracer(is_valid):
362 if device_get(is_valid):
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in body_fn(state)
343 z_grad = jacfwd(potential_fn)(params)
344 else:
--> 345 pe, z_grad = value_and_grad(potential_fn)(params)
346 z_grad_flat = ravel_pytree(z_grad)[0]
347 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))
[... skipping hidden 7 frame]
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
225 )
226 # no param is needed for log_density computation because we already substitute
--> 227 log_joint, model_trace = log_density_(
228 substituted_model, model_args, model_kwargs, {}
229 )
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py in log_density(model, model_args, model_kwargs, params)
268 :return: log of joint density and a corresponding model trace
269 """
--> 270 result, model_trace, _ = _enum_log_density(
271 model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
272 )
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
238 result = funsor.optimizer.apply_optimizer(lazy_result)
239 if len(result.inputs) > 0:
--> 240 raise ValueError(
241 "Expected the joint log density is a scalar, but got {}. "
242 "There seems to be something wrong at the following sites: {}.".format(
ValueError: Expected the joint log density is a scalar, but got (2,). There seems to be something wrong at the following sites: {'_pyro_dim_1'}.
Below is the pyro result without @config_enumeration
:
Pyro SVI results
The full code is a little long to share on the forum, but you can find the full code here at this repo
IsCorrect1 | IsCorrect2 | IsCorrect2 | P(csharp) | P(sql) | SVI skill_01 P(csharp) | SVI skill_02 P(sql) | |
---|---|---|---|---|---|---|---|
0 | False | False | False | 0.101 | 0.101 | 0.0789368 | 0.055503 |
1 | True | False | False | 0.802 | 0.034 | 0.939442 | 0.00339268 |
2 | False | True | False | 0.034 | 0.802 | 0.0224401 | 0.846974 |
3 | True | True | False | 0.561 | 0.561 | 0.547225 | 0.86883 |
4 | False | False | True | 0.148 | 0.148 | 0.0347185 | 0.0237291 |
5 | True | False | True | 0.862 | 0.326 | 0.883048 | 0.0900076 |
6 | False | True | True | 0.326 | 0.862 | 0.346815 | 0.838184 |
7 | True | True | True | 0.946 | 0.946 | 0.988479 | 0.874144 |