Hi,

I’m trying to build a discrete choice multinomial model with a twist. The twist is that some choices are screened out (the interpretation doesn’t matter too much, but you can think of it as a consumer screening out choices to create a subset to choose from). I’ve seen this kind of thing done with custom Gibbs sampling code, but I’m hoping to fit with a modern PPL if possible.

In this model, each individual is presented with a set of choices, with different values for each attribute. We suppose that there are constant coefficient values that determine the latent utility value of each choice (the logits) for all consumers.

If I create a screening rule that is known with certainty, the model fits fine. We can deterministically limit the choices within the model. For instance, screening out all options where the first column is > 0.7. But ideally I’d like to estimate the 0.7 from the data. In other words, given the data on the available choices to each consumer and they choices they made, I’d like to estimate this screening rule threshold as a parameter. But I’m not sure this is possible.

For simplicity, I’m using a no intercept model, and am holding the coefficient of the first variable to 1 while estimating the other parameters.

Here is the full code to estimate with deterministic screening:

```
import pandas as pd
import numpy as np
import jax.numpy as jnp
import numpyro
from scipy.special import softmax
from numpyro.infer import MCMC, NUTS
import numpyro.distributions as dist
from jax import random
import arviz as az
import jax
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=6'
numpyro.set_host_device_count(4)
def utilty(X):
return X['attr0'] + X['attr1'] + X['attr2'] + .5 * X['attr3']
np.random.seed(0)
inputs = []
outcomes = []
budgets = []
for i in range(800):
attrs = np.random.uniform(0, 1, size=(3,4))
attrs = np.vstack([attrs, np.zeros(4)])
products = pd.DataFrame(attrs, columns=['attr0', 'attr1', 'attr2', 'attr3'])
deterministic_utilities = utilty(products)
products['util'] = deterministic_utilities
censored = np.where(products['attr0'].values > 0.7)[0]
probs = softmax(np.delete(deterministic_utilities.values, censored))
choice = np.random.choice(range(4-len(censored)), p=probs)
for cen in censored:
if choice >= cen:
choice += 1
choices = [0]*4
choices[choice] = 1
products['choice'] = choices
nonchoices = np.array([0]*4)
nonchoices[np.where(products['attr0'].values > 0.7)] = 1
products['censored'] = nonchoices
#print(products)
X = products[['attr0', 'attr1', 'attr2', 'attr3']].values
inputs.append(X)
outcomes.append(jnp.array(products['choice'].values))
budgets.append(9)
def model(X, outcomes=None):
vars_num = 3
beta = numpyro.sample(
'beta',
dist.MultivariateNormal(
jnp.array([0]*vars_num),
jnp.eye(vars_num)
)
)
for idx, budget in enumerate(budgets):
X_idx = X[idx]
censored = np.where(X_idx[:, 0] > 0.7)[0]
X_idx = np.delete(X_idx, censored, axis=0)
utilities = X_idx[:, 0] + jnp.matmul(X_idx[:, 1:], beta)
y = numpyro.sample('obs' + str(idx), dist.MultinomialLogits(utilities), obs=np.delete(outcomes[idx], censored))
m = MCMC(NUTS(model), num_warmup=1000, num_samples=2000, num_chains=4)
m.run(random.PRNGKey(0), inputs, outcomes)
samples = m.get_samples()
data = az.from_numpyro(m)
az.plot_trace(data, var_names=['beta'])
az.summary(data, var_names=['beta'])
```

However, when I try to estimate the 0.7 threshold as a parameter, I run into jax errors.

If I just substitute a threshold parameter where the 0.7 was, like this:

```
def model(X, outcomes=None):
vars_num = 3
beta = numpyro.sample('beta', dist.MultivariateNormal(jnp.array([0]*vars_num), jnp.eye(vars_num)))
thresh = numpyro.sample('thresh', dist.Beta(0.5, 0.5))
for idx, budget in enumerate(budgets):
X_idx = X[idx]
censored = jnp.where(X_idx[:, 0] > thresh)[0]
X_idx = np.delete(X_idx, censored, axis=0)
utilities = X_idx[:, 0] + jnp.matmul(X_idx[:, 1:], beta)
y = numpyro.sample('obs' + str(idx), dist.MultinomialLogits(utilities), obs=np.delete(outcomes[idx], censored))
```

I get a `ConcretizationTypeError`

at `jnp.where(X_idx[:, 0] > thresh)`

.

If instead I try to substitute negative infinity for the logits of screened out choices based on the thresh parameter:

```
def model(X, outcomes=None):
vars_num = 3
beta = numpyro.sample(
'beta',
dist.MultivariateNormal(
jnp.array([0]*vars_num),
jnp.eye(vars_num)
)
)
thresh = numpyro.sample(
'thresh',
dist.Beta(0.5, 0.5)
)
for idx, budget in enumerate(budgets):
X_idx = X[idx]
utilities = X_idx[:, 0] + jnp.matmul(X_idx[:, 1:], beta)
utilities.at[(X_idx[:, 0] > thresh)] = -np.inf
y = numpyro.sample('obs' + str(idx), dist.MultinomialLogits(utilities), obs=outcomes[idx])
```

But that results in a `TracerIntegerConversionError`

at `utilities.at[(X_idx[:, 0] > thresh)] = -np.inf`

.

I looked at TracerIntegerConversionError, and I can’t figure out a way past this. Does anyone have ideas on how I might be able to get around this to estimate such a threshold parameter?