Numpyro Chapter 2 MBML

Hi all,

I am trying out the next chapter of MBML, chapter 2. I am starting from the awesome discussions found here on the forum and the code below is adopted from @fritzo’s suggestion, here. Unfortunately, I cannot get the model to work in numpyro. I think the problem stems from enumeration and requires a Vindex, but I do not know where.

I can get some results in just Pyro, below, without any enumeration, but I would like to use numpyro and take advantage of enumeration.

Both models below, I am unsure how to debug. Model 1. is trying to use enumeration, I couldn’t get that to work. I tried to take out the enumeration from 1. to model 2., which worked in pyro, and I run into a similar but different error.

Any help would be appericated. =)

  1. Here is the model with @config_enumeration and infer={"enumerate": "parallel"}, most similar to fritzo answer
ValueError: Invalid shape: expected (), actual (1,)
import operator
from functools import reduce

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.funsor import config_enumerate, enum, infer_discrete
from numpyro.handlers import seed, trace
from numpyro.infer import MCMC, NUTS
from numpyro.infer.util import Predictive

rng_key = jax.random.PRNGKey(2)

responses_check = jnp.array(
    [
        [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0],
        [0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0],
        [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
    ]
)
skills_needed_check = [[0], [1], [0, 1]]


@config_enumerate
def model(
    graded_responses, skills_needed: list[list[int]], prob_mistake=0.1, prob_guess=0.2
):
    n_questions, n_participants = graded_responses.shape
    n_skills = max(map(max, skills_needed)) + 1

    with numpyro.plate("participants_plate", n_participants):
        with numpyro.plate("skills_plate", n_skills):
            skills = numpyro.sample(
                "skills", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}
            )

    for q in range(n_questions):
        has_skills = reduce(operator.mul, [skills[i] for i in skills_needed[q]])
        prob_correct = has_skills * (1 - prob_mistake) + (1 - has_skills) * prob_guess
        isCorrect = numpyro.sample(
            "isCorrect_{}".format(q),
            dist.Bernoulli(prob_correct).to_event(1),
            obs=graded_responses[q],
        )


nuts_kernel = NUTS(model)
mcmc = MCMC(
    nuts_kernel,
    num_warmup=100,
    num_samples=1000,
    num_chains=1,
)
mcmc.run(rng_key, responses_check, skills_needed_check)

Output:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_131019/3237141500.py in <module>
     54 )
     55 # ValueError: Invalid shape: expected (), actual (1,)
---> 56 mcmc.run(rng_key, responses_check, skills_needed_check)

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    564         map_args = (rng_key, init_state, init_params)
    565         if self.num_chains == 1:
--> 566             states_flat, last_state = partial_map_fn(map_args)
    567             states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    568         else:

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    353         rng_key, init_state, init_params = init
    354         if init_state is None:
--> 355             init_state = self.sampler.init(
    356                 rng_key,
    357                 self.num_warmup,

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    687                 vmap(random.split)(rng_key), 0, 1
    688             )
--> 689         init_params = self._init_state(
    690             rng_key_init_model, model_args, model_kwargs, init_params
    691         )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
    633     def _init_state(self, rng_key, model_args, model_kwargs, init_params):
    634         if self._model is not None:
--> 635             init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
    636                 rng_key,
    637                 self._model,

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    616         init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
    617     prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 618     (init_params, pe, grad), is_valid = find_valid_initial_params(
    619         rng_key,
    620         substitute(

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
    372     # Handle possible vectorization
    373     if rng_key.ndim == 1:
--> 374         (init_params, pe, z_grad), is_valid = _find_valid_params(
    375             rng_key, exit_early=True
    376         )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key, exit_early)
    358             # Early return if valid params found. This is only helpful for single chain,
    359             # where we can avoid compiling body_fn in while_loop.
--> 360             _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
    361             if not_jax_tracer(is_valid):
    362                 if device_get(is_valid):

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in body_fn(state)
    343                 z_grad = jacfwd(potential_fn)(params)
    344             else:
--> 345                 pe, z_grad = value_and_grad(potential_fn)(params)
    346             z_grad_flat = ravel_pytree(z_grad)[0]
    347             is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

    [... skipping hidden 7 frame]

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
    225     )
    226     # no param is needed for log_density computation because we already substitute
