Hi @fehiepsi, sure! A simplified version of my actual model / use case would look like this:
Imports
import numpy as np
import jax
from jax import random
from numpyro import sample, deterministic, plate
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist
def model(gender_data=None, diag_data=None):
# some population-level params
disease_a_male = sample("disease_a_male", dist.Normal(loc=0.0, scale=2.0))
disease_b_male = sample("disease_b_male", dist.Normal(loc=0.0, scale=2.0))
diagnosis_logits_mdd = sample("diagnosis_logits_mdd", dist.Normal(loc=0.0, scale=2.0))
diagnosis_logits_male = sample("diagnosis_logits_male", dist.Normal(loc=0.0, scale=1.0))
nobs = 1 if gender_data is None else len(gender_data)
# individual observations / samples
with plate("N", nobs):
# observed site A
male = sample("male", dist.Bernoulli(np.ones((nobs,))*0.5), obs=gender_data)
# continuous-valued latent state in [0, 1] ~ Kumaraswamy(a,b)
disease_a = deterministic("disease_a", jax.nn.softplus(disease_a_male * male))
disease_b = deterministic("disease_b", jax.nn.softplus(disease_b_male * male))
#mdd_state = sample("mdd_state", dist.Kumaraswamy(disease_a, disease_b))
# basic reparametrization of Kumaraswamy(a,b) in terms of Unif(0,1).
mdd_base = sample("mdd_base", dist.Uniform(np.zeros((nobs,)), np.ones((nobs,))))
mdd_state = deterministic("mdd_state", (1 - (1 - mdd_base)**(1/disease_b))**(1/disease_a))
# observed site B
diagnosis_logits = deterministic("diagnosis_logits", diagnosis_logits_mdd * mdd_state + diagnosis_logits_male * male)
mdd_diagnosis = sample("mdd_diagnosis", dist.Bernoulli(logits=diagnosis_logits), obs=diag_data)
# Simulate data
data = Predictive(model, num_samples=40000)(random.PRNGKey(1))
# Infer params from simulated data
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=2000, num_chains=4, num_warmup=500)
mcmc.run(random.PRNGKey(0), gender_data=data['male'].squeeze(), diag_data=data['mdd_diagnosis'].squeeze())
After completion, the mcmc object will contain float32 arrays of shape (n_mcmc, n_obs) = (8k, 400k) for each of disease_a
, disease_b
, mdd_base
, mdd_state
, and diagnosis_logits
, which each of them taking up 12.8 GB of memory. (And again, more as a side-note, my impression was that some not-in-place array copying might be going on during the final stages / sample collection of the mcmc.run() call?)
Is there any way to make the MCMC call not store all of these per-sample variables if I only need a trace for some of them? Is there a way to make them float16 instead of float32 (and would that be a good idea)? Any other tips or tricks?