Sharding with MCMC consensus getting similar results to full model

So firstly I realise that there is an open issue to create a tutorial on this topic but as it is sitting as a low priority. I thought I better just ask the question.

When I shard into seperate shards I seem to get very different results to fitting the model across different shards. I am wondering if anyone has any idea what I am doing wrong.
The example is very much based on[ hierarchical example] (Bayesian Hierarchical Linear Regression — NumPyro documentation) shown in the tutorials.

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist
from jax import random
import arviz as az


train = pd.read_csv(
    "https://gist.githubusercontent.com/ucals/"
    "2cf9d101992cb1b78c2cdd6e3bac6a4b/raw/"
    "43034c39052dcf97d4b894d2ec1bc3f90f3623d9/"
    "osic_pulmonary_fibrosis.csv"
)

The model is very similar to the original example excepts it adjusts the priors to the power of 1/n_shards. This is outlined in the consensus method paper Bayes and Big Data: The Consensus Monte Carlo Algorithm – Google Research

def adjust_sigma_by_shards(sigma, n_shards):
    var = sigma **2
    var_adjust = var**(1/n_shards)
    sigma_adjust = np.sqrt(var_adjust)
    return sigma_adjust

def model(patient_code, Weeks, FVC_obs=None,n_shards = 1):
    μ_α = numpyro.sample("μ_α", dist.Normal(0.0, adjust_sigma_by_shards(500.0,n_shards)))
    σ_α = numpyro.sample("σ_α", dist.HalfNormal(adjust_sigma_by_shards(100.0,n_shards)))
    μ_β = numpyro.sample("μ_β", dist.Normal(0.0, adjust_sigma_by_shards(3.0,n_shards)))
    σ_β = numpyro.sample("σ_β", dist.HalfNormal(adjust_sigma_by_shards(3.0,n_shards)))

    n_patients = len(np.unique(patient_code))

    with numpyro.plate("plate_i", n_patients):
        α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
        β = numpyro.sample("β", dist.Normal(μ_β, σ_β))

    σ = numpyro.sample("σ", dist.HalfNormal(100.0))
    FVC_est = α[patient_code] + β[patient_code] * Weeks

    with numpyro.plate("data", len(patient_code)):
        numpyro.sample("obs", dist.Normal(FVC_est, σ), obs=FVC_obs)

so lets first fit the model with no shards, as it is fairly easy to do

rom sklearn.preprocessing import LabelEncoder

patient_encoder = LabelEncoder()
train["patient_code"] = patient_encoder.fit_transform(train["Patient"].values)

FVC_obs = train["FVC"].values
Weeks = train["Weeks"].values
patient_code = train["patient_code"].values

from numpyro.diagnostics import summary
nuts_kernel = NUTS(model)

mcmc = MCMC(nuts_kernel, num_samples=2000, num_warmup=2000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, patient_code, Weeks, FVC_obs=FVC_obs)

samples = mcmc.get_samples()
# note we drop the variables of the model which are dependent on the patient
# as this makes comparison easier later on
keep_samples = {k:v for k, v in samples.items() if k in ['μ_α', 'μ_β', 'σ', 'σ_α', 'σ_β']}
summary_params_full_model = az.summary(az.from_dict(keep_samples), hdi_prob = 0.95, kind = 'stats', round_to = 9)

Now lets assume that we want to fit the exact same model and same data across 5 shards.
It is recommended to not split across groups. So we will shard by selecting patients into different shards.

n_shards = 5
patients = np.unique(patient_code)
assign_shards = np.random.choice(np.arange(n_shards),size = len(patients), replace=True)
sharded_patients = pd.DataFrame(dict(
    patient_code = patients,
    shard = assign_shards,
))
train = train.merge(sharded_patients, on = 'patient_code')

Now in practise we would train across multiple machines in parallel but in this example lets just run in series on the same machine. As it is easier to work with.

samples_for_each_shard = []
for _,sharded_df in train.groupby('shard'):
    FVC_obs = sharded_df["FVC"].values
    Weeks = sharded_df["Weeks"].values
    patient_code = sharded_df["patient_code"].values
    
    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, num_samples=2000, num_warmup=2000)
    rng_key = random.PRNGKey(0)
    mcmc.run(rng_key, patient_code, Weeks, FVC_obs=FVC_obs,n_shards =n_shards)

    samples = mcmc.get_samples()
    keep_samples = {k:v for k, v in samples.items() if k in ['μ_α', 'μ_β', 'σ', 'σ_α', 'σ_β']}
    samples_for_each_shard.append(keep_samples)

Now lets combine the results together using the consensus algorithm
from numpyro.infer.hmc_util import parametric, parametric_draws, consensus

combined_draws = consensus(samples_for_each_shard,num_draws = 2000, diagonal=True, rng_key=rng_key)
summary_params_sharded_model = az.summary(az.from_dict(combined_draws), hdi_prob = 0.95, kind = 'stats', round_to = 9)

and finally lets compare the results

display(summary_params_sharded_model)
display(summary_params_full_model)

which gives results that look like this.

param estimates for the sharded model

+-----+-------------+------------+-------------+-------------+
|     |        mean |         sd |    hdi_2.5% |   hdi_97.5% |
+=====+=============+============+=============+=============+
| μ_α |   -0.413051 |  0.106684  |   -0.620999 |   -0.218885 |
+-----+-------------+------------+-------------+-------------+
| μ_β |    1.20652  |  0.0857113 |    1.04298  |    1.35467  |
+-----+-------------+------------+-------------+-------------+
| σ   | 1554.26     | 16.0936    | 1524.26     | 1587.49     |
+-----+-------------+------------+-------------+-------------+
| σ_α |    2.72013  |  0.111768  |    2.49435  |    2.9281   |
+-----+-------------+------------+-------------+-------------+
| σ_β |    6.61333  |  0.261872  |    6.11098  |    7.12419  |
+-----+-------------+------------+-------------+-------------+

and model fitted over the full dataset

+-----+------------+-----------+------------+-------------+
|     |       mean |        sd |   hdi_2.5% |   hdi_97.5% |
+=====+============+===========+============+=============+
| μ_α | 2774.91    | 56.7946   | 2657.65    |  2883.81    |
+-----+------------+-----------+------------+-------------+
| μ_β |   -4.18494 |  0.434388 |   -5.02383 |    -3.34333 |
+-----+------------+-----------+------------+-------------+
| σ   |  136.649   |  2.85289  |  131.039   |   142.082   |
+-----+------------+-----------+------------+-------------+
| σ_α |  724.147   | 31.209    |  665.511   |   786.689   |
+-----+------------+-----------+------------+-------------+
| σ_β |    4.99665 |  0.363692 |    4.36008 |     5.77568 |
+-----+------------+-----------+------------+-------------+

Now we can clearly see that some of the estimates are so different, I wouldnt want to use the consensus method as a replacement. So I assume that I must be doing something wrong.

haven’t looked at your code in detail but you don’t appear to be scaling all priors e.g.

σ = numpyro.sample("σ", dist.HalfNormal(100.0))

also instead of using that helper you should just enclose all prior terms in a single
with numpyro.handlers.scale(scale=1/n): context manager. something like:

with numpyro.handlers.scale(scale=1/n):
    z = sample("prior", ...)
sample("obs", ...)  # likelihood evaluated on fraction of data

also i don’t think these consensus methods work all that well. it’s an active area of research. how big is your dataset?

are yes thank you, I was missing σ and using the handler is a lot cleaner.

My actual dataset is fairly massive it is in the order of 100-500 million rows.
So I am aiming to form each shard to be about 50-100 thousand rows.

Does have a large dataset help or hinder me with using the consensus method. Is there anything you would recommend instead?

in theory consensus approaches should work better as you get more data.

you probably need to reparameterize your model, in particular alpha and beta.

you may also need more MCMC samples.

i think you know this already but fyi: i guess if your sharding scheme is by patient you don’t need to scale alpha and beta since they are per-patient and only appear once in each shard (in contrast to the top level parameters)

btw you could also try hmcecs although i think you’d need to integrate out alpha and beta by hand first

okay thank you