--> 227     log_joint, model_trace = log_density_(
    228         substituted_model, model_args, model_kwargs, {}
    229     )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py in log_density(model, model_args, model_kwargs, params)
    268     :return: log of joint density and a corresponding model trace
    269     """
--> 270     result, model_trace, _ = _enum_log_density(
    271         model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
    272     )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
    179 
    180             dim_to_name = site["infer"]["dim_to_name"]
--> 181             log_prob_factor = funsor.to_funsor(
    182                 log_prob, output=funsor.Real, dim_to_name=dim_to_name
    183             )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/functools.py in wrapper(*args, **kw)
    875                             '1 positional argument')
    876 
--> 877         return dispatch(args[0].__class__)(*args, **kw)
    878 
    879     funcname = getattr(func, '__name__', 'singledispatch function')

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/funsor/tensor.py in tensor_to_funsor(x, output, dim_to_name)
    489         result = Tensor(x, dtype=output.dtype)
    490         if result.output != output:
--> 491             raise ValueError(
    492                 "Invalid shape: expected {}, actual {}".format(
    493                     output.shape, result.output.shape

ValueError: Invalid shape: expected (), actual (1,)
  1. Remove @config_enumeration and infer={"enumerate": "parallel"}
ValueError: Expected the joint log density is a scalar, but got (2,). There seems to be something wrong at the following sites: {'_pyro_dim_1'}.
import operator
from functools import reduce

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.funsor import config_enumerate, enum, infer_discrete
from numpyro.contrib.indexing import Vindex
from numpyro.handlers import seed, trace
from numpyro.infer import MCMC, NUTS
from numpyro.infer.util import Predictive

rng_key = jax.random.PRNGKey(2)

responses_check = jnp.array(
    [
        [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0],
        [0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0],
        [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
    ]
)
skills_needed_check = [[0], [1], [0, 1]]


def model(
    graded_responses, skills_needed: list[list[int]], prob_mistake=0.1, prob_guess=0.2
):
    n_questions, n_participants = graded_responses.shape
    n_skills = max(map(max, skills_needed)) + 1

    with numpyro.plate("participants_plate", n_participants):
        skills = []
        for s in range(n_skills):
            skills.append(
                numpyro.sample(
                    "skill_{}".format(s),
                    dist.Bernoulli(0.5),
                    infer={"enumerate": "parallel"},
                )
            )


    for q in range(n_questions):
        has_skills = reduce(operator.mul, [skills[i] for i in skills_needed[q]])
        prob_correct = has_skills * (1 - prob_mistake) + (1 - has_skills) * prob_guess


        isCorrect = numpyro.sample(
            "isCorrect_{}".format(q),
            dist.Bernoulli(prob_correct).to_event(1),
            obs=graded_responses[q],
        )

nuts_kernel = NUTS(model)
mcmc = MCMC(
    nuts_kernel,
    num_warmup=100,
    num_samples=1000,
    num_chains=1,
)

mcmc.run(rng_key, responses_check, skills_needed_check)

Output:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_131605/1969942747.py in <module>
     61 )
     62 
---> 63 mcmc.run(rng_key, responses_check, skills_needed_check)

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    564         map_args = (rng_key, init_state, init_params)
    565         if self.num_chains == 1:
--> 566             states_flat, last_state = partial_map_fn(map_args)
    567             states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    568         else:

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    353         rng_key, init_state, init_params = init
    354         if init_state is None:
--> 355             init_state = self.sampler.init(
    356                 rng_key,
    357                 self.num_warmup,

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    687                 vmap(random.split)(rng_key), 0, 1
    688             )
--> 689         init_params = self._init_state(
    690             rng_key_init_model, model_args, model_kwargs, init_params
    691         )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
    633     def _init_state(self, rng_key, model_args, model_kwargs, init_params):
    634         if self._model is not None:
--> 635             init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
    636                 rng_key,
    637                 self._model,

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    616         init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
    617     prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 618     (init_params, pe, grad), is_valid = find_valid_initial_params(
    619         rng_key,
    620         substitute(

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
    372     # Handle possible vectorization
    373     if rng_key.ndim == 1:
--> 374         (init_params, pe, z_grad), is_valid = _find_valid_params(
    375             rng_key, exit_early=True
    376         )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key, exit_early)
    358             # Early return if valid params found. This is only helpful for single chain,
    359             # where we can avoid compiling body_fn in while_loop.
--> 360             _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
    361             if not_jax_tracer(is_valid):
    362                 if device_get(is_valid):

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in body_fn(state)
    343                 z_grad = jacfwd(potential_fn)(params)
    344             else:
--> 345                 pe, z_grad = value_and_grad(potential_fn)(params)
    346             z_grad_flat = ravel_pytree(z_grad)[0]
    347             is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

    [... skipping hidden 7 frame]

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
    225     )
    226     # no param is needed for log_density computation because we already substitute
--> 227     log_joint, model_trace = log_density_(
    228         substituted_model, model_args, model_kwargs, {}
    229     )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py in log_density(model, model_args, model_kwargs, params)
    268     :return: log of joint density and a corresponding model trace
    269     """
