Low effective sample size

Hello everybody, i am relatively new to Bayesian inference and wanted to try out a hierarchical model. Basically what i want to do is to estimate parameters for individual patients, similar to the docu. The Forward problem is defined by an ODE. The code below is running, but the estimates are just completely off. This is probably due to a high Rhat and a low ESS. Can anybody tell me how to fix this? Also some tips about speed up would really be appreciated!

import dill
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import pandas as pd
from diffrax import ConstantStepSize, Dopri5, ODETerm, SaveAt, diffeqsolve
from jax import random
from numpyro.infer import MCMC, NUTS  # , Predictive
from sklearn.preprocessing import LabelEncoder

# from fem_cycle_model.fetch_params import fetch_params
# from fem_cycle_model.main import run_model
pd.options.mode.chained_assignment = None  # default='warn'
assert numpyro.__version__.startswith("0.11.0")

# Select the number of cores that numpyro will use
numpyro.set_host_device_count(4)


# Von A01 bis A026 und B01 bis B024
# function that merges dataframes
def merge_dataframes(dataframes):
    df = pd.DataFrame()
    for x in dataframes:
        df = pd.concat([df, x])
    return df


# insert a column in a dataframe
def insert_column(df, column_name, column_values):
    df.insert(0, column_name, column_values)
    return df


# make an array with n tmes the same string
def make_array(string, n):
    array = []
    for i in range(n):
        array.append(string)
    return array


# Write function that creates the patient Ids for the females A01 until A026 and B01 until B024 for the cycles 1 and 2
def string_id_generator():
    list = []
    for j in range(1, 3):
        for i in range(1, 27):
            list.append("A" + str(i).zfill(2) + "_" + str(j))
        for i in range(1, 25):
            list.append("B" + str(i).zfill(2) + "_" + str(j))
    return list


def string_desplitter(code):
    return code.split("_")[0], int(code.split("_")[1])


def test_ode(t, z, theta):
    """
    Lotka–Volterra equations. Real positive parameters `alpha`, `beta`, `gamma`, `delta`
    describes the interaction of two species.
    """
    u = z[0]
    v = z[1]
    alpha, beta, gamma, delta = (
        theta[0],
        theta[1],
        theta[2],
        theta[3],
    )
    du_dt = (alpha - beta * v) * u
    dv_dt = (-gamma + delta * u) * v
    return jnp.stack([du_dt, dv_dt])


# Create a linspace of time points 0, 0.2, 0.4, ...


def sim(params, patient_id=None):
    term = ODETerm(test_ode)
    solver = Dopri5()
    saveat = SaveAt(ts=np.linspace(0, 10, 50))
    stepsize_controller = ConstantStepSize()
    sol = diffeqsolve(
        term,
        solver,
        t0=0,
        t1=10,
        dt0=0.1,
        y0=jnp.array([1, 0.2]),
        saveat=saveat,
        stepsize_controller=stepsize_controller,
        args=params,
    )
    if patient_id is not None:
        df = pd.DataFrame(sol.ys)
        df = df.rename(columns={0: "dom", 1: "non_dom"})
        df = insert_column(df, "patient_id", patient_id)
        return df
    else:
        return sol.ys[:, 0], sol.ys[:, 1]


patient_IDs = string_id_generator()
patient_IDs_test = patient_IDs[0:3]


def run_model_data(patient_IDs, params):
    df = pd.DataFrame()
    for i, patient_id in enumerate(patient_IDs):
        params_i = jnp.array(params) * (i + 1)
        df = pd.concat([df, sim(params_i, patient_id)])
    return df


patient_data = run_model_data(patient_IDs_test, [2, 3, 4, 5])


def run_model_all(params):
    non_dom_foll_array = jnp.array([])
    dom_foll_array = jnp.array([])
    for i in range(len(params[1])):
        dom_foll, non_dom_foll = sim([params[j][i] for j in range(len(params))])
        non_dom_foll_array = jnp.append(non_dom_foll_array, non_dom_foll)
        dom_foll_array = jnp.append(dom_foll_array, dom_foll)
    return dom_foll_array, non_dom_foll_array


