I want to simulate the behavior of customers which gave an option to buy 4 different kind of apples.
Parameters of interest for 5 customers.

Total number of purchases for each customer. Used NegativeBinomial to get num_purchase

The probability of each customer buying from one of the apple categories. Used Dirichlet to get desired_probs.

Number of apples each customer bought. This will be hopefully my obs variable later. Used Multinomial(total_count=num_purchase, probs=desired_probs) for that.

However, via error I found the total_count can not be jax traced array.

What am I doing wrong here. Thanks for help.

Minimal code:

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
seed = 35235
# numpyro seed
rng_key = jax.random.PRNGKey(seed)
rng_trace, rng_prior_pred, rng_posterior_pred = jax.random.split(rng_key, 3)
def gen_data(num_apple_category:int = 4,
n_customers:int = 5):
mu = 1
alpha = 2
with numpyro.plate("customers", n_customers):
# The latent probability for number of buys for each customer is Negative binomial.
# The lower number of the purchase has more chance.
num_purchase = numpyro.sample('num_purchase', dist.NegativeBinomial2(mu, alpha))
desired_probs = numpyro.sample("desired_probs", dist.Dirichlet(0.5*jnp.ones(num_apple_category)))
apples = numpyro.sample("shade", dist.Multinomial(total_count=num_purchase, probs=desired_probs))
num_samples = 2
# Draw from priors to make the fake data
prior_predictive = Predictive(gen_data, num_samples=num_samples)
prior_predictions = prior_predictive(rng_prior_pred)