--> 270     result, model_trace, _ = _enum_log_density(
    271         model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
    272     )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
    238     result = funsor.optimizer.apply_optimizer(lazy_result)
    239     if len(result.inputs) > 0:
--> 240         raise ValueError(
    241             "Expected the joint log density is a scalar, but got {}. "
    242             "There seems to be something wrong at the following sites: {}.".format(

ValueError: Expected the joint log density is a scalar, but got (2,). There seems to be something wrong at the following sites: {'_pyro_dim_1'}.

Below is the pyro result without @config_enumeration:

Pyro SVI results

The full code is a little long to share on the forum, but you can find the full code here at this repo

IsCorrect1 IsCorrect2 IsCorrect2 P(csharp) P(sql) SVI skill_01 P(csharp) SVI skill_02 P(sql)
0 False False False 0.101 0.101 0.0789368 0.055503
1 True False False 0.802 0.034 0.939442 0.00339268
2 False True False 0.034 0.802 0.0224401 0.846974
3 True True False 0.561 0.561 0.547225 0.86883
4 False False True 0.148 0.148 0.0347185 0.0237291
5 True False True 0.862 0.326 0.883048 0.0900076
6 False True True 0.326 0.862 0.346815 0.838184
7 True True True 0.946 0.946 0.988479 0.874144

A couple of notes:

  • The model is not suitable for enumeration. See some discussion here. Your code violates the restriction 2: no arrow is allowed to go from an enumerated variable to outside of the plates enclosing it.
  • If you want to use enumeration and accept the exponential slowness, you can just simply remove the plates enclosing your enumerated parameter:
    for i in range(n_participants):
        for j in range(n_skills):
            skill[i][j] = sample(...)
  • Using NUTS is unnecessary for models with no latent variable. Under enumeration, you can just use infer_discrete.
  • Without enumeration, it is better to use TraceGraph_ELBO objective, which is available in both Pyro and NumPyro.

Few follow up questions:

What data type are you suggesting for skills here? At first glance I thought it was a 2D jnp.array but of course this won’t work under enumeration.

Sorry if this is obvious, but I don’t understand why you think either model defined in the OP has no latent variables. Can you explain? Each person’s skill is an unobserved Berr(\theta_s). I would think NUTS is no good because it will marginalize out all discrete variables, but the model has latent variables.

If I gave each skill a beta prior, then NUTS should be an okay choice for inference?

What data type are you suggesting for skills here?

I think you can make it a list of lists of integers.

why you think either model defined in the OP has no latent variables

Sorry, my last reply was not clear. NUTS only works for continuous latent variables so it won’t work for your model. If you marginalize out the discrete latent variable skills, then your model does not have any latent variable. NUTS will run for that enumerated model but will return an empty result because there is no latent site to sample from.

If I gave each skill a beta prior, then NUTS should be an okay choice for inference?

I think so, as long as the joint density of your model is continuous (e.g. if you use something like skills_cont ~ Beta, skills = int(skill_cont), then NUTS might get into trouble.)

I think the joint density isn’t continuous here? I am not 100% sure though. The model below looks like it should work, but I get a ValueError: Expected the joint log density is a scalar

toy data:

responses_check = jnp.array([[0., 1., 0., 1., 0., 1., 0., 1.], [0., 0., 1., 1., 0., 0., 1., 1.], [0., 0., 0., 0., 1., 1., 1., 1.]])
skills_needed_check = [[0], [1], [0, 1]]
def model(
    graded_responses, skills_needed: list[list[int]], prob_mistake=0.1, prob_guess=0.2
):
    n_questions, n_participants = graded_responses.shape
    n_skills = max(map(max, skills_needed)) + 1
    
    participants_plate = numpyro.plate("participants_plate", n_participants)
    
    with participants_plate:
        with numpyro.plate("skills_plate", n_skills):
            theta = numpyro.sample("theta", dist.Beta(1,1))
    
    with participants_plate:
        skills = []
        for s in range(n_skills):
            skills.append(numpyro.sample("skill_{}".format(s), dist.Bernoulli(theta[s])))

    for q in range(n_questions):
        has_skills = reduce(operator.mul, [skills[i] for i in skills_needed[q]])
        prob_correct = has_skills * (1 - prob_mistake) + (1 - has_skills) * prob_guess
        isCorrect = numpyro.sample(
            "isCorrect_{}".format(q),
            dist.Bernoulli(prob_correct).to_event(1),
            obs=graded_responses[q],
        )

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=1000, num_chains=1, )
mcmc.run(rng_key, responses_check, skills_needed_check)
Output
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_154919/2526860559.py in <module>
      1 nuts_kernel = NUTS(model)
      2 mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=1000, num_chains=1, )
----> 3 mcmc.run(rng_key, responses_check, skills_needed_check)

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    564         map_args = (rng_key, init_state, init_params)
    565         if self.num_chains == 1:
--> 566             states_flat, last_state = partial_map_fn(map_args)
    567             states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    568         else:

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    353         rng_key, init_state, init_params = init
    354         if init_state is None:
--> 355             init_state = self.sampler.init(
    356                 rng_key,
    357                 self.num_warmup,

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    687                 vmap(random.split)(rng_key), 0, 1
    688             )
--> 689         init_params = self._init_state(
    690             rng_key_init_model, model_args, model_kwargs, init_params
    691         )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/hmc.py in _init_state(self, rng_key, model_args, model_kwargs, init_params)
    633     def _init_state(self, rng_key, model_args, model_kwargs, init_params):
    634         if self._model is not None:
--> 635             init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
    636                 rng_key,
    637                 self._model,

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    616         init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
    617     prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 618     (init_params, pe, grad), is_valid = find_valid_initial_params(
    619         rng_key,
    620         substitute(

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
    372     # Handle possible vectorization
    373     if rng_key.ndim == 1:
--> 374         (init_params, pe, z_grad), is_valid = _find_valid_params(
    375             rng_key, exit_early=True
    376         )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key, exit_early)
    358             # Early return if valid params found. This is only helpful for single chain,
    359             # where we can avoid compiling body_fn in while_loop.
--> 360             _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
    361             if not_jax_tracer(is_valid):
    362                 if device_get(is_valid):

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in body_fn(state)
    343                 z_grad = jacfwd(potential_fn)(params)
    344             else:
--> 345                 pe, z_grad = value_and_grad(potential_fn)(params)
    346             z_grad_flat = ravel_pytree(z_grad)[0]
    347             is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

    [... skipping hidden 7 frame]

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
    225     )
    226     # no param is needed for log_density computation because we already substitute
