Estimate a parameter for conditional inclusion to a Multinomial Regression

Hi,

I’m trying to build a discrete choice multinomial model with a twist. The twist is that some choices are screened out (the interpretation doesn’t matter too much, but you can think of it as a consumer screening out choices to create a subset to choose from). I’ve seen this kind of thing done with custom Gibbs sampling code, but I’m hoping to fit with a modern PPL if possible.

In this model, each individual is presented with a set of choices, with different values for each attribute. We suppose that there are constant coefficient values that determine the latent utility value of each choice (the logits) for all consumers.

If I create a screening rule that is known with certainty, the model fits fine. We can deterministically limit the choices within the model. For instance, screening out all options where the first column is > 0.7. But ideally I’d like to estimate the 0.7 from the data. In other words, given the data on the available choices to each consumer and they choices they made, I’d like to estimate this screening rule threshold as a parameter. But I’m not sure this is possible.

For simplicity, I’m using a no intercept model, and am holding the coefficient of the first variable to 1 while estimating the other parameters.

Here is the full code to estimate with deterministic screening:

import pandas as pd
import numpy as np
import jax.numpy as jnp
import numpyro
from scipy.special import softmax
from numpyro.infer import MCMC, NUTS
import numpyro.distributions as dist
from jax import random
import arviz as az
import jax

import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=6'
numpyro.set_host_device_count(4)

def utilty(X):
    return X['attr0'] + X['attr1'] + X['attr2'] + .5 * X['attr3']

np.random.seed(0)
inputs = []
outcomes = []
budgets = []
for i in range(800):
    
    attrs = np.random.uniform(0, 1, size=(3,4))
    attrs = np.vstack([attrs, np.zeros(4)])
    products = pd.DataFrame(attrs, columns=['attr0', 'attr1', 'attr2', 'attr3'])
    deterministic_utilities = utilty(products)
    products['util'] = deterministic_utilities
    censored = np.where(products['attr0'].values > 0.7)[0]

    probs = softmax(np.delete(deterministic_utilities.values, censored))
    choice = np.random.choice(range(4-len(censored)), p=probs)
    for cen in censored:
        if choice >= cen:
            choice += 1
    choices = [0]*4
    choices[choice] = 1

    products['choice'] = choices
    nonchoices = np.array([0]*4)

    nonchoices[np.where(products['attr0'].values > 0.7)] = 1
    products['censored'] = nonchoices
    #print(products)

    X = products[['attr0', 'attr1', 'attr2', 'attr3']].values
    inputs.append(X)

    outcomes.append(jnp.array(products['choice'].values))
    budgets.append(9)
    
def model(X, outcomes=None):
    vars_num = 3
    beta = numpyro.sample(
        'beta', 
        dist.MultivariateNormal(
            jnp.array([0]*vars_num), 
            jnp.eye(vars_num)
        )
    )
    for idx, budget in enumerate(budgets):
        X_idx = X[idx]
        censored = np.where(X_idx[:, 0] > 0.7)[0]
        X_idx = np.delete(X_idx, censored, axis=0)
        utilities = X_idx[:, 0] + jnp.matmul(X_idx[:, 1:], beta)
        y = numpyro.sample('obs' + str(idx), dist.MultinomialLogits(utilities), obs=np.delete(outcomes[idx], censored))

m = MCMC(NUTS(model), num_warmup=1000, num_samples=2000, num_chains=4)

m.run(random.PRNGKey(0), inputs, outcomes)

samples = m.get_samples()
data = az.from_numpyro(m)

az.plot_trace(data, var_names=['beta'])

az.summary(data, var_names=['beta'])

However, when I try to estimate the 0.7 threshold as a parameter, I run into jax errors.

If I just substitute a threshold parameter where the 0.7 was, like this:

def model(X, outcomes=None):
    vars_num = 3
    beta = numpyro.sample('beta', dist.MultivariateNormal(jnp.array([0]*vars_num), jnp.eye(vars_num)))
    thresh = numpyro.sample('thresh', dist.Beta(0.5, 0.5))
    for idx, budget in enumerate(budgets):
        X_idx = X[idx]
        censored = jnp.where(X_idx[:, 0] > thresh)[0]
        X_idx = np.delete(X_idx, censored, axis=0)
        utilities = X_idx[:, 0] + jnp.matmul(X_idx[:, 1:], beta)
        y = numpyro.sample('obs' + str(idx), dist.MultinomialLogits(utilities), obs=np.delete(outcomes[idx], censored))

I get a ConcretizationTypeError at jnp.where(X_idx[:, 0] > thresh).

If instead I try to substitute negative infinity for the logits of screened out choices based on the thresh parameter:

