Hi all,
I am trying to reproduce a Bayesian naive Bayes model from the plate diagram in Kevin Murphy’s Machine Learning: A probabilistic Perspective 2012, Chapter 10, Section 10.4.1, Figure 10.8b, page 322. The figure is taken from the course Probabilistic Graphical Models, slide 16 of Lecture 2: Directed Graphical Models:
I got the model to work on a toy dataset, compared the results to scikit-learn’s BernoulliNB
. The model seems to recover the parameters from the toy data, and I have no divergences or any r_hat > 1
. The complete code here, and the minimal code for the model is below.
Minimal code, working example
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
key = jax.random.PRNGKey(0)
key
keys = jax.random.split(key, 2)
d_X = dist.Bernoulli(jnp.array((0.91, 0.02, 0.90, 0.32, 0.66, 0.71, 0.32, 0.70, 0.94, 0.49)))
d_y = dist.Bernoulli(jnp.array((0.8)))
X = d_X.sample(keys[0], (300,))
y = d_y.sample(keys[1], (300,))
def model(X, num_classes, y=None,):
num_items, num_features = X.shape
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
with numpyro.plate("D", num_features, dim=None):
with numpyro.plate("C", num_classes):
theta = numpyro.sample("theta", dist.Beta(1,1))
with numpyro.plate("N", num_items, dim=-2):
y = numpyro.sample("y", dist.Categorical(pi), obs=y)
with numpyro.plate("D", num_features, dim=None):
x = numpyro.sample("x", dist.Bernoulli(theta[y]), obs=X)
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000, num_chains=4)
mcmc.run(key, X, 2, y, extra_fields=("potential_energy",))
mcmc.print_summary()
Output:
sample: 100%|██████████| 1500/1500 [00:11<00:00, 129.06it/s, 7 steps of size 3.95e-01. acc. prob=0.87]
sample: 100%|██████████| 1500/1500 [00:06<00:00, 246.59it/s, 15 steps of size 4.08e-01. acc. prob=0.86]
sample: 100%|██████████| 1500/1500 [00:05<00:00, 270.71it/s, 7 steps of size 4.91e-01. acc. prob=0.85]
sample: 100%|██████████| 1500/1500 [00:06<00:00, 240.69it/s, 7 steps of size 4.40e-01. acc. prob=0.86]
mean std median 5.0% 95.0% n_eff r_hat
pi[0] 0.17 0.00 0.17 0.17 0.18 6765.86 1.00
pi[1] 0.83 0.00 0.83 0.82 0.83 6766.00 1.00
theta[0,0] 0.85 0.05 0.86 0.77 0.93 5084.27 1.00
theta[0,1] 0.06 0.03 0.05 0.01 0.10 6469.38 1.00
theta[0,2] 0.89 0.04 0.89 0.82 0.96 4864.51 1.00
theta[0,3] 0.28 0.06 0.27 0.18 0.38 5242.58 1.00
theta[0,4] 0.67 0.06 0.67 0.56 0.77 4682.34 1.00
theta[0,5] 0.69 0.06 0.69 0.58 0.79 5221.52 1.00
theta[0,6] 0.30 0.06 0.29 0.19 0.40 4822.39 1.00
theta[0,7] 0.61 0.07 0.61 0.51 0.72 5376.68 1.00
theta[0,8] 0.89 0.04 0.89 0.82 0.95 5293.73 1.00
theta[0,9] 0.59 0.06 0.59 0.49 0.70 5938.99 1.00
theta[1,0] 0.91 0.02 0.91 0.88 0.94 5269.80 1.00
theta[1,1] 0.01 0.01 0.01 0.00 0.02 5056.84 1.00
theta[1,2] 0.90 0.02 0.90 0.86 0.93 5359.81 1.00
theta[1,3] 0.36 0.03 0.36 0.30 0.41 6750.08 1.00
theta[1,4] 0.66 0.03 0.66 0.61 0.71 4756.77 1.00
theta[1,5] 0.68 0.03 0.68 0.63 0.73 5011.69 1.00
theta[1,6] 0.32 0.03 0.32 0.28 0.37 4942.66 1.00
theta[1,7] 0.69 0.03 0.69 0.64 0.73 5178.98 1.00
theta[1,8] 0.94 0.01 0.95 0.92 0.97 5130.13 1.00
theta[1,9] 0.46 0.03 0.46 0.41 0.51 4781.71 1.00
Number of divergences: 0
What is troubling me is my misunderstanding behind dim
within numpyro.plate
. It took me several iterations of trial and error to find the correct arguments of dim
for each plate. Unfortunately, this might be convoluted with my model definition because I am repeating the plate over D
for X_{ij}.
My misunderstanding of dimensions:
- I would expect the shape of
theta
to be (10,2) because the first plate,D
was givendim=None
therefore uses the leftmost dimension, yet thetheta
has shape (2,10) - I don’t understand why the two plates
N
andD
only work withdim=-2
anddim=None
, respectfully.
My explanation for why dim=-2
and dim=None
for plates N
and D
, respectfully, work:
Using dim=-2
and dim=None
(same as indexing 0?) for plate N
and plate D
, respectfully, I am indexing the number of rows in my 2d training data, then indexing the columns of my training data.
Yet, when I use dim=None
and dim=-1
for plates N
and D
, respectfully, I get an assertion error (shown below)? My expectation is dim=None
and dim=-1
is indexing the leftmost dimension for plate N
(the rows) and rightmost dimension for plate D
(the columns)
AssertionError
def model(X, num_classes, y=None,):
num_items, num_features = X.shape
pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))
with numpyro.plate("D", num_features, dim=None):
with numpyro.plate("C", num_classes, dim=None):
theta = numpyro.sample("theta", dist.Beta(1,1))
print(f"{theta.shape=}")
with numpyro.plate("N", num_items, dim=None):
y = numpyro.sample("y", dist.Categorical(pi), obs=y)
print(f"{y.shape=}")
with numpyro.plate("D", num_features, dim=-1):
print(f"{theta[y].shape=}")
print(f"{X.shape=}")
x = numpyro.sample("x", dist.Bernoulli(theta[y]), obs=X)
print(f"{x.shape=}")
with numpyro.handlers.seed(rng_seed=0):
for _ in range(1):
model(X_train, 2, y_train)
Output:
theta.shape=(2, 10)
y.shape=(5,)
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-41-39648db3380c> in <module>
20 with numpyro.handlers.seed(rng_seed=0):
21 for _ in range(1):
---> 22 model(X_train, 2, y_train)
<ipython-input-41-39648db3380c> in model(X, num_classes, y)
11 y = numpyro.sample("y", dist.Categorical(pi), obs=y)
12 print(f"{y.shape=}")
---> 13 with numpyro.plate("D", num_features, dim=-1):
14 print(f"{theta[y].shape=}")
15 print(f"{X.shape=}")
~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in __init__(self, name, size, subsample_size, dim)
401 if dim is not None and dim >= 0:
402 raise ValueError("dim arg must be negative.")
--> 403 self.dim, self._indices = self._subsample(
404 self.name, self.size, subsample_size, dim
405 )
~/anaconda3/envs/numpyro_play/lib/python3.8/site-packages/numpyro/primitives.py in _subsample(name, size, subsample_size, dim)
442 dim = new_dim
443 else:
--> 444 assert dim not in occupied_dims
445 return dim, subsample
446
AssertionError:
Last possibly, unrelated question:
3. Is there a better way to index using plates to achieve the exact diagram shown in Figure 10.8(b)? I feel like it’s hacky with reusing plate D
. The render numpryo model, below: