ValueError: Expected the joint log density is a scalar

Hi all,

I am trying to take the model below (which works, from MBML chapter 1), and switch out the hardcoded samples murderer and hair into a plate since the model can leverage the conditional independence.

Minimal Code example
import arviz as az
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import pandas as pd
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs
from numpyro.infer.util import Predictive

rng_key = jax.random.PRNGKey(2)
rng_key, rng_key_ = jax.random.split(rng_key)

num_warmup = 1000
num_samples = 1000
num_chains = 4

guess = 0.7

def mystery_extend(guess):
    weapon_cpt = jnp.array([[0.9, 0.1], [0.2, 0.8]])
    hair_cpt = jnp.array([[0.5, 0.5], [0.95, 0.05]])
    murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
    weapon = numpyro.sample("weapon", dist.Categorical(weapon_cpt[murderer]))
    hair = numpyro.sample("hair", dist.Categorical(hair_cpt[murderer]))
    return murderer, weapon, hair

conditioned_model_extend = numpyro.handlers.condition(mystery_extend, {"weapon": 0.0, "hair": 1.0})
predictive = Predictive(conditioned_model_extend, num_samples=num_chains*num_samples, infer_discrete=True,)
discrete_samples = predictive(rng_key_, guess)

for key in discrete_samples.keys():
    discrete_samples[key] = np.array(discrete_samples[key].reshape(num_chains,num_samples))
    
az.stats.summary(discrete_samples["murderer"], hdi_prob=0.9)

Output:

mean sd hdi_5% hdi_95% mcse_mean mcse_sd ess_bulk ess_tail r_hat
x 0.052 0.222 0.0 0.0 0.004 0.002 4023.0 4000.0 1.0
az.plot_trace(discrete_samples["murderer"]);

Output:

Yet I cannot figure out this ValueError: Expected the joint log density is a scalar, once I include the plate. Also when I try to just sample the model it seems the model is not returning the expected inference.

Ancestral sampling
def mystery_plate(guess):
    murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
    evidence = jnp.array([[[0.9, 0.1], [0.2, 0.8]], [[0.5, 0.5], [0.95, 0.05]], ])
    size = len(evidence)
    with numpyro.plate(f'i=1..{size}', size=size, dim=-1):
        obs = numpyro.sample("evidence", dist.Categorical(evidence[murderer]))
    return murderer, obs

numpyro.render_model(mystery_plate, (guess,), render_distributions=True)

conditioned_model_plate = numpyro.handlers.condition(mystery_plate, {"evidence": jnp.array([0., 1.])})

with numpyro.handlers.seed(rng_seed=0):
    samples = []
    for _ in range(5000):
        murderer, _ = conditioned_model_plate(guess)
        samples.append((murderer.item(),))
pd.DataFrame(samples, columns=["murderer"])["murderer"].value_counts(normalize=True)

Output:

1    0.6862
0    0.3138
Name: murderer, dtype: float64
ValueError: Expected the joint log density is a scalar
rng_key_, rng_key__ = jax.random.split(rng_key_)

predictive = Predictive(conditioned_model_plate, num_samples=num_chains*num_samples, infer_discrete=True,)
discrete_samples = predictive(rng_key__, guess)

Output:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_48865/2729934207.py in <module>
      1 predictive = Predictive(conditioned_model_plate, num_samples=num_chains*num_samples, infer_discrete=True,)
----> 2 discrete_samples = predictive(rng_key__, guess)

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in __call__(self, rng_key, *args, **kwargs)
    892             )
    893         model = substitute(self.model, self.params)
--> 894         return _predictive(
    895             rng_key,
    896             model,

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, model_args, model_kwargs)
    737     rng_key = rng_key.reshape(batch_shape + (2,))
    738     chunk_size = num_samples if parallel else 1
--> 739     return soft_vmap(
    740         single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size
    741     )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/util.py in soft_vmap(fn, xs, batch_ndims, chunk_size)
    403         fn = vmap(fn)
    404 
--> 405     ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
    406     map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
    407     ys = tree_map(

    [... skipping hidden 15 frame]

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in single_prediction(val)
    702             model_trace = prototype_trace
    703             temperature = 1
--> 704             pred_samples = _sample_posterior(
    705                 config_enumerate(condition(model, samples)),
    706                 first_available_dim,

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/funsor/discrete.py in _sample_posterior(model, first_available_dim, temperature, rng_key, *args, **kwargs)
     60     with funsor.adjoint.AdjointTape() as tape:
     61         with block(), enum(first_available_dim=first_available_dim):
---> 62             log_prob, model_tr, log_measures = _enum_log_density(
     63                 model, args, kwargs, {}, sum_op, prod_op
     64             )

~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
    238     result = funsor.optimizer.apply_optimizer(lazy_result)
    239     if len(result.inputs) > 0:
--> 240         raise ValueError(
    241             "Expected the joint log density is a scalar, but got {}. "
    242             "There seems to be something wrong at the following sites: {}.".format(

ValueError: Expected the joint log density is a scalar, but got (2,). There seems to be something wrong at the following sites: {'_pyro_dim_3'}.

Here is my list where I think might be the source to both problems, but I am not sure…

  • Misuse use of categorical distribution with probs being N-dim, but this seems correct based off the docs
  • Maybe this needs a VIndex?
  • Maybe I am breaking on of the restrictions?

You are right, because murderer is enumerated, we need to handle the shape of evidence[murderer] properly. One solution is to replace it by evidence[murderer.squeeze()]. A better solution is to use the method at the end of this section in enumeration tutorial:

with numpyro.plate(f'i=1..{size}', size=size, dim=-1) as i:
    obs = numpyro.sample("evidence",
        dist.Categorical(Vindex(evidence)[murderer, i]))

I would recommend to use the second method (also see vindex doc string, which explains well the shape semantic of vindex) because it will work for more complicated cases.

Thank you @fehiepsi, The vindex did the trick. I am going to have to spend a little more time with the examples from the docs to make sure I am solid. Below is the complete minimal code to reproduce:

Using VIndex
import arviz as az
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
from numpyro.contrib.indexing import Vindex
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs
from numpyro.infer.util import Predictive

num_warmup = 1000
num_samples = 1000
num_chains = 4

guess = 0.7

rng_key = jax.random.PRNGKey(2)
rng_key, rng_key_ = jax.random.split(rng_key)
rng_key_, rng_key__ = jax.random.split(rng_key_)

def mystery_plate(guess):
    murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
    evidence = jnp.array([[[0.9, 0.1], [0.2, 0.8]], [[0.5, 0.5], [0.95, 0.05]], ])
    size = len(evidence)
    with numpyro.plate(f'i=1..{size}', size=size, dim=-1) as i:
        obs = numpyro.sample("evidence", dist.Categorical(Vindex(evidence)[murderer, i]))
    return murderer, obs

conditioned_model_plate = numpyro.handlers.condition(mystery_plate, {"evidence": jnp.array([0., 1.])})

predictive = Predictive(conditioned_model_plate, num_samples=num_chains*num_samples, infer_discrete=True,)
discrete_samples = predictive(rng_key__, guess)

discrete_samples["murderer"] = np.array(discrete_samples["murderer"].reshape(num_chains,num_samples))
az.stats.summary(discrete_samples["murderer"], hdi_prob=0.9)

Output:

mean sd hdi_5% hdi_95% mcse_mean mcse_sd ess_bulk ess_tail r_hat
x 0.08 0.271 0.0 0.0 0.004 0.003 3744.0 3744.0 1.0
az.plot_trace(discrete_samples["murderer"]);

Output:

Two follow up questions:

  1. I notice from this tutorial that pyro has a utility to print out the shapes of each site within a model, but the numpyro’s trace doesn’t have same method. Does numpyro have something similar to Trace.format_shapes()?

  2. Do you have any thoughts and or comments on why the sampling under the context with numpyro.handlers.seed(rng_seed=0) wasn’t working for the model above? I get the same results even with vindex.

Could you make a feature request for infer_shapes? Currently we don’t have such utility. About the error with seed handler, could you be more explicit? I don’t know which code you refered to.

I will go ahead and submit the FR.

Example code
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
from numpyro.contrib.indexing import Vindex
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs
from numpyro.infer.util import Predictive
import pandas as pd

guess = 0.7

def mystery_plate(guess):
    murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
    evidence = jnp.array([[[0.9, 0.1], [0.2, 0.8]], [[0.5, 0.5], [0.95, 0.05]], ])
    size = len(evidence)
    with numpyro.plate(f'i=1..{size}', size=size, dim=-1) as i:
        obs = numpyro.sample("evidence", dist.Categorical(Vindex(evidence)[murderer, i]))
    return murderer, obs

conditioned_model_plate = numpyro.handlers.condition(mystery_plate, {"evidence": jnp.array([0., 1.])})

with numpyro.handlers.seed(rng_seed=0):
    samples = []
    for _ in range(5000):
        murderer, _ = conditioned_model_plate(guess)
        samples.append((murderer.item(),))
pd.DataFrame(samples, columns=["murderer"])["murderer"].value_counts(normalize=True)

Output:

1    0.6862
0    0.3138
Name: murderer, dtype: float64

With the code above, I am conditioning then I just sample under the handler, yet it seems the condition has no effect since the sample site murderer seems close to the initial input of 0.7. I would expect a value of ~0.05 given the condition.

I think you’ll need to move seed handler inside the for loop to deal with name conflictions (multiple sites have the same names).

seed handler within the for loop
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
from numpyro.contrib.indexing import Vindex
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs
from numpyro.infer.util import Predictive
import pandas as pd
import pprint as pp

guess = 0.7

def mystery_plate(guess):
    murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
    evidence = jnp.array([[[0.9, 0.1], [0.2, 0.8]], [[0.5, 0.5], [0.95, 0.05]], ])
    size = len(evidence)
    with numpyro.plate(f'i=1..{size}', size=size, dim=-1) as i:
        obs = numpyro.sample("evidence", dist.Categorical(Vindex(evidence)[murderer, i]))
    return murderer, obs

conditioned_model_plate = numpyro.handlers.condition(mystery_plate, {"evidence": jnp.array([0., 1.])})


samples = []
for i in range(5000):
    with numpyro.handlers.seed(rng_seed=i):
        murderer, _ = conditioned_model_plate(guess)
        samples.append((murderer.item(),))
pd.DataFrame(samples, columns=["murderer"])["murderer"].value_counts(normalize=True)

Output:

1    0.6992
0    0.3008
Name: murderer, dtype: float64

The code above shows the same result, the murderer site still returns a probability equal to the input, when I move the seed handler within the for loop. Maybe I am misunderstanding what you are suggesting, but I then had to pass each number from range to the random seed in order to not have same samples returned.

Sorry if it’s obvious, but where is the name conflict? If I look over the trace I see every sample site has a different name.

exec_trace = numpyro.handlers.trace(numpyro.handlers.seed(mystery_plate, jax.random.PRNGKey(0))).get_trace(guess)
pp.pprint(exec_trace) 
Output from pp.print of exec_trace
OrderedDict([('murderer',
              {'args': (),
               'cond_indep_stack': [],
               'fn': <numpyro.distributions.discrete.BernoulliProbs object at 0x7fa5585dfe20>,
               'infer': {},
               'intermediates': [],
               'is_observed': False,
               'kwargs': {'rng_key': array([2718843009, 1272950319], dtype=uint32),
                          'sample_shape': ()},
               'name': 'murderer',
               'scale': None,
               'type': 'sample',
               'value': DeviceArray(1, dtype=int32)}),
             ('i=1..2',
              {'args': (2, None),
               'cond_indep_stack': [],
               'fn': <function _subsample_fn at 0x7fa5182e1c10>,
               'kwargs': {'rng_key': None},
               'name': 'i=1..2',
               'scale': 1.0,
               'type': 'plate',
               'value': DeviceArray([0, 1], dtype=int32)}),
             ('evidence',
              {'args': (),
               'cond_indep_stack': [CondIndepStackFrame(name='i=1..2', dim=-1, size=2)],
               'fn': <numpyro.distributions.discrete.CategoricalProbs object at 0x7fa55858a340>,
               'infer': {},
               'intermediates': [],
               'is_observed': False,
               'kwargs': {'rng_key': array([1278412471, 2182328957], dtype=uint32),
                          'sample_shape': ()},
               'name': 'evidence',
               'scale': None,
               'type': 'sample',
               'value': DeviceArray([1, 0], dtype=int32)})])

For example, the site murder is repeatedly appeared in the loop. I think random seed will be constant for this site.

Edit sorry, it is fine because you didn’t play with the trace handler, which only stores 1 value for 1 site. I overlooked your comment.

the murderer site still returns a probability equal to the input

Is it expected when murderer has distribution Bernoulli(0.7)? If you want to get posterior of murderer, you can use infer_discrete as above.

🤦… aw dude, I am sorry for taking up so much time for this non-error. (so embarressing…)

I was wrong, but this does reveal a misconception I had. I was implicitly assuming that if I sampled points from the model under the seed handler while the model was conditioned I would expect inference, which is incorrect. I guess I was confusing Ancestral sampling [1, 2, 3] as a form of inference.

To sample the model’s posterior you need wrap the model in an inference algorithm. =P