Bayesian Hierarchical Linear Regression: how to predict on a new patient having observations?

Hello there :smile:

I would like to understand how to extend the mentioned tutorial with NumPyro so that:

  • having a new (= not part of training) patient, with some observed data,
  • we infer his “random effects” (i.e. posterior distribution of (\alpha, \beta) - or only offsets (\alpha - \mu_\alpha, \beta - \mu_\beta)), while retaining “fixed effects” of the once trained model (i.e. posterior distribution of (\mu_\alpha, \mu_\beta, \sigma_\alpha, \sigma_\beta, \sigma))
  • and then eventually predict his outcome (FVC_{est}) at unknown weeks

I precise that:

  • I’d be in a setup where I could not re-train a model with union of all data (old training data + new data)
  • This new patient having observations, I do want to infer his random effects, not considering them as centered on fixed effects (as we could do for a new patient without any observations)

Thanks for your help :blush:

Related material found:

Hi @exhumea, when you perform Bayesian inference and get posterior, you can use that posterior (rather than prior) for new inference with new data.

Hi @fehiepsi! Thanks, so you’d estimate univariate / multivariate joint posterior distribution of “fixed effects” and set them as new priors for inference on new patients? But then is there a way somehow for that these “fixed effects” not to be updated during this new inference?

To fix values of a site, you can use block + condition handlers.

set them as new priors for inference on new patients

Yes, I think it is a reasonable approach.

I may have misunderstood the use of (block +) condition but doesn’t it enable to replace some of the model latent variables priors by point estimates only?

For instance, if after first inference, posterior distribution of \mu_\alpha seems to be well modeled by \mathcal{N}(\hat{\mu}_{\mu_\alpha}, \hat{\sigma}_{\mu_\alpha}) how could I sample this variable from this distribution in my latter inferences, keeping it not updated / “blocked”?

Thanks!

Maybe there’s a more elegant way of doing it but my code usually looks like the code below. Basically, I have a custom predict function that handles the logic for unseen patients with a try-except-else. To decide if the patient is “seen” or “new” I use sklearn’s LabelEncoder(). The example below is minimal, only controlling for age, but you can easily extend it to control for sex or smoking_status.

# not showing import statements nor the code for loading the data

encoder = LabelEncoder()
encoder.fit(train["Patient"])

N_PATIENTS = len(encoder.classes_)

patient_to_age = train[['Patient', 'Age']] \
    .drop_duplicates() \
    .set_index("Patient") \
    .loc[encoder.classes_, "Age"] \
    .values

data_dict=dict(
    PatientNum = jnp.array(encoder.transform(train['Patient'].values)),
    Weeks = jnp.array(train['Weeks'].values),
    PatientToAge = jnp.array(patient_to_age),
    FVC_obs = jnp.array(train['FVC'].values),
)

def model(PatientNum, Weeks, PatientToAge, FVC_obs=None):
    Ό_α = numpyro.sample("Ό_α", dist.Normal(0., 100.))
    σ_α = numpyro.sample("σ_α", dist.HalfNormal(100.))
    ÎŒ_ÎČ = numpyro.sample("ÎŒ_ÎČ", dist.Normal(0., 100.))
    σ_ÎČ = numpyro.sample("σ_ÎČ", dist.HalfNormal(100.))
    
    # Age's effect on α
    coef_age= numpyro.sample("coef_age", dist.Normal(0., 100.,))

    with numpyro.plate("plate_i", N_PATIENTS) as patient_idx:
        α = numpyro.sample("α", dist.Normal(ÎŒ_α + coef_age * PatientToAge[patient_idx], σ_α))
        ÎČ = numpyro.sample("ÎČ", dist.Normal(ÎŒ_ÎČ, σ_ÎČ))

    σ = numpyro.sample("σ", dist.HalfNormal(100.))
    FVC_est = α[PatientNum] + ÎČ[PatientNum] * Weeks
    
    n_obs = PatientNum.shape[0]
    with numpyro.plate("data", n_obs):
        numpyro.sample("obs", dist.Normal(FVC_est, σ), obs=FVC_obs)

# run mcmc
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=2000, num_warmup=2000)
rng_key = random.PRNGKey(0)

mcmc.run(rng_key, **data_dict)

# custom predict function
def predict_single_patient(patient_id, weeks, age, posterior_samples, patient_encoder):
    # logic to handle unseen patients:
    try:
        n = patient_encoder.transform([patient_id])[0]
    except ValueError:
        # new patient but known age
        Ό_α = posterior_samples["Ό_α"]
        σ_α = posterior_samples["σ_α"]
        α = dist.Normal(ÎŒ_α + age * posterior_samples["coef_age"], σ_α) \
            .sample(random.PRNGKey(0))
            
        ÎČ = dist.Normal(posterior_samples["ÎŒ_ÎČ"], posterior_samples["σ_ÎČ"]) \
            .sample(random.PRNGKey(0))
        
    else:
        α = posterior_samples["α"][:, n]
        ÎČ = posterior_samples["ÎČ"][:, n]

    mu = α[:, None] + ÎČ[:, None] * weeks
    # note: I'm not including the noise term σ in the prediction here
    # but you could easily do that
    return mu


weeks = jnp.arange(-12, 133)


# try on seen patient
patient = "ID00007637202177411956430"
age = train.query("Patient == @patient")["Age"].iloc[0]
mu_pred = predict_single_patient(patient, weeks, age, mcmc.get_samples(), encoder)

# try on unseen patient but with known age
patient = "this is a random id"
age = 66
mu_pred = predict_single_patient(patient, weeks, age, mcmc.get_samples(), encoder)

Thank you @omarfsosa.
The difference in my setup is that new patients do have observed data that I want to take into account (i.e. to infer their random effects, the fixed effects being somehow excluded from new inference) so I don’t want to just return a population-level prediction.

Thanks to @fehiepsi I now understood:

  • how to force my fixed effects to some point estimates (e.g. mean or mode from training inference posterior distribution) for following inferences on new patients with block + condition
  • that I could replace my initial priors on fixed effects (= used during training inference) with posterior distributions obtained after training inference:
    • but in turn I still don’t understand how I could freeze these sites / exclude them from next inferences so that only random effects are inferred

I see. So perhaps you want to use something like numpyro.handlers.do and create a sort of “intervened” model? I’m not convinced there’s an easy way of achieving what you’re trying to do here without writing a new model altogether. I think that strictly speaking the data on the new patient will affect the posterior distribution of the rest of the population (only a little cause you have a lot of patients, but still) so inferences should all be considered simultaneously. Inferring the effects of the new patient only feels like a different model to me. Definitely interested in hearing if you find a solution to this :slight_smile:

I guess you want to perform inference for p(y) (y is a random effect) given the density p(x, y) (x is a fixed effect), by marginalizing out the x variable. I’m not sure what is an easy way to perform such marginalization (maybe using funsor here? I don’t know). One way is to approximate it by \sum_{i=0}^{n-1} p(x_i)p(y|x_i) - you’ll need to create an auxiliary categorical variable over {0,..., n-1} (with probs parameter is equal to p(x_i)), some sort of:

def model():
    c = sample("c", Categorical(p_x))
    x_i = x[c]
    # use the original modeling code here
    # with x=x_i rather than a sample site
    sample("y", p(y|x_i))

If you assume that x and y are independent (mean-field approximation). You can use HMCGibbs wherein the gibbs_fn, you can just simply draw a sample for your fixed effects.

Disclaim: Please justify the above suggestions both theoretically and empirically.

A bit of topic, but I did like your prediction function here. What I’m wondering is how you go about including the error as well.

This would be quite helpful, because it would let me store the posterior as a *.nc via arviz and read that if I don’t have the model “activated”. Basically letting me start the notebook and just use the saved data, without having to re-run the model.

Could even automate that by checking if I have new data and if not, load the saved samples.