Hi all,

I am trying to take the model below (which works, from MBML chapter 1), and switch out the hardcoded samples `murderer`

and `hair`

into a `plate`

since the model can leverage the conditional independence.

## Minimal Code example

```
import arviz as az
import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import pandas as pd
from numpyro.infer import MCMC, NUTS, DiscreteHMCGibbs
from numpyro.infer.util import Predictive
rng_key = jax.random.PRNGKey(2)
rng_key, rng_key_ = jax.random.split(rng_key)
num_warmup = 1000
num_samples = 1000
num_chains = 4
guess = 0.7
def mystery_extend(guess):
weapon_cpt = jnp.array([[0.9, 0.1], [0.2, 0.8]])
hair_cpt = jnp.array([[0.5, 0.5], [0.95, 0.05]])
murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
weapon = numpyro.sample("weapon", dist.Categorical(weapon_cpt[murderer]))
hair = numpyro.sample("hair", dist.Categorical(hair_cpt[murderer]))
return murderer, weapon, hair
conditioned_model_extend = numpyro.handlers.condition(mystery_extend, {"weapon": 0.0, "hair": 1.0})
predictive = Predictive(conditioned_model_extend, num_samples=num_chains*num_samples, infer_discrete=True,)
discrete_samples = predictive(rng_key_, guess)
for key in discrete_samples.keys():
discrete_samples[key] = np.array(discrete_samples[key].reshape(num_chains,num_samples))
az.stats.summary(discrete_samples["murderer"], hdi_prob=0.9)
```

Output:

mean | sd | hdi_5% | hdi_95% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|

x | 0.052 | 0.222 | 0.0 | 0.0 | 0.004 | 0.002 | 4023.0 | 4000.0 | 1.0 |

```
az.plot_trace(discrete_samples["murderer"]);
```

Output:

Yet I cannot figure out this `ValueError: Expected the joint log density is a scalar,`

once I include the `plate`

. Also when I try to just sample the model it seems the model is not returning the expected inference.

## Ancestral sampling

```
def mystery_plate(guess):
murderer = numpyro.sample("murderer", dist.Bernoulli(guess))
evidence = jnp.array([[[0.9, 0.1], [0.2, 0.8]], [[0.5, 0.5], [0.95, 0.05]], ])
size = len(evidence)
with numpyro.plate(f'i=1..{size}', size=size, dim=-1):
obs = numpyro.sample("evidence", dist.Categorical(evidence[murderer]))
return murderer, obs
numpyro.render_model(mystery_plate, (guess,), render_distributions=True)
conditioned_model_plate = numpyro.handlers.condition(mystery_plate, {"evidence": jnp.array([0., 1.])})
with numpyro.handlers.seed(rng_seed=0):
samples = []
for _ in range(5000):
murderer, _ = conditioned_model_plate(guess)
samples.append((murderer.item(),))
pd.DataFrame(samples, columns=["murderer"])["murderer"].value_counts(normalize=True)
```

Output:

```
1 0.6862
0 0.3138
Name: murderer, dtype: float64
```

## ValueError: Expected the joint log density is a scalar

```
rng_key_, rng_key__ = jax.random.split(rng_key_)
predictive = Predictive(conditioned_model_plate, num_samples=num_chains*num_samples, infer_discrete=True,)
discrete_samples = predictive(rng_key__, guess)
```

Output:

```
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_48865/2729934207.py in <module>
1 predictive = Predictive(conditioned_model_plate, num_samples=num_chains*num_samples, infer_discrete=True,)
----> 2 discrete_samples = predictive(rng_key__, guess)
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in __call__(self, rng_key, *args, **kwargs)
892 )
893 model = substitute(self.model, self.params)
--> 894 return _predictive(
895 rng_key,
896 model,
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, model_args, model_kwargs)
737 rng_key = rng_key.reshape(batch_shape + (2,))
738 chunk_size = num_samples if parallel else 1
--> 739 return soft_vmap(
740 single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size
741 )
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/util.py in soft_vmap(fn, xs, batch_ndims, chunk_size)
403 fn = vmap(fn)
404
--> 405 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
406 map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
407 ys = tree_map(
[... skipping hidden 15 frame]
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/infer/util.py in single_prediction(val)
702 model_trace = prototype_trace
703 temperature = 1
--> 704 pred_samples = _sample_posterior(
705 config_enumerate(condition(model, samples)),
706 first_available_dim,
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/funsor/discrete.py in _sample_posterior(model, first_available_dim, temperature, rng_key, *args, **kwargs)
60 with funsor.adjoint.AdjointTape() as tape:
61 with block(), enum(first_available_dim=first_available_dim):
---> 62 log_prob, model_tr, log_measures = _enum_log_density(
63 model, args, kwargs, {}, sum_op, prod_op
64 )
~/anaconda3/envs/mbml_numpyro/lib/python3.9/site-packages/numpyro/contrib/funsor/infer_util.py in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
238 result = funsor.optimizer.apply_optimizer(lazy_result)
239 if len(result.inputs) > 0:
--> 240 raise ValueError(
241 "Expected the joint log density is a scalar, but got {}. "
242 "There seems to be something wrong at the following sites: {}.".format(
ValueError: Expected the joint log density is a scalar, but got (2,). There seems to be something wrong at the following sites: {'_pyro_dim_3'}.
```

Here is my list where I think might be the source to both problems, but I am not sure…

- Misuse use of categorical distribution with
`probs`

being N-dim, but this seems correct based off the docs - Maybe this needs a VIndex?
- Maybe I am breaking on of the restrictions?