Help putting together a hierarchical ordinal regression

Hi all! I’m new-ish to NumPyro and have become stuck adding a level to an ordinal regression. The tutorials I’m building off of are this talk / example given at a Pydata conference and the Ordinal Regression tutorial in the NumPyro documentation. I have survey responses on people’s experiences working at a non-profit ranging from 1 to 5; and I’d like to look at the relationship between various categorical characteristics and those scores. I’ve been successful implementing this model on the entire population, but I’d like to improve the model by incorporating the state someone is located in as a level.

# Dummy example of the kind of data I'm working with:
toy_data = {'state': {0: 'CA', 1: 'CA', 2: 'CA', 3: 'IL', 4: 'MD', 5: 'NY', 6: 'PA', 7:'PA', 8:'IL'},
         'is_remote': {0: 1, 1: 1, 2: 1, 3: 1, 4: 0, 5: 0, 6: 1, 7: 0, 8: 1},
         'is_manager': {0: 0, 1: 0, 2: 1, 3: 1, 4: 1, 5: 0, 6: 1, 7: 1, 8: 0},
         'is_new': {0: 1, 1: 1, 2: 1, 3: 1, 4: 0, 5: 1, 6: 0, 7: 1, 8: 0},
         'rating': {0: 4, 1: 3, 2: 5, 3: 3, 4: 1, 5:4, 6:2, 7:5, 8:5}}

toy_df = pd.DataFrame(toy_data, columns=toy_data.keys())

# I then create a dummy matrix that represents every possible combination of these features to do a prior predictive simulation
# X = np.array([['CA', '0', '0', '0'], ['CA', '1', '0', '0'], ['CA', '0', '1', '0'], ['CA', '0', '0', '1'] ... etc etc])

Here is a model that seems to work – I’ve adopted use of .expand() and jnp.matmul following NumPyro’s Bayesian hierarchical stacking tutorial so as to avoid writing out each coefficient as is done in the PyData example – in reality I have more independent variables than this toy example, so wanted an efficient way to do this without writing out each one… I also couldn’t figure out how to get the probs object from the NumPyro example with the SimplexToOrderedTransform so I included the logits / probs calculation part from the PyData example:

# Following https://num.pyro.ai/en/0.6.0/tutorials/ordinal_regression.html & borrowing use of .expand()
# from https://num.pyro.ai/en/stable/tutorials/bayesian_hierarchical_stacking.html#2

# This works!
n_survey_options = 5
concentration = np.ones((n_survey_options,)) * 10.0

# X here is that dummy vector of all my independent variables for the prior pred. simulation / then actual data later on

def ordered_logistic_regression_cutpoint_prior(X, concentration, n_survey_options, anchor_point=0.0, y=None):

    b_X_eta = numpyro.sample("beta", dist.Normal(0, 1).expand([X.shape[1]])) # Expand here instead of writing out coefficients one by one like in https://github.com/MarcoGorelli/pydataglobal-21/blob/master/ordered%20logistic%20regression.ipynb

    # C_y is cutpoints
    with handlers.reparam(config={"c_y": TransformReparam()}):
        c_y = numpyro.sample(
            "c_y",
            dist.TransformedDistribution(
                dist.Dirichlet(concentration),
                dist.transforms.SimplexToOrderedTransform(anchor_point),
            ),
        )
    

    with numpyro.plate('obs', X.shape[0]):  # Number of observations
        eta = jnp.matmul(X, b_X_eta) # X * b_X_eta << Tutorial it's this, but we need matrix multiplication bc of all of coefs + their values
        numpyro.sample('y', dist.OrderedLogistic(eta, c_y), obs=y)  
    
    # Extra part to try to get probs value << this is correct / ok? Trying to get same output https://github.com/MarcoGorelli/pydataglobal-21/blob/master/ordered%20logistic%20regression.ipynb does, but using NumPyro Ordinal Regression tutorial doesn't give you probs as an output
    logits = c_y - eta[:, jnp.newaxis]
    cumulative_probs = jnp.pad(
        jax.scipy.special.expit(logits),
        pad_width=((0, 0), (1, 1)),
        constant_values=(0, 1),
    )
    probs = numpyro.deterministic("probs", jnp.diff(cumulative_probs))

My main question now is: how do I add state as a level? Conceptually, I am thinking I want to change the mean and standard deviation of dist.Normal(0, 1) of b_X_eta to be instead derived from some state-level mean and standard deviation? Here’s one attempt I’ve tried, where I also went back and mapped each person to their respective state, following advice and examples from this forum like this question about adding two levels to a hierarchical model – however, I get this error: ValueError: Cannot broadcast distribution of shape (73728, 2) to shape (16,) << I think I understand why I get this error and why what I’m doing might be wrong – I have 16 coefficients (73,728 is the size of my dummy dataset), and I don’t think I shouldn’t be doing a plate for employees, but should be doing something else? I’m just missing the part of what that something else is / how to write it out… here’s what I have currently:

# Different method fom https://num.pyro.ai/en/0.6.0/tutorials/ordinal_regression.html
# We will apply a nudge towards equal probability for each category (corresponds to equal logits of the true data generating process)
# Help from https://forum.pyro.ai/t/hierarchical-model-with-two-levels/3775/5
# You have to map each 'employee' / 'persona' to their respective state

n_survey_options = 5

concentration = np.ones((n_survey_options,)) * 10.0

def hierarchical_ordered_logistic_regression_cutpoint_prior(X, employee_id, employee_to_state_lookup, concentration, n_survey_options, state, anchor_point=0.0, y=None):
    
    # Trying out a hierarchical model that samples the distribution of coefficient values per GROUP from global values

    n_states = np.unique(state).shape[0]

    with numpyro.plate("state", n_states):
        state_mean = numpyro.sample("state_mean", dist.Normal(0.0, 1.0))  
        state_sd = numpyro.sample("state_sd", dist.HalfNormal(1.0))  
    
    with numpyro.plate('employee_id', X.shape[0]):
        b_X_eta = numpyro.sample('b_X_eta', dist.Normal(state_mean[employee_to_state_lookup], state_sd[employee_to_state_lookup]).expand([X.shape[1]]))

    # C_y is cutpoints -- I don't think these are / should be influenced by groups? See https://betanalpha.github.io/assets/case_studies/ordinal_regression.html#22_Surgical_Cut
    with handlers.reparam(config={"c_y": TransformReparam()}):
        c_y = numpyro.sample(
            "c_y",
            dist.TransformedDistribution(
                dist.Dirichlet(concentration),
                dist.transforms.SimplexToOrderedTransform(anchor_point),
            ),
        )
    

    with numpyro.plate('obs', X.shape[0]):  # Number of observations
        eta = jnp.matmul(X, b_X_eta) # X * b_X_eta << Tutorial it's this, but we need matrix multiplication bc of all of coefs + their values
        numpyro.sample('y', dist.OrderedLogistic(eta, c_y), obs=y)
    
    # Extra part to try to get probs value
    logits = c_y - eta[:, jnp.newaxis]
    cumulative_probs = jnp.pad(
        jax.scipy.special.expit(logits),
        pad_width=((0, 0), (1, 1)),
        constant_values=(0, 1),
    )
    probs = numpyro.deterministic("probs", jnp.diff(cumulative_probs))

Thank you in advance for any help! I really have tried reading and researching, but I think I’m just stuck on how the syntax works of what plates are doing / how to fold in different levels for models other than a linear regression :slight_smile: