Multinomial, plate and total counts

Hi all,

I want to simulate the behavior of customers which gave an option to buy 4 different kind of apples.
Parameters of interest for 5 customers.

  • Total number of purchases for each customer. Used NegativeBinomial to get num_purchase
  • The probability of each customer buying from one of the apple categories. Used Dirichlet to get desired_probs.
  • Number of apples each customer bought. This will be hopefully my obs variable later. Used Multinomial(total_count=num_purchase, probs=desired_probs) for that.

However, via error I found the total_count can not be jax traced array.

What am I doing wrong here. Thanks for help.

Minimal code:

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
seed = 35235
# numpyro seed
rng_key = jax.random.PRNGKey(seed)
rng_trace, rng_prior_pred, rng_posterior_pred = jax.random.split(rng_key, 3)

def gen_data(num_apple_category:int = 4, 
             n_customers:int = 5):
    mu = 1
    alpha = 2
    with numpyro.plate("customers", n_customers):
        # The latent probability for number of buys for each customer is Negative binomial.
        # The lower number of the purchase has more chance.
        num_purchase = numpyro.sample('num_purchase', dist.NegativeBinomial2(mu, alpha))
        desired_probs = numpyro.sample("desired_probs", dist.Dirichlet(0.5*jnp.ones(num_apple_category)))
        apples = numpyro.sample("shade", dist.Multinomial(total_count=num_purchase, probs=desired_probs))
num_samples = 2
# Draw from priors to make the fake data
prior_predictive = Predictive(gen_data, num_samples=num_samples)
prior_predictions = prior_predictive(rng_prior_pred)
1 Like

Could you make a github issue for this? I think we can introduce an argument like total_count_max so that total_count can be a tracer.

Thanks @fehiepsi. Here is the link to the issue: