Hi everyone,
I am trying to predict new data with a multinomial regression model. However I am getting
TypeError: sub got incompatible shapes for broadcasting: (30, 34), (22, 34).
I followed the tutorials suggested in this post without success.
Here is my model:
import numpyro as npr
from numpyro.distributions import Dirichlet, Multinomial, Normal
from numpyro.infer import MCMC, NUTS, Predictive
import jax.numpy as jnp
import jax.random as random
from jax.scipy.special import expit
def model(X, y=None,):
*_, n_features = X.shape
y_norm = y/y.sum(axis=1)[:, None] if y is not None else None
alpha = npr.sample("alpha", Normal(0., 1.).expand([n_features+1]).to_event(1))
dconc = jnp.maximum(alpha[0] + jnp.dot(X, alpha[1:]), 0)
frac = npr.sample("frac", Dirichlet(concentration=dconc))
counts = npr.sample(
"counts",
Multinomial(total_count=1, probs=frac),
obs=y_norm
)
print(f"{X.shape=}, {alpha.shape=}, {dconc.shape=}, {frac.shape=}, {counts.shape=}")
return counts
rng_key=random.split(random.PRNGKey(1))[0]
kernel = NUTS(model)
mcmc = MCMC(kernel, jit_model_args=True, **sample_kwargs)
mcmc.run(rng_key, X_train, y_train)
# X.shape=(30, 34, 16), alpha.shape=(17,), dconc.shape=(30, 34), frac.shape=(30, 34), counts.shape=(30, 34)
posterior_samples = mcmc.get_samples()
predictive = Predictive(model=model, posterior_samples=posterior_samples, parallel=True, return_sites=["alpha", "frac", "count"])
predictions = predictive(rng_key, X=X_test)
# X.shape=(22, 34, 16), alpha.shape=(17,), dconc.shape=(22, 34), frac.shape=(30, 34), counts.shape=(30, 34)
I am probably missing something passing from dconc
to frac
parameters, yet I don’t see it how to work around it.
Thanks in advanced to anyone.