Help putting together a hierarchical ordinal regression

Hi all! I’m new-ish to NumPyro and have become stuck adding a level to an ordinal regression. The tutorials I’m building off of are this talk / example given at a Pydata conference and the Ordinal Regression tutorial in the NumPyro documentation. I have survey responses on people’s experiences working at a non-profit ranging from 1 to 5; and I’d like to look at the relationship between various categorical characteristics and those scores. I’ve been successful implementing this model on the entire population, but I’d like to improve the model by incorporating the state someone is located in as a level.

# Dummy example of the kind of data I'm working with:
toy_data = {'state': {0: 'CA', 1: 'CA', 2: 'CA', 3: 'IL', 4: 'MD', 5: 'NY', 6: 'PA', 7:'PA', 8:'IL'},
         'is_remote': {0: 1, 1: 1, 2: 1, 3: 1, 4: 0, 5: 0, 6: 1, 7: 0, 8: 1},
         'is_manager': {0: 0, 1: 0, 2: 1, 3: 1, 4: 1, 5: 0, 6: 1, 7: 1, 8: 0},
         'is_new': {0: 1, 1: 1, 2: 1, 3: 1, 4: 0, 5: 1, 6: 0, 7: 1, 8: 0},
         'rating': {0: 4, 1: 3, 2: 5, 3: 3, 4: 1, 5:4, 6:2, 7:5, 8:5}}

toy_df = pd.DataFrame(toy_data, columns=toy_data.keys())

# I then create a dummy matrix that represents every possible combination of these features to do a prior predictive simulation
# X = np.array([['CA', '0', '0', '0'], ['CA', '1', '0', '0'], ['CA', '0', '1', '0'], ['CA', '0', '0', '1'] ... etc etc])

Here is a model that seems to work – I’ve adopted use of .expand() and jnp.matmul following NumPyro’s Bayesian hierarchical stacking tutorial so as to avoid writing out each coefficient as is done in the PyData example – in reality I have more independent variables than this toy example, so wanted an efficient way to do this without writing out each one… I also couldn’t figure out how to get the probs object from the NumPyro example with the SimplexToOrderedTransform so I included the logits / probs calculation part from the PyData example:

# Following https://num.pyro.ai/en/0.6.0/tutorials/ordinal_regression.html & borrowing use of .expand()
# from https://num.pyro.ai/en/stable/tutorials/bayesian_hierarchical_stacking.html#2

# This works!
n_survey_options = 5
concentration = np.ones((n_survey_options,)) * 10.0

# X here is that dummy vector of all my independent variables for the prior pred. simulation / then actual data later on

def ordered_logistic_regression_cutpoint_prior(X, concentration, n_survey_options, anchor_point=0.0, y=None):

    b_X_eta = numpyro.sample("beta", dist.Normal(0, 1).expand([X.shape[1]])) # Expand here instead of writing out coefficients one by one like in https://github.com/MarcoGorelli/pydataglobal-21/blob/master/ordered%20logistic%20regression.ipynb

    # C_y is cutpoints
    with handlers.reparam(config={"c_y": TransformReparam()}):
        c_y = numpyro.sample(
            "c_y",
            dist.TransformedDistribution(
                dist.Dirichlet(concentration),
                dist.transforms.SimplexToOrderedTransform(anchor_point),
            ),
        )
    

    with numpyro.plate('obs', X.shape[0]):  # Number of observations
        eta = jnp.matmul(X, b_X_eta) # X * b_X_eta << Tutorial it's this, but we need matrix multiplication bc of all of coefs + their values
        numpyro.sample('y', dist.OrderedLogistic(eta, c_y), obs=y)  
    
    # Extra part to try to get probs value << this is correct / ok? Trying to get same output https://github.com/MarcoGorelli/pydataglobal-21/blob/master/ordered%20logistic%20regression.ipynb does, but using NumPyro Ordinal Regression tutorial doesn't give you probs as an output
    logits = c_y - eta[:, jnp.newaxis]
    cumulative_probs = jnp.pad(
        jax.scipy.special.expit(logits),
        pad_width=((0, 0), (1, 1)),
        constant_values=(0, 1),
    )
    probs = numpyro.deterministic("probs", jnp.diff(cumulative_probs))

My main question now is: how do I add state as a level? Conceptually, I am thinking I want to change the mean and standard deviation of dist.Normal(0, 1) of b_X_eta to be instead derived from some state-level mean and standard deviation? Here’s one attempt I’ve tried, where I also went back and mapped each person to their respective state, following advice and examples from this forum like this question about adding two levels to a hierarchical model – however, I get this error: ValueError: Cannot broadcast distribution of shape (73728, 2) to shape (16,) << I think I understand why I get this error and why what I’m doing might be wrong – I have 16 coefficients (73,728 is the size of my dummy dataset), and I don’t think I shouldn’t be doing a plate for employees, but should be doing something else? I’m just missing the part of what that something else is / how to write it out… here’s what I have currently:

# Different method fom https://num.pyro.ai/en/0.6.0/tutorials/ordinal_regression.html
# We will apply a nudge towards equal probability for each category (corresponds to equal logits of the true data generating process)
# Help from https://forum.pyro.ai/t/hierarchical-model-with-two-levels/3775/5
# You have to map each 'employee' / 'persona' to their respective state

n_survey_options = 5

concentration = np.ones((n_survey_options,)) * 10.0

def hierarchical_ordered_logistic_regression_cutpoint_prior(X, employee_id, employee_to_state_lookup, concentration, n_survey_options, state, anchor_point=0.0, y=None):
    
    # Trying out a hierarchical model that samples the distribution of coefficient values per GROUP from global values

    n_states = np.unique(state).shape[0]

    with numpyro.plate("state", n_states):
        state_mean = numpyro.sample("state_mean", dist.Normal(0.0, 1.0))  
        state_sd = numpyro.sample("state_sd", dist.HalfNormal(1.0))  
    
    with numpyro.plate('employee_id', X.shape[0]):
        b_X_eta = numpyro.sample('b_X_eta', dist.Normal(state_mean[employee_to_state_lookup], state_sd[employee_to_state_lookup]).expand([X.shape[1]]))

    # C_y is cutpoints -- I don't think these are / should be influenced by groups? See https://betanalpha.github.io/assets/case_studies/ordinal_regression.html#22_Surgical_Cut
    with handlers.reparam(config={"c_y": TransformReparam()}):
        c_y = numpyro.sample(
            "c_y",
            dist.TransformedDistribution(
                dist.Dirichlet(concentration),
                dist.transforms.SimplexToOrderedTransform(anchor_point),
            ),
        )
    

    with numpyro.plate('obs', X.shape[0]):  # Number of observations
        eta = jnp.matmul(X, b_X_eta) # X * b_X_eta << Tutorial it's this, but we need matrix multiplication bc of all of coefs + their values
        numpyro.sample('y', dist.OrderedLogistic(eta, c_y), obs=y)
    
    # Extra part to try to get probs value
    logits = c_y - eta[:, jnp.newaxis]
    cumulative_probs = jnp.pad(
        jax.scipy.special.expit(logits),
        pad_width=((0, 0), (1, 1)),
        constant_values=(0, 1),
    )
    probs = numpyro.deterministic("probs", jnp.diff(cumulative_probs))

Thank you in advance for any help! I really have tried reading and researching, but I think I’m just stuck on how the syntax works of what plates are doing / how to fold in different levels for models other than a linear regression :slight_smile:

Ok! It’s been a journey. I finally figured this out, and wanted to post an update in case it helps anyone else struggling with this kind of thing. Instead of using what my actual data / variables were, I’ll simplify it a bit so it’s easier to break down: I have a survey with individual responses that number 1 through 5, and I want to do a model that has three levels: state, county and political party affiliation:

# Sample data:
toy_data = {
'respondent_id': {0:'1', 1:'2', 2:'3', 3:'4', 4:'5', 5:'6', 6:'7', 7:'8', 8:'9'},
'party': {0:'D', 1:'D', 2:'R', 3:'G', 4:'R', 5:'D', 6:'G', 7:'G', 8:'D'},
'state': {0:'CA', 1:'CA', 2:'CA', 3:'IL', 4:'PA', 5:'NY', 6:'PA', 7:'NY', 8:'IL'},
'county': {0:'Lake', 1:'Mendocino', 2:'Sonoma', 3:'Cook', 4:'Allegheny', 5:'Queens', 6:'Philadelphia', 7:'Kings', 8:'Champaign'},
'rating': {0:4, 1:3, 2:5, 3:3, 4:1, 5:4, 6:2, 7:5, 8:5}}

toy_df = pd.DataFrame(toy_data, columns=toy_data.keys())

First, following the process detailed here, I encoded my party, state and county columns and linked them together. For each sub-level (party and county), I created a new variable that combined my level columns. I did this because whatever level you have has to be distinct (for example, there is a ‘Lake County’ in California and also one in Florida, etc. etc.; you’ll have Green party members in different counties / states). My data was in the format of a pandas data frame, so I just made new columns that added relevant columns together:

# Create new column that will separate out party level correctly
toy_df['party_level'] = toy_df['party'] + '-' + toy_df['county'] + ' - ' + toy_df['state']

# Create new column that will separate out county level correctly
toy_df['county_level'] = toy_df['county'] + ' - ' + toy_df['state']

# Encode each, link them together:
party_level_encoder = LabelEncoder()
toy_df['party_level_code'] = party_level_encoder.fit_transform(toy_df['party_level'])

county_level_encoder = LabelEncoder()
toy_df['county_level_code'] = county_level_encoder.fit_transform(toy_df['county_level'])

state_level_encoder = LabelEncoder()
toy_df['state_level_code'] = state_level_encoder.fit_transform(toy_df['state'])

# Link each of these to each other:
map_party_level_to_county_level = (
toy_df[['party_level_code', 'county_level_code']]
.drop_duplicates()
.set_index('party_level_code', verify_integrity=True)
.sort_index()['county_level_code']
.values)

map_county_level_to_state_level = (
toy_df[['county_level_code', 'state_level_code']]
.drop_duplicates()
.set_index('county_level_code', verify_integrity=True)
.sort_index()['state_level_code']
.values)

Also need to format response variable correctly:

responses_to_index = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4}

y = toy_df[rating].map(responses_to_index).to_numpy()

Here’s how the model code looks like (adapted from sources cited above + examples from Statistical Rethinking):

n_survey_options = 5
concentration = np.ones((n_survey_options,)) * 10.0

def model(
    party_level_code, map_party_level_to_county_level,
    map_county_level_to_state_level,
    concentration, n_survey_options,
    anchor_point=0.0, y=None):

    """Hierarchical ordinal regression model"""

    μ_β_global = numpyro.sample("μ_β_global", dist.Normal(0.0, 1.0))
    σ_β_global = numpyro.sample("σ_β_global", dist.HalfNormal(1.0))
    
    # Set up this for plates:
    n_state_level = len(np.unique(map_county_level_to_state_level))
    n_county_level = len(np.unique(map_party_level_to_county_level))
    n_final_level = len(np.unique(party_level_code))

    with numpyro.plate("plate_state_level", n_state_level):
        μ_β_state_level = numpyro.sample(
            "μ_β_state_level", dist.Normal(μ_β_global, σ_β_global)
        )
    
    with numpyro.plate("plate_county_level", n_county_level):
        β_county_level = numpyro.sample(
            "β_county_level", dist.Normal(μ_β_state_level[map_county_level_to_state_level], σ_β_global)
        )
    
    with numpyro.plate("plate_i", n_party_level):
        β = numpyro.sample(
            "β",
            dist.Normal(β_county_level[map_party_level_to_county_level], σ_β_global),
        )
    
    # C_y is cutpoints -- I don't think these are / should be influenced by groups? See https://betanalpha.github.io/assets/case_studies/ordinal_regression.html#22_Surgical_Cut
    with handlers.reparam(config={"c_y": TransformReparam()}):
        c_y = numpyro.sample(
            "c_y",
            dist.TransformedDistribution(
                dist.Dirichlet(concentration),
                dist.transforms.SimplexToOrderedTransform(anchor_point),
            ),
        )
        
# If you have any independent variables in your model, you multiply them by β[party_level_code] here
prediction = (
        β[final_level_code]
)

    logits = c_y - prediction[:, jnp.newaxis]
    cumulative_probs = jnp.pad(
        jax.scipy.special.expit(logits),
        pad_width=((0, 0), (1, 1)),
        constant_values=(0, 1),
    )

    probs = numpyro.deterministic("probs", jnp.diff(cumulative_probs))

    numpyro.sample(
        'y',
        dist.Categorical(probs=probs),
        obs=y,
    )

I also reparameterized:

reparam_config = {
    "μ_β_state_level": LocScaleReparam(0),
    "β_county_level": LocScaleReparam(0),
    "β": LocScaleReparam(0),
}
reparam_model = reparam(
    model, config=reparam_config
    )

Then I do a prior predictive check run with dummy data (not shown here), and then run on actual data:

# Just need the smallest sub level, because we've defined our maps above 
party_level_code = toy_df['party_level_code'].values

nuts_kernel = numpyro.infer.NUTS(reparam_model)  

mcmc = numpyro.infer.MCMC(nuts_kernel, num_samples=3000, num_warmup=5000)
rng_key = jax.random.PRNGKey(42)
mcmc.run(rng_key,
    party_level_code, map_party_level_to_county_level,
    map_county_level_to_state_level,
    concentration, n_survey_options, anchor_point=0.0, y=y)

posterior_samples = mcmc.get_samples()

There you go / good luck !