encoder = LabelEncoder()
encoder.fit(patient_IDs_test)
N_PATIENTS = len(encoder.classes_)


def model(patient_code, dom_foll=None, non_dom_foll=None):
    μ_foll_alpha = numpyro.sample("μ_foll_alpha", dist.Uniform(1.0, 7.0))
    σ_foll_alpha = numpyro.sample("σ_foll_alpha", dist.LeftTruncatedDistribution(dist.Normal(3.0, 2.0)))

    μ_foll_beta = numpyro.sample("μ_foll_beta", dist.Uniform(2.0, 10.0))
    σ_foll_beta = numpyro.sample("σ_foll_beta", dist.LeftTruncatedDistribution(dist.Normal(4.0, 2.0)))

    μ_foll_gamma = numpyro.sample("μ_foll_gamma", dist.Uniform(3.0, 13.0))
    σ_foll_gamma = numpyro.sample("σ_foll_gamma", dist.LeftTruncatedDistribution(dist.Normal(5.0, 2.0)))

    μ_foll_delta = numpyro.sample("μ_foll_delta", dist.Uniform(4.0, 16.0))
    σ_foll_delta = numpyro.sample("σ_foll_delta", dist.LeftTruncatedDistribution(dist.Normal(7.0, 2.0)))

    σ = numpyro.sample("σ", dist.LeftTruncatedDistribution(dist.Normal(0.3, .5)))

    with numpyro.plate("plate_i", N_PATIENTS):
        foll_alpha = numpyro.sample("foll_alpha", dist.Normal(μ_foll_alpha, σ_foll_alpha))
        foll_beta = numpyro.sample("foll_beta", dist.Normal(μ_foll_beta, σ_foll_beta))
        foll_gamma = numpyro.sample("foll_gamma", dist.Normal(μ_foll_gamma, σ_foll_gamma))
        foll_delta = numpyro.sample("foll_delta", dist.Normal(μ_foll_delta, σ_foll_delta))

    foll_dom_est, foll_non_dom_est = run_model_all(
        [
            foll_alpha[patient_code],
            foll_beta[patient_code],
            foll_gamma[patient_code],
            foll_delta[patient_code],
        ]
    )
    numpyro.sample("obs_dom", dist.Normal(foll_dom_est, σ), obs=dom_foll)
    numpyro.sample("obs_non_dom", dist.Normal(foll_non_dom_est, σ), obs=non_dom_foll)


data_dict = dict(
    dom_foll=jnp.array(patient_data.dom),
    non_dom_foll=jnp.array(patient_data.non_dom),
    # log_radon_test=jnp.array(100*jnp.ones(len(df2.log_radon))),
)

# Transform all of the counties into integers, which will be used as the county variable in the model
patient_code = jnp.array(encoder.transform(patient_IDs_test))

# Add the county variable to the data dictionary
data_dict.update({"patient_code": patient_code})

# Specify the number of chains in the Markov Chain Monte Carlo. Typically set to the nmber of cores in the computer
mcmc_kwargs = dict(num_samples=2000, num_warmup=2000, num_chains=4)

# Select a random key and split it into different parts. This guarantees that we get the same result each time and
# will lead to reproducable results. For more see:
# https://ericmjl.github.io/dl-workshop/02-jax-idioms/03-deterministic-randomness.html
rng_key = random.PRNGKey(12)
seed1, seed2, seed3, seed4, seed5 = random.split(rng_key, 5)


inference_mcmc = MCMC(NUTS(model, init_strategy=numpyro.infer.init_to_sample()), **mcmc_kwargs)
inference_mcmc.run(seed1, **data_dict)

# This block lets the posterior be pickled
inference_mcmc.sampler._sample_fn = None  # pylint: disable=protected-access
inference_mcmc.sampler._init_fn = None  # pylint: disable=protected-access
inference_mcmc.sampler._postprocess_fn = None  # pylint: disable=protected-access
inference_mcmc.sampler._potential_fn = None  # pylint: disable=protected-access
inference_mcmc.sampler._potential_fn_gen = None  # pylint: disable=protected-access
inference_mcmc._cache = {}  # pylint: disable=protected-access

# Saving the posterior
with open("savemcmc.pkl", "wb") as f:
    dill.dump(inference_mcmc, f)

print(inference_mcmc.print_summary())

The summary is the following:

               mean       std    median      5.0%     95.0%     n_eff     r_hat

foll_alpha[0] 1.51 0.67 1.71 0.28 2.27 2.45 2.27
foll_alpha[1] 1.82 1.21 1.37 0.51 3.88 2.21 3.15
foll_alpha[2] 1.90 0.78 1.73 0.65 3.19 3.05 1.86
foll_beta[0] 2.43 0.92 2.58 0.68 3.57 3.02 1.69
foll_beta[1] 3.62 1.69 3.24 1.35 6.16 2.60 2.03
foll_beta[2] 4.55 1.98 4.11 1.53 7.75 3.30 1.68
foll_delta[0] 6.57 2.44 5.74 3.78 10.42 3.59 1.51
foll_delta[1] 13.43 3.39 12.96 8.34 18.96 13.30 1.11
foll_delta[2] 14.33 3.72 14.27 7.72 20.07 15.41 1.17
foll_gamma[0] 5.40 2.29 4.59 2.86 8.96 3.52 1.53
foll_gamma[1] 10.36 2.82 9.95 6.02 15.04 11.74 1.12
foll_gamma[2] 12.35 3.17 12.26 6.85 17.35 14.76 1.16
μ_foll_alpha 2.13 1.01 1.85 1.00 3.45 10.88 1.11
μ_foll_beta 4.06 1.62 3.65 2.00 6.41 12.86 1.10
μ_foll_delta 10.79 2.95 11.01 6.57 15.95 397.68 1.01
μ_foll_gamma 8.83 2.34 9.01 5.43 12.83 76.82 1.03
σ 0.46 0.02 0.46 0.43 0.49 51.93 1.04
σ_foll_alpha 1.65 1.32 1.29 0.01 3.52 7.84 1.16
σ_foll_beta 2.78 1.70 2.62 0.02 5.10 10.99 1.11
σ_foll_delta 6.64 1.81 6.60 3.72 9.64 1975.99 1.01
σ_foll_gamma 4.93 1.64 4.83 2.17 7.46 917.31 1.02

Number of divergences: 36

some tips/questions:

  • how many patients? if the dimensionality of the latent space is large this might be a very challenging inference problem
  • use 64 bit precision => enable_x64
  • see bad geometry tutorial
  • best to do all pandas manipulations outside of model() in a pre-preprocessing step
  • try simpler versions of the model with fewer latent variables before jumping into the full complexity of what you want
  • wherever possible replace for loops with vectorized code

Thank you for your answer :slight_smile:
So in this case there are only 3 patients, but the number will be bumped up to 87! So yes, in total there will be 87*7 parameters.
I will try the 64 bit precision for sure and check out the geometry link.
Honestly, this is already an extremly simplified version of the original problem, so i am quite astonished that the results here are already kind of bad. Also because i have seen some talks about MCMC where people simulate a few thousand or even million parameters with Bayesian inference.
I will try also the vectorization, but i think it will not be possible in some cases.

But the general model set up in my case is right? (In terms of the definition of the model function)

your code is far too long for me to read in detail, especially given all the extraneous data stuff.

all models are different. sometimes doing inference on a model with 10 parameters is much more difficult than doing inference on a model with 1000 parameters. there are no universal rules of thumb.

So i shortened the code and now i just try to estimate the four parameters of the Lotka-Volterra equation for one single set of observations. The problem is still there, except for the parameter alpha. For each parameter i assume a Hyperprior, just like in the hierarchichal model. This should be relaively easy for the Inference no?

import sys
# from pathlib import Path

# import arviz as az
import dill
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import pandas as pd
from diffrax import ConstantStepSize, Dopri5, ODETerm, SaveAt, diffeqsolve
from jax import random
from numpyro.infer import MCMC, NUTS  # , Predictive
from sklearn.preprocessing import LabelEncoder

# from fem_cycle_model.fetch_params import fetch_params
# from fem_cycle_model.main import run_model
numpyro.enable_x64()
pd.options.mode.chained_assignment = None  # default='warn'
assert numpyro.__version__.startswith("0.11.0")

# Select the number of cores that numpyro will use
numpyro.set_host_device_count(16)


def test_ode(t, z, theta):
    """
    Lotka–Volterra equations. Real positive parameters `alpha`, `beta`, `gamma`, `delta`
    describes the interaction of two species.
    """
    u = z[0]
    v = z[1]
    alpha, beta, gamma, delta = (
        theta[0],
        theta[1],
        theta[2],
        theta[3],
    )
    du_dt = (alpha - beta * v) * u
    dv_dt = (-gamma + delta * u) * v
    return jnp.stack([du_dt, dv_dt])


# Create a linspace of time points 0, 0.2, 0.4, ...


def sim(params):
    term = ODETerm(test_ode)
    solver = Dopri5()
    saveat = SaveAt(ts=np.linspace(0, 10, 50))
    stepsize_controller = ConstantStepSize()
    sol = diffeqsolve(
        term,
        solver,
        t0=0,
        t1=10,
        dt0=0.1,
        y0=jnp.array([1, 0.2]),
        saveat=saveat,
        stepsize_controller=stepsize_controller,
        args=params,
    )
    return sol.ys[:, 0], sol.ys[:, 1]



def run_model_data(num_patients, params):
    y1_total = jnp.array([])
    y2_total = jnp.array([])
    for i in range(num_patients):
        params_i = jnp.array(params) * (i + 1)
        y1, y2 = sim(params_i)
        y1_total = jnp.append(y1_total, y1)
        y2_total = jnp.append(y2_total, y2)
    return y1_total, y2_total


patient_y1, patient_y2 = run_model_data(1, [2, 3, 4, 5])


def run_model_all(params):
    y1, y2 = sim(params)
    return y1, y2




def model(y1=None, y2=None):
    μ_foll_alpha = numpyro.sample("μ_foll_alpha", dist.Uniform(1.0, 7.0))
    σ_foll_alpha = numpyro.sample("σ_foll_alpha", dist.LeftTruncatedDistribution(dist.Normal(3.0, 2.0)))

    μ_foll_beta = numpyro.sample("μ_foll_beta", dist.Uniform(2.0, 10.0))
    σ_foll_beta = numpyro.sample("σ_foll_beta", dist.LeftTruncatedDistribution(dist.Normal(4.0, 2.0)))

    μ_foll_gamma = numpyro.sample("μ_foll_gamma", dist.Uniform(3.0, 13.0))
    σ_foll_gamma = numpyro.sample("σ_foll_gamma", dist.LeftTruncatedDistribution(dist.Normal(5.0, 2.0)))

    μ_foll_delta = numpyro.sample("μ_foll_delta", dist.Uniform(4.0, 16.0))
    σ_foll_delta = numpyro.sample("σ_foll_delta", dist.LeftTruncatedDistribution(dist.Normal(7.0, 2.0)))

    σ = numpyro.sample("σ", dist.LeftTruncatedDistribution(dist.Normal(0.3, .5)))

    #with numpyro.plate("plate_i", N_PATIENTS):
    foll_alpha = numpyro.sample("foll_alpha", dist.Normal(μ_foll_alpha, σ_foll_alpha))
    foll_beta = numpyro.sample("foll_beta", dist.Normal(μ_foll_beta, σ_foll_beta))
    foll_gamma = numpyro.sample("foll_gamma", dist.Normal(μ_foll_gamma, σ_foll_gamma))
    foll_delta = numpyro.sample("foll_delta", dist.Normal(μ_foll_delta, σ_foll_delta))

    y1_est, y2_est= run_model_all(
        [
            foll_alpha,
            foll_beta,
            foll_gamma,
            foll_delta,
        ]
    )
    with numpyro.plate("likelihood", len(y1)):
        numpyro.sample("obs_dom", dist.Normal(y1_est, σ), obs=y1)
        numpyro.sample("obs_non_dom", dist.Normal(y2_est, σ), obs=y2)


