Dirichlet Multinomial with centered random effects

Dear all,

I would like to run a Dirichlet Multinomial with centered random effects. Before applying the model to my data, I wanted to ensure that the model is correctly specified. I conducted a simulation where parameters are known, but unfortunately, the model couldn’t retrieve the correct estimates. I ran the same model in Stan and it did work. Please find bellow the code for reproducibility.

Could you please let me know what I might have misspecified in my model building?

Thank you.

import seaborn as sns
import numpy as np
import jax
from jax import random
from jax.nn import softmax
import jax.numpy as jnp
import numpyro as numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
numpyro.set_platform("gpu")
numpyro.set_host_device_count(1)

###############################################################################
############ SIMULATING MULTINOMIAL DATA WITH SOFTMAX LINK FUNCTION ###########
def mysoftmax(x):
    exp_x = np.exp(x - np.max(x))
    return exp_x / np.sum(exp_x, axis=0)

K = 3
N = 100
N_obs = 2
sigma_random = 0.6

########################################################
################### Fixed effect Sim ###################
#a = np.random.normal(0, 1, K)
a = np.array([3,1,1]) # Forcing a values
b_individual = np.random.normal(0, 1, (N, K))
mu = b_individual + a

# Declare an empty Matrix to fill with data
Y = np.empty((N * N_obs, K))

# Declare an empty vector to fill with IDs
id = []

# Loop over each individual
for i in range(N):
    # Simulate N_obs draws from the multinomial
    Y[i*N_obs:(i+1)*N_obs, :] = np.apply_along_axis(lambda x: np.random.multinomial(100, mysoftmax(x)), 0, mu[i])
    # Assign ID vector
    id += [i] * N_obs


N = N*N_obs
K = K
ni = N
y = jnp.array(Y, dtype=jnp.int32).reshape(N, K)
i_ID = jnp.array(id)

dat = dict(
    K = K,
    ni = ni,
    y = y,
    i_ID = i_ID, 
    N_obs = N_obs
)
###############################################################################
#################################### Numpyro Model  ##############################
def model(K, ni, y, i_ID, N_obs):
    a =  numpyro.sample('a', dist.Normal(1).expand([K]))

    #alpha2 = numpyro.sample("alpha2", dist.Normal(0, 1).expand([ni, K]))
    Sigma_individual = numpyro.sample("Sigma_individual", dist.Exponential(1).expand([ni]))
    L_individual =  numpyro.sample('L_individual', dist.LKJCholesky(ni, 50))
    z_individual = numpyro.sample("z_individual", dist.Normal(0, 1).expand([ni, K]))
    alpha = numpyro.deterministic("alpha", ((Sigma_individual[..., None] * L_individual) @ z_individual))

    #Apply softmax along the correct axis
    lk = numpyro.deterministic("alpha_softmax", softmax(a+ alpha[i_ID], axis = -1 ))
    numpyro.sample("y", dist.DirichletMultinomial(lk, N_obs), obs=y)

m = MCMC(NUTS(model, init_strategy = numpyro.infer.init_to_median()), num_warmup=500, num_samples=500, num_chains=1)
m.run(random.PRNGKey(0), extra_fields=["diverging"], **dat)
post = m.get_samples()

###############################################################################
#################################### Pustan Model  #############################
import time as tm
import stan
import nest_asyncio
import numpy as np

import numpy as np
tmp = dat
tmp['y'] = np.array(tmp['y'])
tmp['i_ID'] = np.array(tmp['i_ID']+1)
tmp['ni'] = tmp['ni']
tmp['K'] = tmp['K']
tmp['N'] = int(N)

nest_asyncio.apply()
stan_code = """ 
data {
    int<lower=0>  N;             // number of observations
    int<lower=0>  K;             // number of occupations
    int ni;                     // NUmber of Unique Individauls
    array[N, K] int y;           // array of observed occupation indicators
    array[N]int<lower=0>  i_ID;     // village indicator for each individual
}
parameters {
    vector[K] a;                    // intercept for each occupation
    matrix[ni, K]  z_individual;    // raw random effect for household 
    cholesky_factor_corr[ni] L_individual; // Cholesky factor for 
    vector<lower=0>[ni] Sigma_individual;

}
transformed parameters{
    matrix[K, ni] b_individual;
    b_individual = (diag_pre_multiply(Sigma_individual, L_individual) * z_individual)';
}
model{
    array[N] vector[K] p;
    matrix[K, N] random_effects;
    to_vector(a) ~ normal(0, 1);
    L_individual ~   lkj_corr_cholesky(2);
    Sigma_individual ~ exponential(1);

    to_vector(z_individual) ~ normal(0, 1);


    // Likelihood for
    for (k in 1:K) {
        for (i in 1:N) {
          random_effects[k, i] = b_individual[k, i_ID[i]];
          p[i,k] =  a[k] + random_effects[k, i];
      }
    }

    for (i in 1:(N)) {
        y[i,] ~ multinomial(softmax(p[i,]));
    }
}
"""

start = tm.time()
stan_model = stan.build(stan_code, data = tmp)
fit = stan_model.sample(num_chains=1, num_samples=500, num_warmup = 500, init = [{'L_individual': np.zeros((tmp['ni'], tmp['ni']))}])
end = tm.time()    
#df = fit.to_frame()
print(f"Pystan took: {end - start:.4f} seconds")

###############################################################################
#################################### Comparisons  #############################
df = fit.to_frame()
print(jax.nn.softmax(jnp.array(a))) # Simulated
print(jax.nn.softmax(jnp.mean(post['a'], axis = 0))) # Numpypro estimation
print(jax.nn.softmax(jnp.array([df['a.1'].mean(),df['a.2'].mean()],df['a.3'].mean()))) # Pytstan estimation

is this right?

Hi,

I could force only positive values, but this should work. Regarding the shape of the parameter, it is correct: a mean categorical effect for 3 categories.

I figured out that Dirichlet doesn’t use a simplex unlike Multinomial, so softmax is not required, only positive values. If I run the same model but with exp instead of softmax, it works. However, for strange reasons, it leads to a significant increase in computation time. Here are the results for simulated parameters to recover: [0.786986, 0.10650697, 0.10650697]:

  1. DirichletMultinomial with exp: [0.7513038 , 0.1308593 , 0.11783697], execution time of 16:51 minutes.
  2. Multinomial : [0.753772, 0.1263752, 0.1198528], execution time of 3:41 minutes.

Any ideas on why we observe such an increase in execution speed between DirichletMultinomial and Multinomial?

Interestingly stan is faster with DirichletMultinomial (3:35 minutes) but slower with Multinomial ( 10:30 minutes)