def model(X, outcomes=None):
    vars_num = 3
    beta = numpyro.sample(
        'beta', 
        dist.MultivariateNormal(
            jnp.array([0]*vars_num),
            jnp.eye(vars_num)
        )
    )
    thresh = numpyro.sample(
        'thresh',
        dist.Beta(0.5, 0.5)
    )
    for idx, budget in enumerate(budgets):
        X_idx = X[idx]
        utilities = X_idx[:, 0] + jnp.matmul(X_idx[:, 1:], beta)
        utilities.at[(X_idx[:, 0] > thresh)] = -np.inf
        y = numpyro.sample('obs' + str(idx), dist.MultinomialLogits(utilities), obs=outcomes[idx])

But that results in a TracerIntegerConversionError at utilities.at[(X_idx[:, 0] > thresh)] = -np.inf.

I looked at TracerIntegerConversionError, and I can’t figure out a way past this. Does anyone have ideas on how I might be able to get around this to estimate such a threshold parameter?

I’m seeing you are using np and jnp interchangeably (including the part utilities.at... which does not follow jax syntax for updating an array value) and using jnp.where without true and false values. To make MCMC work, you will need to use jax operators at those places. It would be better to write down a function f(x, thresh) and try to make it work correctly first. Then you can use that function inside the model.

Thank you! I forget to use set there.

I was able to get a version working that infers the location of thresh and remove the for loop within the model. It recovers the parameters pretty well:

import pandas as pd
import numpy as np
import jax.numpy as jnp
import numpyro
from scipy.special import softmax
from numpyro.infer import MCMC, NUTS
import numpyro.distributions as dist
from jax import random
import arviz as az
import jax

import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=6'
numpyro.set_host_device_count(4)

def utilty(X):
    return 1*X['attr0'] + 4*X['attr1'] + .5*X['attr2'] + .5 * X['attr3']

np.random.seed(0)
inputs = []
censoreds = []
outcomes = []
budgets = []
for i in range(4000):
    
    attrs = np.random.uniform(0, 1, size=(3,4))
    attrs = np.vstack([attrs, np.zeros(4)])
    products = pd.DataFrame(attrs, columns=['attr0', 'attr1', 'attr2', 'attr3'])
    deterministic_utilities = utilty(products)
    products['rand_util'] = deterministic_utilities
    censored = np.where(products['attr1'].values > 0.7)[0]
    censoreds.append(censored)

    probs = softmax(np.delete(deterministic_utilities.values, censored))
    choice = np.random.choice(range(4-len(censored)), p=probs)
    for cen in censored:
        if choice >= cen:
            choice += 1
    choices = [0]*4
    choices[choice] = 1

    products['choice'] = choices
    nonchoices = np.array([0]*4)

    nonchoices[np.where(products['attr1'].values > 0.7)] = 1
    products['censored'] = nonchoices
    #print(products)
    X = products[['attr0', 'attr1', 'attr2', 'attr3']].values
    inputs.append(X)

    outcomes.append(jnp.array(products['choice'].values))
    budgets.append(9)
    
def model(X, outcomes=None):
    vars_num = 3
    beta = numpyro.sample(
        'beta', 
        dist.MultivariateNormal(
            jnp.zeros(vars_num),
            jnp.eye(vars_num)*10
        )
    )
    thresh = numpyro.sample(
        'thresh',
        dist.Exponential(1)
    )
    utilities = X[:, :, 0] + jnp.matmul(X[:, :, 1:], beta)
    final_utilities = jnp.where(X[:, :, 1] > thresh, -1000, utilities)
    y = numpyro.sample('obs', dist.MultinomialLogits(final_utilities), obs=outcomes)

m = MCMC(NUTS(model), num_warmup=8000, num_samples=4000, num_chains=4)

m.run(random.PRNGKey(0), jnp.array(inputs), jnp.array(outcomes))

samples = m.get_samples()
data = az.from_numpyro(m)

az.plot_trace(data, var_names=['beta', 'thresh'])
az.summary(data, var_names=['beta', 'thresh'])

I found that it helped to have a long warmup, although I’m also planning to try out some different sampling methods.

didn’t look at your model in any detail but afaik your model log density is not differentiable w.r.t. thresh which might lead to problems in some regimes (the results can still be correct in general do to the metropolis hastings correction but sampling might potentially become inefficient)

Thanks. Yes, that’s what I was thinking. Does numpyro have a good way to sample from something like that? Would SA work better? Or would I need to go to Gibbs?

well i’d generally only expect problems if your latent space dimensionality gets large-ish although it’s hard to say. if you have trouble i’d recommend converting thres to a discrete latent linked to some pre-chosen thresholds