# Multinomial, plate and total counts

Hi all,

I want to simulate the behavior of customers which gave an option to buy 4 different kind of apples.
Parameters of interest for 5 customers.

• Total number of purchases for each customer. Used `NegativeBinomial` to get `num_purchase`
• The probability of each customer buying from one of the apple categories. Used `Dirichlet` to get `desired_probs`.
• Number of apples each customer bought. This will be hopefully my `obs` variable later. Used `Multinomial(total_count=num_purchase, probs=desired_probs)` for that.

However, via error I found the `total_count` can not be `jax` traced array.

What am I doing wrong here. Thanks for help.

Minimal code:

``````import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
seed = 35235
# numpyro seed
rng_key = jax.random.PRNGKey(seed)
rng_trace, rng_prior_pred, rng_posterior_pred = jax.random.split(rng_key, 3)

def gen_data(num_apple_category:int = 4,
n_customers:int = 5):
mu = 1
alpha = 2
with numpyro.plate("customers", n_customers):
# The latent probability for number of buys for each customer is Negative binomial.
# The lower number of the purchase has more chance.
num_purchase = numpyro.sample('num_purchase', dist.NegativeBinomial2(mu, alpha))
desired_probs = numpyro.sample("desired_probs", dist.Dirichlet(0.5*jnp.ones(num_apple_category)))
Could you make a github issue for this? I think we can introduce an argument like `total_count_max` so that `total_count` can be a tracer.