I’m estimating a fairly straightforward multinomial logistic regression model with random (hierarchical) coefficients. Inference works fine when using NUTS. But, when using SVI with any kind of automatic guide, it fails with a cryptic error about “attempted boolean conversion” when checking model_shape == guide_shape
.
Any thoughts on how to fix this?
I’ve boiled down the code to this minimal example:
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
import jax.numpy as jnp
from jax.nn import softmax
from jax import random
def npsoftmax(x):
return np.exp(x)/sum(np.exp(x))
n_people = 200
n_obs = 10
n_alts = 3
n_feats = 2
X = np.random.normal(size=(n_people, n_obs, n_alts, n_feats))
# DGP ----------------------------------------
mu = np.random.normal(loc=0, scale=1, size=n_feats)
tau = np.random.uniform(low=0.5, high=1.5, size=n_feats)
beta = np.random.multivariate_normal(mean=mu, cov=tau*np.eye(n_feats), size=n_people)
y = np.zeros((n_people,n_obs), dtype=int)
for i in range(n_people):
for t in range(n_obs):
logits = X[i,t] @ beta[i]
probs = npsoftmax(logits)
y[i,t] = np.random.choice(n_alts, size=1, p=probs)[0]
numpyro_data = {}
numpyro_data['X'] = jnp.array(X.reshape((n_people*n_obs, n_alts, n_feats)))
numpyro_data['y'] = jnp.array(y.reshape(n_people*n_obs))
numpyro_data['id'] = jnp.array(np.repeat(np.arange(n_people), n_obs))
# Model ----------------------------------------
def model(X, y, id):
n_obs_total, n_alts, n_pars = X.shape
n_people = numpyro_data["id"].max() + 1
mu = numpyro.sample("mu", dist.Normal(scale=5), sample_shape=(n_pars,))
tau = numpyro.sample("tau", dist.HalfNormal(5), sample_shape=(n_pars,))
cov = tau * jnp.eye(n_feats)
beta = numpyro.sample("beta", dist.MultivariateNormal(mu, covariance_matrix=cov), sample_shape=(n_people,))
beta_expanded = beta[id][:, None, :]
X_weighted = X * beta_expanded
utils = jnp.sum(X_weighted, axis=-1)
p = softmax(utils, axis=-1)
numpyro.sample(f"choices", dist.Categorical(p), obs=y)
# MCMC ----------------------------------------
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=200, num_samples=200)
mcmc.run(random.PRNGKey(0), X=numpyro_data['X'], y=numpyro_data['y'], id=numpyro_data['id'])
# SVI ----------------------------------------
guide = AutoNormal(model)
optimizer = numpyro.optim.Adam(step_size=0.01)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 2000, X=numpyro_data['X'], y=numpyro_data['y'], id=numpyro_data['id'])
Here is the error:
File ~/miniconda3/envs/cs23/lib/python3.12/site-packages/numpyro/util.py:653, in check_model_guide_match(model_trace, guide_trace)
651 model_shape = model_site["fn"].shape(model_site["kwargs"]["sample_shape"])
652 guide_shape = guide_site["fn"].shape(guide_site["kwargs"]["sample_shape"])
--> 653 if model_shape == guide_shape:
654 continue
656 for model_size, guide_size in zip_longest(
657 reversed(model_shape), reversed(guide_shape), fillvalue=1
658 ):
[... skipping hidden 1 frame]
File ~/miniconda3/envs/cs23/lib/python3.12/site-packages/jax/_src/core.py:1492, in concretization_function_error.<locals>.error(self, arg)
1491 def error(self, arg):
-> 1492 raise TracerBoolConversionError(arg)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..