data_dict = dict(
    y1=patient_y1,
    y2 = patient_y2,
)


# Specify the number of chains in the Markov Chain Monte Carlo. Typically set to the nmber of cores in the computer
mcmc_kwargs = dict(num_samples=2000, num_warmup=2000, num_chains=4)

# Select a random key and split it into different parts. This guarantees that we get the same result each time and
# will lead to reproducable results. For more see:
# https://ericmjl.github.io/dl-workshop/02-jax-idioms/03-deterministic-randomness.html
rng_key = random.PRNGKey(12)
seed1, seed2, seed3, seed4, seed5 = random.split(rng_key, 5)

inference_mcmc = MCMC(NUTS(model, init_strategy=numpyro.infer.init_to_sample(), dense_mass=True), **mcmc_kwargs)
inference_mcmc.run(seed1, **data_dict)
print(inference_mcmc.print_summary())
                     mean       std    median      5.0%     95.0%     n_eff     r_hat
foll_alpha      2.08      0.40      2.00      1.51      2.91    158.08      1.04
 foll_beta      4.18      1.69      3.00      3.00      7.01      3.81      1.51
foll_delta      8.32      4.61      5.00      3.54     15.25      3.71      1.73
foll_gamma      7.45      4.68      4.00      3.11     14.89      3.55      1.76

μ_foll_alpha 3.44 2.00 3.26 1.00 6.01 3.08 1.83
μ_foll_beta 7.73 2.30 9.04 3.60 9.72 3.85 1.53
μ_foll_delta 11.21 2.25 11.38 7.11 15.28 159.77 1.03
μ_foll_gamma 9.18 2.29 9.26 7.05 12.78 4.12 1.41
σ 0.25 0.26 0.20 0.00 0.54 2.02 10.93
σ_foll_alpha 3.25 2.06 3.02 1.11 5.99 2.64 2.00
σ_foll_beta 11.28 9.78 9.97 2.19 28.08 2.02 8.77
σ_foll_delta 10.79 9.68 6.61 2.56 27.20 2.03 8.37
σ_foll_gamma 3.07 2.02 1.89 1.32 6.46 3.50 1.59

this should be relaively easy for the Inference no?

i’m not sure it depends on details e.g.

  • is ConstantStepSize() appropriate? if the ODE solution is not smooth because e.g. the step size is too small HMC will get confused
  • are the hyperpriors reasonable? an ODE that behaves well in one part of parameter space may behave poorly in another part of parameter space. depending on the context, it may be important to ensure that your hyperprior prevents you from entering bad parts of the parameter space
  • you might also initializing at a good initial parameter see e.g. here
  • Thanks for the hint, the step size is actually a good point. The problem is, by changing it to PIDController, i get the following error:
    jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: Buffer Definition Event: Generated function failed:
    Is this numpyro or Diffrax related?

  • The Hyperprior is defined in a way that it captures the real paramaters, in this example of a single observed dataset maybe a bit wide, but in the more general case i would say as narrow as it can get

  • Thanks for the initial parameter link, this might also be an option. Can you maybe tell me how to use it? So just init_to_value(guess par1, guess par 2, …) or do i need to specify the kernel noise, etc?

not numpyro

do i need to specify the kernel noise

you specify the (named) latent variables in your model. your model does not have a kernel noise