JAX error with SVI + autoguide (model works fine with NUTS)

I’m estimating a fairly straightforward multinomial logistic regression model with random (hierarchical) coefficients. Inference works fine when using NUTS. But, when using SVI with any kind of automatic guide, it fails with a cryptic error about “attempted boolean conversion” when checking model_shape == guide_shape.

Any thoughts on how to fix this?

I’ve boiled down the code to this minimal example:

import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
import jax.numpy as jnp
from jax.nn import softmax
from jax import random

def npsoftmax(x):
   return np.exp(x)/sum(np.exp(x))

n_people = 200
n_obs = 10
n_alts = 3
n_feats = 2
X = np.random.normal(size=(n_people, n_obs, n_alts, n_feats))

# DGP ----------------------------------------
mu = np.random.normal(loc=0, scale=1, size=n_feats)
tau = np.random.uniform(low=0.5, high=1.5, size=n_feats)
beta = np.random.multivariate_normal(mean=mu, cov=tau*np.eye(n_feats), size=n_people)

y = np.zeros((n_people,n_obs), dtype=int)
for i in range(n_people):
   for t in range(n_obs):
       logits = X[i,t] @ beta[i]
       probs = npsoftmax(logits)
       y[i,t] = np.random.choice(n_alts, size=1, p=probs)[0]
       
numpyro_data = {}
numpyro_data['X'] = jnp.array(X.reshape((n_people*n_obs, n_alts, n_feats)))
numpyro_data['y'] = jnp.array(y.reshape(n_people*n_obs))
numpyro_data['id'] = jnp.array(np.repeat(np.arange(n_people), n_obs))

# Model ----------------------------------------
def model(X, y, id):
   n_obs_total, n_alts, n_pars = X.shape
   n_people = numpyro_data["id"].max() + 1

   mu = numpyro.sample("mu", dist.Normal(scale=5), sample_shape=(n_pars,))
   tau = numpyro.sample("tau", dist.HalfNormal(5), sample_shape=(n_pars,))
   cov = tau * jnp.eye(n_feats)
   beta = numpyro.sample("beta", dist.MultivariateNormal(mu, covariance_matrix=cov), sample_shape=(n_people,))

   beta_expanded = beta[id][:, None, :] 
   X_weighted = X * beta_expanded  
   utils = jnp.sum(X_weighted, axis=-1) 

   p = softmax(utils, axis=-1)
   numpyro.sample(f"choices", dist.Categorical(p), obs=y)

# MCMC ----------------------------------------
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=200, num_samples=200)
mcmc.run(random.PRNGKey(0), X=numpyro_data['X'], y=numpyro_data['y'], id=numpyro_data['id'])

# SVI ----------------------------------------
guide = AutoNormal(model)
optimizer = numpyro.optim.Adam(step_size=0.01)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 2000, X=numpyro_data['X'], y=numpyro_data['y'], id=numpyro_data['id'])

Here is the error:

File ~/miniconda3/envs/cs23/lib/python3.12/site-packages/numpyro/util.py:653, in check_model_guide_match(model_trace, guide_trace)
    651 model_shape = model_site["fn"].shape(model_site["kwargs"]["sample_shape"])
    652 guide_shape = guide_site["fn"].shape(guide_site["kwargs"]["sample_shape"])
--> 653 if model_shape == guide_shape:
    654     continue
    656 for model_size, guide_size in zip_longest(
    657     reversed(model_shape), reversed(guide_shape), fillvalue=1
    658 ):

    [... skipping hidden 1 frame]

File ~/miniconda3/envs/cs23/lib/python3.12/site-packages/jax/_src/core.py:1492, in concretization_function_error.<locals>.error(self, arg)
   1491 def error(self, arg):
-> 1492   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..

Your n_people is a jax arrray. Could you let it have the int type?

Thank you – that resolved the error. It was tricky to spot!