Hello,
I have a simple hierarchical regression model, which I’d like to fit using NUTS.
This is my code:
def initialize_parameters(S, d):
rho_hp = (8, 2)
L = jnp.eye(S ** 2) # placeholder matrix, actual is application dependent
tau_eps_hp = (1, 2)
tau_js_hp = np.random.rand(2, d) + 1
beta_0_hp = np.random.rand() + 1
return rho_hp, L, tau_eps_hp, tau_js_hp, beta_0_hp
def model(X, y, rho_hp, L, tau_eps_hp, tau_js_hp, beta_0_hp):
S = L.shape[0]
rho = numpyro.sample("rho", dist.Beta(*rho_hp))
Sigma_inv = rho * L + (1 - rho) * jnp.eye(S)
Sigma_chol = numpyro.distributions.util.cholesky_of_inverse(Sigma_inv)
tau_eps = numpyro.sample("tau_eps", dist.InverseGamma(*tau_eps_hp))
tau_js = numpyro.sample("tau_js", dist.InverseGamma(*tau_js_hp))
beta_0 = numpyro.sample("beta0", dist.Normal(loc=0, scale=beta_0_hp).expand_by((d, S)))
betas = numpyro.sample("beta", dist.MultivariateNormal(loc=beta_0, scale_tril=Sigma_chol * tau_js.reshape(-1, 1, 1)))
mean = X @ betas
with numpyro.plate("data", len(y)):
numpyro.sample(f"y", dist.MultivariateNormal(loc=mean, scale_tril=Sigma_chol * tau_eps), obs=y)
T = 1000
d = 28
S = 15
X = np.random.randn(T, d)
beta = np.random.randn(d, S ** 2)
y = X @ beta + np.random.randn(T, S ** 2)
rho_hp, L, tau_eps_hp, tau_js_hp, beta_0_hp = initialize_parameters(S, d)
nuts_kernel = NUTS(model, max_tree_depth=10, dense_mass=True)
#nuts_kernel = SA(model)
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=100, progress_bar=True)
mcmc.run(rng_key_, X=X, y=y, rho_hp=rho_hp, L=L, tau_eps_hp=tau_eps_hp, tau_js_hp=tau_js_hp, beta_0_hp=beta_0_hp)
Initially, it starts off relatively quickly, with 20 seconds per sample till 25 samples, and then regresses to 250 seconds per sample. Can someone suggest what I could do to speed it up? Thanks in advance.