Numpyro.plate dim - Bayesian Naive Bayes

Hi all,

I am trying to reproduce a Bayesian naive Bayes model from the plate diagram in Kevin Murphy’s Machine Learning: A probabilistic Perspective 2012, Chapter 10, Section 10.4.1, Figure 10.8b, page 322. The figure is taken from the course Probabilistic Graphical Models, slide 16 of Lecture 2: Directed Graphical Models:
image

I got the model to work on a toy dataset, compared the results to scikit-learn’s BernoulliNB. The model seems to recover the parameters from the toy data, and I have no divergences or any r_hat > 1. The complete code here, and the minimal code for the model is below.

Minimal code, working example
import jax
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

key = jax.random.PRNGKey(0)
key

keys = jax.random.split(key, 2)
d_X = dist.Bernoulli(jnp.array((0.91, 0.02, 0.90, 0.32, 0.66, 0.71, 0.32, 0.70, 0.94, 0.49)))
d_y = dist.Bernoulli(jnp.array((0.8)))
X = d_X.sample(keys[0], (300,))
y = d_y.sample(keys[1], (300,))

def model(X, num_classes, y=None,):
    num_items, num_features = X.shape
    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
    
    with numpyro.plate("D", num_features, dim=None):
        with numpyro.plate("C", num_classes):
            theta = numpyro.sample("theta", dist.Beta(1,1))
            
    with numpyro.plate("N", num_items, dim=-2):
        y = numpyro.sample("y", dist.Categorical(pi), obs=y)
        with numpyro.plate("D", num_features, dim=None):
            x = numpyro.sample("x", dist.Bernoulli(theta[y]), obs=X)
            
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000, num_chains=4)
mcmc.run(key, X, 2, y, extra_fields=("potential_energy",))
mcmc.print_summary()

Output:

sample: 100%|██████████| 1500/1500 [00:11<00:00, 129.06it/s, 7 steps of size 3.95e-01. acc. prob=0.87] 
sample: 100%|██████████| 1500/1500 [00:06<00:00, 246.59it/s, 15 steps of size 4.08e-01. acc. prob=0.86]
sample: 100%|██████████| 1500/1500 [00:05<00:00, 270.71it/s, 7 steps of size 4.91e-01. acc. prob=0.85]
sample: 100%|██████████| 1500/1500 [00:06<00:00, 240.69it/s, 7 steps of size 4.40e-01. acc. prob=0.86] 
                mean       std    median      5.0%     95.0%     n_eff     r_hat
     pi[0]      0.17      0.00      0.17      0.17      0.18   6765.86      1.00
     pi[1]      0.83      0.00      0.83      0.82      0.83   6766.00      1.00
theta[0,0]      0.85      0.05      0.86      0.77      0.93   5084.27      1.00
theta[0,1]      0.06      0.03      0.05      0.01      0.10   6469.38      1.00
theta[0,2]      0.89      0.04      0.89      0.82      0.96   4864.51      1.00
theta[0,3]      0.28      0.06      0.27      0.18      0.38   5242.58      1.00
theta[0,4]      0.67      0.06      0.67      0.56      0.77   4682.34      1.00
theta[0,5]      0.69      0.06      0.69      0.58      0.79   5221.52      1.00
theta[0,6]      0.30      0.06      0.29      0.19      0.40   4822.39      1.00
theta[0,7]      0.61      0.07      0.61      0.51      0.72   5376.68      1.00
theta[0,8]      0.89      0.04      0.89      0.82      0.95   5293.73      1.00
theta[0,9]      0.59      0.06      0.59      0.49      0.70   5938.99      1.00
theta[1,0]      0.91      0.02      0.91      0.88      0.94   5269.80      1.00
theta[1,1]      0.01      0.01      0.01      0.00      0.02   5056.84      1.00
theta[1,2]      0.90      0.02      0.90      0.86      0.93   5359.81      1.00
theta[1,3]      0.36      0.03      0.36      0.30      0.41   6750.08      1.00
theta[1,4]      0.66      0.03      0.66      0.61      0.71   4756.77      1.00
theta[1,5]      0.68      0.03      0.68      0.63      0.73   5011.69      1.00
theta[1,6]      0.32      0.03      0.32      0.28      0.37   4942.66      1.00
theta[1,7]      0.69      0.03      0.69      0.64      0.73   5178.98      1.00
theta[1,8]      0.94      0.01      0.95      0.92      0.97   5130.13      1.00
theta[1,9]      0.46      0.03      0.46      0.41      0.51   4781.71      1.00

Number of divergences: 0

