Maybe there’s a more elegant way of doing it but my code usually looks like the code below. Basically, I have a custom predict function that handles the logic for unseen patients with a try-except-else
. To decide if the patient is “seen” or “new” I use sklearn’s LabelEncoder()
. The example below is minimal, only controlling for age
, but you can easily extend it to control for sex
or smoking_status
.
# not showing import statements nor the code for loading the data
encoder = LabelEncoder()
encoder.fit(train["Patient"])
N_PATIENTS = len(encoder.classes_)
patient_to_age = train[['Patient', 'Age']] \
.drop_duplicates() \
.set_index("Patient") \
.loc[encoder.classes_, "Age"] \
.values
data_dict=dict(
PatientNum = jnp.array(encoder.transform(train['Patient'].values)),
Weeks = jnp.array(train['Weeks'].values),
PatientToAge = jnp.array(patient_to_age),
FVC_obs = jnp.array(train['FVC'].values),
)
def model(PatientNum, Weeks, PatientToAge, FVC_obs=None):
μ_α = numpyro.sample("μ_α", dist.Normal(0., 100.))
σ_α = numpyro.sample("σ_α", dist.HalfNormal(100.))
μ_β = numpyro.sample("μ_β", dist.Normal(0., 100.))
σ_β = numpyro.sample("σ_β", dist.HalfNormal(100.))
# Age's effect on α
coef_age= numpyro.sample("coef_age", dist.Normal(0., 100.,))
with numpyro.plate("plate_i", N_PATIENTS) as patient_idx:
α = numpyro.sample("α", dist.Normal(μ_α + coef_age * PatientToAge[patient_idx], σ_α))
β = numpyro.sample("β", dist.Normal(μ_β, σ_β))
σ = numpyro.sample("σ", dist.HalfNormal(100.))
FVC_est = α[PatientNum] + β[PatientNum] * Weeks
n_obs = PatientNum.shape[0]
with numpyro.plate("data", n_obs):
numpyro.sample("obs", dist.Normal(FVC_est, σ), obs=FVC_obs)
# run mcmc
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=2000, num_warmup=2000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, **data_dict)
# custom predict function
def predict_single_patient(patient_id, weeks, age, posterior_samples, patient_encoder):
# logic to handle unseen patients:
try:
n = patient_encoder.transform([patient_id])[0]
except ValueError:
# new patient but known age
μ_α = posterior_samples["μ_α"]
σ_α = posterior_samples["σ_α"]
α = dist.Normal(μ_α + age * posterior_samples["coef_age"], σ_α) \
.sample(random.PRNGKey(0))
β = dist.Normal(posterior_samples["μ_β"], posterior_samples["σ_β"]) \
.sample(random.PRNGKey(0))
else:
α = posterior_samples["α"][:, n]
β = posterior_samples["β"][:, n]
mu = α[:, None] + β[:, None] * weeks
# note: I'm not including the noise term σ in the prediction here
# but you could easily do that
return mu
weeks = jnp.arange(-12, 133)
# try on seen patient
patient = "ID00007637202177411956430"
age = train.query("Patient == @patient")["Age"].iloc[0]
mu_pred = predict_single_patient(patient, weeks, age, mcmc.get_samples(), encoder)
# try on unseen patient but with known age
patient = "this is a random id"
age = 66
mu_pred = predict_single_patient(patient, weeks, age, mcmc.get_samples(), encoder)