So firstly I realise that there is an open issue to create a tutorial on this topic but as it is sitting as a low priority. I thought I better just ask the question.
When I shard into seperate shards I seem to get very different results to fitting the model across different shards. I am wondering if anyone has any idea what I am doing wrong.
The example is very much based on[ hierarchical example] (Bayesian Hierarchical Linear Regression — NumPyro documentation) shown in the tutorials.
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist
from jax import random
import arviz as az
train = pd.read_csv(
"https://gist.githubusercontent.com/ucals/"
"2cf9d101992cb1b78c2cdd6e3bac6a4b/raw/"
"43034c39052dcf97d4b894d2ec1bc3f90f3623d9/"
"osic_pulmonary_fibrosis.csv"
)
The model is very similar to the original example excepts it adjusts the priors to the power of 1/n_shards. This is outlined in the consensus method paper Bayes and Big Data: The Consensus Monte Carlo Algorithm – Google Research
def adjust_sigma_by_shards(sigma, n_shards):
var = sigma **2
var_adjust = var**(1/n_shards)
sigma_adjust = np.sqrt(var_adjust)
return sigma_adjust
def model(patient_code, Weeks, FVC_obs=None,n_shards = 1):
μ_α = numpyro.sample("μ_α", dist.Normal(0.0, adjust_sigma_by_shards(500.0,n_shards)))
σ_α = numpyro.sample("σ_α", dist.HalfNormal(adjust_sigma_by_shards(100.0,n_shards)))
μ_β = numpyro.sample("μ_β", dist.Normal(0.0, adjust_sigma_by_shards(3.0,n_shards)))
σ_β = numpyro.sample("σ_β", dist.HalfNormal(adjust_sigma_by_shards(3.0,n_shards)))
n_patients = len(np.unique(patient_code))
with numpyro.plate("plate_i", n_patients):
α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
β = numpyro.sample("β", dist.Normal(μ_β, σ_β))
σ = numpyro.sample("σ", dist.HalfNormal(100.0))
FVC_est = α[patient_code] + β[patient_code] * Weeks
with numpyro.plate("data", len(patient_code)):
numpyro.sample("obs", dist.Normal(FVC_est, σ), obs=FVC_obs)
so lets first fit the model with no shards, as it is fairly easy to do
rom sklearn.preprocessing import LabelEncoder
patient_encoder = LabelEncoder()
train["patient_code"] = patient_encoder.fit_transform(train["Patient"].values)
FVC_obs = train["FVC"].values
Weeks = train["Weeks"].values
patient_code = train["patient_code"].values
from numpyro.diagnostics import summary
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=2000, num_warmup=2000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, patient_code, Weeks, FVC_obs=FVC_obs)
samples = mcmc.get_samples()
# note we drop the variables of the model which are dependent on the patient
# as this makes comparison easier later on
keep_samples = {k:v for k, v in samples.items() if k in ['μ_α', 'μ_β', 'σ', 'σ_α', 'σ_β']}
summary_params_full_model = az.summary(az.from_dict(keep_samples), hdi_prob = 0.95, kind = 'stats', round_to = 9)
Now lets assume that we want to fit the exact same model and same data across 5 shards.
It is recommended to not split across groups. So we will shard by selecting patients into different shards.
n_shards = 5
patients = np.unique(patient_code)
assign_shards = np.random.choice(np.arange(n_shards),size = len(patients), replace=True)
sharded_patients = pd.DataFrame(dict(
patient_code = patients,
shard = assign_shards,
))
train = train.merge(sharded_patients, on = 'patient_code')
Now in practise we would train across multiple machines in parallel but in this example lets just run in series on the same machine. As it is easier to work with.
samples_for_each_shard = []
for _,sharded_df in train.groupby('shard'):
FVC_obs = sharded_df["FVC"].values
Weeks = sharded_df["Weeks"].values
patient_code = sharded_df["patient_code"].values
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=2000, num_warmup=2000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, patient_code, Weeks, FVC_obs=FVC_obs,n_shards =n_shards)
samples = mcmc.get_samples()
keep_samples = {k:v for k, v in samples.items() if k in ['μ_α', 'μ_β', 'σ', 'σ_α', 'σ_β']}
samples_for_each_shard.append(keep_samples)
Now lets combine the results together using the consensus algorithm
from numpyro.infer.hmc_util import parametric, parametric_draws, consensus
combined_draws = consensus(samples_for_each_shard,num_draws = 2000, diagonal=True, rng_key=rng_key)
summary_params_sharded_model = az.summary(az.from_dict(combined_draws), hdi_prob = 0.95, kind = 'stats', round_to = 9)
and finally lets compare the results
display(summary_params_sharded_model)
display(summary_params_full_model)
which gives results that look like this.
param estimates for the sharded model
+-----+-------------+------------+-------------+-------------+
| | mean | sd | hdi_2.5% | hdi_97.5% |
+=====+=============+============+=============+=============+
| μ_α | -0.413051 | 0.106684 | -0.620999 | -0.218885 |
+-----+-------------+------------+-------------+-------------+
| μ_β | 1.20652 | 0.0857113 | 1.04298 | 1.35467 |
+-----+-------------+------------+-------------+-------------+
| σ | 1554.26 | 16.0936 | 1524.26 | 1587.49 |
+-----+-------------+------------+-------------+-------------+
| σ_α | 2.72013 | 0.111768 | 2.49435 | 2.9281 |
+-----+-------------+------------+-------------+-------------+
| σ_β | 6.61333 | 0.261872 | 6.11098 | 7.12419 |
+-----+-------------+------------+-------------+-------------+
and model fitted over the full dataset
+-----+------------+-----------+------------+-------------+
| | mean | sd | hdi_2.5% | hdi_97.5% |
+=====+============+===========+============+=============+
| μ_α | 2774.91 | 56.7946 | 2657.65 | 2883.81 |
+-----+------------+-----------+------------+-------------+
| μ_β | -4.18494 | 0.434388 | -5.02383 | -3.34333 |
+-----+------------+-----------+------------+-------------+
| σ | 136.649 | 2.85289 | 131.039 | 142.082 |
+-----+------------+-----------+------------+-------------+
| σ_α | 724.147 | 31.209 | 665.511 | 786.689 |
+-----+------------+-----------+------------+-------------+
| σ_β | 4.99665 | 0.363692 | 4.36008 | 5.77568 |
+-----+------------+-----------+------------+-------------+
Now we can clearly see that some of the estimates are so different, I wouldnt want to use the consensus method as a replacement. So I assume that I must be doing something wrong.