What is troubling me is my misunderstanding behind dim within numpyro.plate. It took me several iterations of trial and error to find the correct arguments of dim for each plate. Unfortunately, this might be convoluted with my model definition because I am repeating the plate over D for X_{ij}.

My misunderstanding of dimensions:

  1. I would expect the shape of theta to be (10,2) because the first plate, D was given dim=None therefore uses the leftmost dimension, yet the theta has shape (2,10)
  2. I don’t understand why the two plates N and D only work with dim=-2 and dim=None, respectfully.

My explanation for why dim=-2 and dim=None for plates N and D, respectfully, work:
Using dim=-2 and dim=None (same as indexing 0?) for plate N and plate D, respectfully, I am indexing the number of rows in my 2d training data, then indexing the columns of my training data.

Yet, when I use dim=None and dim=-1 for plates N and D, respectfully, I get an assertion error (shown below)? My expectation is dim=None and dim=-1 is indexing the leftmost dimension for plate N (the rows) and rightmost dimension for plate D (the columns)

AssertionError
def model(X, num_classes, y=None,):
    num_items, num_features = X.shape
    pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
    
    with numpyro.plate("D", num_features, dim=None):
        with numpyro.plate("C", num_classes, dim=None):
            theta = numpyro.sample("theta", dist.Beta(1,1))
            print(f"{theta.shape=}")
            
    with numpyro.plate("N", num_items, dim=None):
        y = numpyro.sample("y", dist.Categorical(pi), obs=y)
        print(f"{y.shape=}")
        with numpyro.plate("D", num_features, dim=-1):
            print(f"{theta[y].shape=}")
            print(f"{X.shape=}")
            x = numpyro.sample("x", dist.Bernoulli(theta[y]), obs=X)
            print(f"{x.shape=}")
    
    
with numpyro.handlers.seed(rng_seed=0):
    for _ in range(1):
        model(X_train, 2, y_train)

Output:

theta.shape=(2, 10)
y.shape=(5,)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-41-39648db3380c> in <module>
     20 with numpyro.handlers.seed(rng_seed=0):
     21     for _ in range(1):
---> 22         model(X_train, 2, y_train)

<ipython-input-41-39648db3380c> in model(X, num_classes, y)
     11         y = numpyro.sample("y", dist.Categorical(pi), obs=y)
     12         print(f"{y.shape=}")
---> 13         with numpyro.plate("D", num_features, dim=-1):
     14             print(f"{theta[y].shape=}")
     15             print(f"{X.shape=}")

~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in __init__(self, name, size, subsample_size, dim)
    401         if dim is not None and dim >= 0:
    402             raise ValueError("dim arg must be negative.")
--> 403         self.dim, self._indices = self._subsample(
    404             self.name, self.size, subsample_size, dim
    405         )

~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in _subsample(name, size, subsample_size, dim)
    442             dim = new_dim
    443         else:
--> 444             assert dim not in occupied_dims
    445         return dim, subsample
    446 

AssertionError: 

Last possibly, unrelated question:
3. Is there a better way to index using plates to achieve the exact diagram shown in Figure 10.8(b)? I feel like it’s hacky with reusing plate D. The render numpryo model, below:
image

Oops, sorry, the doc is wrong. It should be rightmost rather than leftmost… Do you want to submit a fix?

reusing plate D

You can do

def model():
    plate_D = numpyro.plate("D", size, dim=...)
    with plate_D:
        ...
    # reuse the plate
    with plate_D:
        ...

achieve the exact diagram

It seems to me that the two diagrams are similar. Could you elaborate where is the difference?

It should be rightmost rather than leftmost

Oh… Yeah, that makes a lot more sense. I think I noticed this when I was stepping through the debugger. I will submit the PR shortly.

Let me double check my understanding:

    with numpyro.plate("D", num_features, dim=None):
        with numpyro.plate("C", num_classes):
            theta = numpyro.sample("theta", dist.Beta(1,1))

Above is construction a tensor of rank 2 with the rightmost dimension index by plate D of size, num_features, then because nesting it moves over one dimension from the right, with a size of num_classes resulting in theta.shape == (num_classes, num_features)

plate_D = numpyro.plate("D", size, dim=...)

I cannot believe I didn’t think of that, haha.

Could you elaborate where is the difference?

I guess I was expecting a one-for-one output such that numpyro.render would overlay the plate nesting between plate N for variable y and x.

Since it’s not one-to-one it makes me feel like I could rewrite this model such that I don’t reuse plate_D to get numpyro.render to show a similar overlay plate between ‘y’ and ‘x’. Maybe this is a little naive.