--> 227     log_joint, model_trace = log_density_(
    228         substituted_model, model_args, model_kwargs, {}
    229     )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py in log_density(model, model_args, model_kwargs, params)
    268     :return: log of joint density and a corresponding model trace
    269     """
--> 270     result, model_trace, _ = _enum_log_density(
    271         model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
    272     )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
    238     result = funsor.optimizer.apply_optimizer(lazy_result)
    239     if len(result.inputs) > 0:
--> 240         raise ValueError(
    241             "Expected the joint log density is a scalar, but got {}. "
    242             "There seems to be something wrong at the following sites: {}.".format(

ValueError: Expected the joint log density is a scalar, but got (2,). There seems to be something wrong at the following sites: {'_pyro_dim_2'}.

NUTS does not work with models with discrete late sites. If you are using enumeration, then pls make sure to use two for loops over skills site as explained in my comments above: no arrow is allowed from a discrete site to outside of plates surrounding it. I would recommend to use SVI (with TraceGraph_ELBO) or DiscreteHMCGibbs for your model because enumeration will be exponentially slow.

If you want to use NUTS, you can either:

  • Use 2 for loops as above, this is likely to be not slow if the ranges of those for loops are less than 5.
  • You can approximate skills by a continuous variable, then convert that continuous variable to a discrete one by casting x to int(x) (in your code, it will be skills=int(theta)). But the density is not continuous so NUTS will get trouble to give any reliable result.

I tried a lot of your suggestions. You can find the full code within the notebook, found at this repo.

I apologize if you have to repeat yourself, I am still learning. I thought I could run NUTS using the beta priors and, at the worst, marginalize out the discrete latent variables and be left with the beta priors as a good proxy for the latent skills, but this is wrong.

I think I understand this restriction from your earlier post, but I was able to get a model (model_05, link to notebook) to work with @config_enumeration and infer={"enumerate": "parallel"}. I am not sure the significance of this… 🤷

I did try to two for loops, model_01 (link to the notebook), of course, it’s very slow, but still recovers the expected answer.

Trying to avoid the two for loops, I was able to get the original model I posted earlier to work and recover the expected values. Both models are using DiscreteHMCGibbs they are found here as model_02 and model_03, respectfully.

I apologize, I should have mentioned earlier that I was able to get the SVI to run in numpyro, but I wanted to push on with other inference methods since SVI was not recovering the values as well. Shown the table below.

Using DiscreteHMCGibbs was the life changer, it got a lot of my first iterations of the model to run.

Gotcha, thank you for the clarification about density not being continuous. I think I might save this trick for later =)

Final Results

IsCorrect1 IsCorrect2 IsCorrect2 P(csharp) P(sql) model_00 P(csharp) model_00 P(sql) model_01a P(csharp) model_01a P(sql) model_01b P(csharp) model_01b P(sql) model_02 P(csharp) model_02 P(sql) model_03 P(csharp) model_03 P(sql) model_04 skill_01 P(csharp) model_04 skill_02 P(sql)
0 False False False 0.101 0.101 0.097 0.098 0.099 0.091 0.00133333 0.493667 0.094 0.099 0.099 0.091 0.0579186 0.0781705
1 True False False 0.802 0.034 0.799 0.036 0.795 0.031 0.497333 0.491333 0.804 0.033 0.795 0.031 0.716431 0.0181244
2 False True False 0.034 0.802 0.035 0.779 0.034 0.804 0.503667 0.495 0.032 0.796 0.034 0.804 0.218091 0.493445
3 True True False 0.561 0.561 0.544 0.55 0.553 0.547 0.503333 0.501333 0.56 0.535 0.553 0.547 0.56913 0.49497
4 False False True 0.148 0.148 0.124 0.14 0.15 0.151 0.509 0.489 0.129 0.14 0.15 0.151 0.0370603 0.11286
5 True False True 0.862 0.326 0.838 0.323 0.862 0.325 0.494 0.499 0.867 0.34 0.862 0.325 0.952262 0.870823
6 False True True 0.326 0.862 0.313 0.863 0.33 0.864 0.511333 0.516667 0.317 0.857 0.33 0.864 0.0602989 0.942486
7 True True True 0.946 0.946 0.943 0.944 0.93 0.931 0.496 0.496667 0.941 0.943 0.93 0.931 0.968268 0.962681

Thank you, @fehiepsi for your patience. I know I have now come across the same error and misconceptions from my previous posts, and I can see you are still willing to give help. Thank you for your time.

Very respectfully,
Ben

I think you should define nuts_kernel = NUTS(model_05) rather than nuts_kernel = NUTS(model_03). I think if you use enumeration here, DiscreteHMCGibbs will raise errors (otherwise, there’s likely a bug in the current implementation).

Glad that DiscreteHMCGibbs works for your model. For SVI, I guess we need some more advanced variance reduction techniques which is currently not supported in NumPyro. :smiley:

Updated the notebook at the repo.

whoops… Yeah looks like DiscreteHMCGibbs doing its job with:

AssertionError: Cannot detect any discrete latent variables in the model.

Does pyro’s SVI do this natively? I tried the same model with SVI in pyro, with similar results. I should compute some metrics to make it more concrete.

For SVI, you should use TraceGraph_ELBO, which is available in both Pyro and NumPyro (in the notebook, you used Trace_EBLO, which does not work for models with discrete latent sites). If you follow the link in my last comment, you will find that in Pyro, you can further reduce variance by using baseline. You might need to tune the baseline parameters a bit to get better results, or incorporate neural baselines. (I’m not familiar with baseline so couldn’t give a better suggestion :().

I updated the notebook, and with no tuning to reduce the variance using baseline_beta, Trace_EBLO and TraceGraph_ELBO seemed to give similar results. I am comparing the SVI results to Table 2.4 from the book which used belief propagation (which I think is exact?).

Anyway, I was able to tune the results to a reasonable error using baseline_beta. I will have to verify if this optimization is any good going forward with the rest of the chapter. I will update when I can.

Results

  • Note: the decay TraceGraph_ELBO columns are using the default value of 0.9 for baseline_beta
IsCorrect1 IsCorrect2 IsCorrect2 P(csharp) P(sql) Trace_ELBO P(csharp) Trace_ELBO P(sql) TraceGraph_ELBO P(csharp) TraceGraph_ELBO P(sql) decay TraceGraph_ELBO P(csharp) decay TraceGraph_ELBO P(sql) best decay TraceGraph_ELBO P(csharp) best decay TraceGraph_ELBO P(sql)
0 False False False 0.101 0.101 0.0789368 0.055503 0.0566014 0.188106 0.0855226 0.167065 0.159373 0.212234
1 True False False 0.802 0.034 0.939442 0.00339268 0.899589 0.167753 0.907181 0.0250883 0.754841 0.028081
2 False True False 0.034 0.802 0.0224401 0.846974 0.0126467 0.883312 0.00550104 0.91804 0.0283127 0.803002
3 True True False 0.561 0.561 0.547225 0.86883 0.522785 0.631924 0.500015 0.625142 0.577624 0.668692
4 False False True 0.148 0.148 0.0347185 0.0237291 0.0324805 0.272412 0.0588521 0.0657203 0.134006 0.0728378
5 True False True 0.862 0.326 0.883048 0.0900076 0.803443 0.475151 0.960243 0.437125 0.834419 0.349507
6 False True True 0.326 0.862 0.346815 0.838184 0.435617 0.997614 0.272805 0.910495 0.359426 0.893284
7 True True True 0.946 0.946 0.988479 0.874144 0.932451 0.710658 0.965964 0.965729 0.951077 0.949261