Hi all,

I am trying to reproduce a *Bayesian* naive Bayes model from the plate diagram in Kevin Murphy’s Machine Learning: A probabilistic Perspective 2012, Chapter 10, Section 10.4.1, Figure 10.8b, page 322. The figure is taken from the course Probabilistic Graphical Models, slide 16 of Lecture 2: Directed Graphical Models:

I got the model to work on a toy dataset, compared the results to scikit-learn’s `BernoulliNB`

. The model seems to recover the parameters from the toy data, and I have no divergences or any `r_hat > 1`

. The complete code here, and the minimal code for the model is below.

## Minimal code, working example

```
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
key = jax.random.PRNGKey(0)
key
keys = jax.random.split(key, 2)
d_X = dist.Bernoulli(jnp.array((0.91, 0.02, 0.90, 0.32, 0.66, 0.71, 0.32, 0.70, 0.94, 0.49)))
d_y = dist.Bernoulli(jnp.array((0.8)))
X = d_X.sample(keys[0], (300,))
y = d_y.sample(keys[1], (300,))
def model(X, num_classes, y=None,):
num_items, num_features = X.shape
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
with numpyro.plate("D", num_features, dim=None):
with numpyro.plate("C", num_classes):
theta = numpyro.sample("theta", dist.Beta(1,1))
with numpyro.plate("N", num_items, dim=-2):
y = numpyro.sample("y", dist.Categorical(pi), obs=y)
with numpyro.plate("D", num_features, dim=None):
x = numpyro.sample("x", dist.Bernoulli(theta[y]), obs=X)
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000, num_chains=4)
mcmc.run(key, X, 2, y, extra_fields=("potential_energy",))
mcmc.print_summary()
```

Output:

```
sample: 100%|██████████| 1500/1500 [00:11<00:00, 129.06it/s, 7 steps of size 3.95e-01. acc. prob=0.87]
sample: 100%|██████████| 1500/1500 [00:06<00:00, 246.59it/s, 15 steps of size 4.08e-01. acc. prob=0.86]
sample: 100%|██████████| 1500/1500 [00:05<00:00, 270.71it/s, 7 steps of size 4.91e-01. acc. prob=0.85]
sample: 100%|██████████| 1500/1500 [00:06<00:00, 240.69it/s, 7 steps of size 4.40e-01. acc. prob=0.86]
mean std median 5.0% 95.0% n_eff r_hat
pi[0] 0.17 0.00 0.17 0.17 0.18 6765.86 1.00
pi[1] 0.83 0.00 0.83 0.82 0.83 6766.00 1.00
theta[0,0] 0.85 0.05 0.86 0.77 0.93 5084.27 1.00
theta[0,1] 0.06 0.03 0.05 0.01 0.10 6469.38 1.00
theta[0,2] 0.89 0.04 0.89 0.82 0.96 4864.51 1.00
theta[0,3] 0.28 0.06 0.27 0.18 0.38 5242.58 1.00
theta[0,4] 0.67 0.06 0.67 0.56 0.77 4682.34 1.00
theta[0,5] 0.69 0.06 0.69 0.58 0.79 5221.52 1.00
theta[0,6] 0.30 0.06 0.29 0.19 0.40 4822.39 1.00
theta[0,7] 0.61 0.07 0.61 0.51 0.72 5376.68 1.00
theta[0,8] 0.89 0.04 0.89 0.82 0.95 5293.73 1.00
theta[0,9] 0.59 0.06 0.59 0.49 0.70 5938.99 1.00
theta[1,0] 0.91 0.02 0.91 0.88 0.94 5269.80 1.00
theta[1,1] 0.01 0.01 0.01 0.00 0.02 5056.84 1.00
theta[1,2] 0.90 0.02 0.90 0.86 0.93 5359.81 1.00
theta[1,3] 0.36 0.03 0.36 0.30 0.41 6750.08 1.00
theta[1,4] 0.66 0.03 0.66 0.61 0.71 4756.77 1.00
theta[1,5] 0.68 0.03 0.68 0.63 0.73 5011.69 1.00
theta[1,6] 0.32 0.03 0.32 0.28 0.37 4942.66 1.00
theta[1,7] 0.69 0.03 0.69 0.64 0.73 5178.98 1.00
theta[1,8] 0.94 0.01 0.95 0.92 0.97 5130.13 1.00
theta[1,9] 0.46 0.03 0.46 0.41 0.51 4781.71 1.00
Number of divergences: 0
```

What is troubling me is my misunderstanding behind `dim`

within `numpyro.plate`

. It took me several iterations of trial and error to find the correct arguments of `dim`

for each plate. Unfortunately, this might be convoluted with my model definition because I am repeating the plate over `D`

for X_{ij}.

My misunderstanding of dimensions:

- I would expect the shape of
`theta`

to be (10,2) because the first plate,`D`

was given`dim=None`

therefore uses the leftmost dimension, yet the`theta`

has shape (2,10) - I don’t understand why the two plates
`N`

and`D`

only work with`dim=-2`

and`dim=None`

, respectfully.

My explanation for why `dim=-2`

and `dim=None`

for plates `N`

and `D`

, respectfully, work:

*Using dim=-2 and dim=None (same as indexing 0?) for plate N and plate D, respectfully, I am indexing the number of rows in my 2d training data, then indexing the columns of my training data.*

Yet, when I use `dim=None`

and `dim=-1`

for plates `N`

and `D`

, respectfully, I get an assertion error (shown below)? My expectation is `dim=None`

and `dim=-1`

is indexing the leftmost dimension for plate `N`

(the rows) and rightmost dimension for plate `D`

(the columns)

## AssertionError

```
def model(X, num_classes, y=None,):
num_items, num_features = X.shape
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
with numpyro.plate("D", num_features, dim=None):
with numpyro.plate("C", num_classes, dim=None):
theta = numpyro.sample("theta", dist.Beta(1,1))
print(f"{theta.shape=}")
with numpyro.plate("N", num_items, dim=None):
y = numpyro.sample("y", dist.Categorical(pi), obs=y)
print(f"{y.shape=}")
with numpyro.plate("D", num_features, dim=-1):
print(f"{theta[y].shape=}")
print(f"{X.shape=}")
x = numpyro.sample("x", dist.Bernoulli(theta[y]), obs=X)
print(f"{x.shape=}")
with numpyro.handlers.seed(rng_seed=0):
for _ in range(1):
model(X_train, 2, y_train)
```

Output:

```
theta.shape=(2, 10)
y.shape=(5,)
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-41-39648db3380c> in <module>
20 with numpyro.handlers.seed(rng_seed=0):
21 for _ in range(1):
---> 22 model(X_train, 2, y_train)
<ipython-input-41-39648db3380c> in model(X, num_classes, y)
11 y = numpyro.sample("y", dist.Categorical(pi), obs=y)
12 print(f"{y.shape=}")
---> 13 with numpyro.plate("D", num_features, dim=-1):
14 print(f"{theta[y].shape=}")
15 print(f"{X.shape=}")
~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in __init__(self, name, size, subsample_size, dim)
401 if dim is not None and dim >= 0:
402 raise ValueError("dim arg must be negative.")
--> 403 self.dim, self._indices = self._subsample(
404 self.name, self.size, subsample_size, dim
405 )
~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in _subsample(name, size, subsample_size, dim)
442 dim = new_dim
443 else:
--> 444 assert dim not in occupied_dims
445 return dim, subsample
446
AssertionError:
```

Last possibly, unrelated question:

3. Is there a better way to index using plates to achieve the exact diagram shown in Figure 10.8(b)? I feel like it’s hacky with reusing plate `D`

. The render numpryo model, below: