Multinomial regression - wrong shape in inference step

Hi everyone,

I am trying to predict new data with a multinomial regression model. However I am getting

TypeError: sub got incompatible shapes for broadcasting: (30, 34), (22, 34).

I followed the tutorials suggested in this post without success.

Here is my model:

import numpyro as npr
from numpyro.distributions import Dirichlet, Multinomial, Normal
from numpyro.infer import MCMC, NUTS, Predictive

import jax.numpy as jnp
import jax.random as random
from jax.scipy.special import expit

def model(X, y=None,):
    *_, n_features = X.shape 
    y_norm = y/y.sum(axis=1)[:, None] if y is not None else None

    alpha = npr.sample("alpha", Normal(0., 1.).expand([n_features+1]).to_event(1))
    dconc = jnp.maximum(alpha[0] +, alpha[1:]), 0)
    frac = npr.sample("frac", Dirichlet(concentration=dconc))

    counts = npr.sample(
        Multinomial(total_count=1, probs=frac),

    print(f"{X.shape=}, {alpha.shape=}, {dconc.shape=}, {frac.shape=}, {counts.shape=}")
    return counts


kernel = NUTS(model)
mcmc = MCMC(kernel, jit_model_args=True, **sample_kwargs), X_train, y_train)

# X.shape=(30, 34, 16), alpha.shape=(17,), dconc.shape=(30, 34), frac.shape=(30, 34), counts.shape=(30, 34)

posterior_samples = mcmc.get_samples()

predictive = Predictive(model=model, posterior_samples=posterior_samples, parallel=True, return_sites=["alpha", "frac", "count"])
predictions = predictive(rng_key, X=X_test)
# X.shape=(22, 34, 16), alpha.shape=(17,), dconc.shape=(22, 34), frac.shape=(30, 34), counts.shape=(30, 34)

I am probably missing something passing from dconc to frac parameters, yet I don’t see it how to work around it.

Thanks in advanced to anyone.

you probably need to make your code vectorizable with changes like right-indexing (y.sum(axis=-2) or what have you), perhaps dconc = jnp.maximum(alpha[..., 0] +, alpha[..., 1:]), -2) or similar, etc