Maybe thereâs a more elegant way of doing it but my code usually looks like the code below. Basically, I have a custom predict function that handles the logic for unseen patients with a try-except-else
. To decide if the patient is âseenâ or ânewâ I use sklearnâs LabelEncoder()
. The example below is minimal, only controlling for age
, but you can easily extend it to control for sex
or smoking_status
.
# not showing import statements nor the code for loading the data
encoder = LabelEncoder()
encoder.fit(train["Patient"])
N_PATIENTS = len(encoder.classes_)
patient_to_age = train[['Patient', 'Age']] \
.drop_duplicates() \
.set_index("Patient") \
.loc[encoder.classes_, "Age"] \
.values
data_dict=dict(
PatientNum = jnp.array(encoder.transform(train['Patient'].values)),
Weeks = jnp.array(train['Weeks'].values),
PatientToAge = jnp.array(patient_to_age),
FVC_obs = jnp.array(train['FVC'].values),
)
def model(PatientNum, Weeks, PatientToAge, FVC_obs=None):
Ό_α = numpyro.sample("Ό_α", dist.Normal(0., 100.))
Ï_α = numpyro.sample("Ï_α", dist.HalfNormal(100.))
ÎŒ_ÎČ = numpyro.sample("ÎŒ_ÎČ", dist.Normal(0., 100.))
Ï_ÎČ = numpyro.sample("Ï_ÎČ", dist.HalfNormal(100.))
# Age's effect on α
coef_age= numpyro.sample("coef_age", dist.Normal(0., 100.,))
with numpyro.plate("plate_i", N_PATIENTS) as patient_idx:
α = numpyro.sample("α", dist.Normal(ÎŒ_α + coef_age * PatientToAge[patient_idx], Ï_α))
ÎČ = numpyro.sample("ÎČ", dist.Normal(ÎŒ_ÎČ, Ï_ÎČ))
Ï = numpyro.sample("Ï", dist.HalfNormal(100.))
FVC_est = α[PatientNum] + ÎČ[PatientNum] * Weeks
n_obs = PatientNum.shape[0]
with numpyro.plate("data", n_obs):
numpyro.sample("obs", dist.Normal(FVC_est, Ï), obs=FVC_obs)
# run mcmc
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=2000, num_warmup=2000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, **data_dict)
# custom predict function
def predict_single_patient(patient_id, weeks, age, posterior_samples, patient_encoder):
# logic to handle unseen patients:
try:
n = patient_encoder.transform([patient_id])[0]
except ValueError:
# new patient but known age
Ό_α = posterior_samples["Ό_α"]
Ï_α = posterior_samples["Ï_α"]
α = dist.Normal(ÎŒ_α + age * posterior_samples["coef_age"], Ï_α) \
.sample(random.PRNGKey(0))
ÎČ = dist.Normal(posterior_samples["ÎŒ_ÎČ"], posterior_samples["Ï_ÎČ"]) \
.sample(random.PRNGKey(0))
else:
α = posterior_samples["α"][:, n]
ÎČ = posterior_samples["ÎČ"][:, n]
mu = α[:, None] + ÎČ[:, None] * weeks
# note: I'm not including the noise term Ï in the prediction here
# but you could easily do that
return mu
weeks = jnp.arange(-12, 133)
# try on seen patient
patient = "ID00007637202177411956430"
age = train.query("Patient == @patient")["Age"].iloc[0]
mu_pred = predict_single_patient(patient, weeks, age, mcmc.get_samples(), encoder)
# try on unseen patient but with known age
patient = "this is a random id"
age = 66
mu_pred = predict_single_patient(patient, weeks, age, mcmc.get_samples(), encoder)