Fitting models with NUTS is slow v2


I have a simple hierarchical regression model, which I’d like to fit using NUTS.

This is my code:

def initialize_parameters(S, d):
    rho_hp = (8, 2)
    L = jnp.eye(S ** 2)  # placeholder matrix, actual is application dependent
    tau_eps_hp = (1, 2)
    tau_js_hp = np.random.rand(2, d) + 1
    beta_0_hp = np.random.rand() + 1
    return rho_hp, L, tau_eps_hp, tau_js_hp, beta_0_hp

def model(X, y, rho_hp, L, tau_eps_hp, tau_js_hp, beta_0_hp):
    S = L.shape[0]
    rho = numpyro.sample("rho", dist.Beta(*rho_hp))
    Sigma_inv = rho * L + (1 - rho) * jnp.eye(S)
    Sigma_chol = numpyro.distributions.util.cholesky_of_inverse(Sigma_inv)
    tau_eps = numpyro.sample("tau_eps", dist.InverseGamma(*tau_eps_hp))
    tau_js = numpyro.sample("tau_js", dist.InverseGamma(*tau_js_hp))
    beta_0 = numpyro.sample("beta0", dist.Normal(loc=0, scale=beta_0_hp).expand_by((d, S)))
    betas = numpyro.sample("beta", dist.MultivariateNormal(loc=beta_0, scale_tril=Sigma_chol * tau_js.reshape(-1, 1, 1)))
    mean = X @ betas
    with numpyro.plate("data", len(y)):
        numpyro.sample(f"y", dist.MultivariateNormal(loc=mean, scale_tril=Sigma_chol * tau_eps), obs=y)

T = 1000
d = 28
S = 15
X = np.random.randn(T, d)
beta = np.random.randn(d, S ** 2)
y = X @ beta + np.random.randn(T, S ** 2)
rho_hp, L, tau_eps_hp, tau_js_hp, beta_0_hp = initialize_parameters(S, d)

nuts_kernel = NUTS(model, max_tree_depth=10, dense_mass=True)
#nuts_kernel = SA(model)
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=100, progress_bar=True), X=X, y=y, rho_hp=rho_hp, L=L, tau_eps_hp=tau_eps_hp, tau_js_hp=tau_js_hp, beta_0_hp=beta_0_hp)

Initially, it starts off relatively quickly, with 20 seconds per sample till 25 samples, and then regresses to 250 seconds per sample. Can someone suggest what I could do to speed it up? Thanks in advance.

what is the dimensionality of your latent space? i would see if you can get away with dense_mass=False and lower max_tree_depth (e.g. 6 or 8)

So I was speaking with @neerajprad about this on PyTorch Slack who suggested setting a higher tree depth if the maximum length is hit a lot of the time, which happens.

I don’t quite understand what you mean by latent space, are you referring to the sample sizes of the intermediate RVs?

i mean the total dimensionality of the all the unobserved latent random variables.

can you explain the logic of your model? you may be running into the folk theorem..

for example you could replace these two sample statements by a single equivalent multivariate normal:

beta_0 = numpyro.sample("beta0", dist.Normal(loc=0, scale=beta_0_hp).expand_by((d, S)))
betas = numpyro.sample("beta", dist.MultivariateNormal(loc=beta_0, scale_tril=Sigma_chol * tau_js.reshape(-1, 1, 1)))

Sigma is a matrix of size S x S.
tau_eps_hp is a 2-tuple of floats, tau_js_hp is a 2-tuple of vectors, each of which is of size d

tau ~ InvGamma(*tau_js_hp) # tau is of size d
tau_eps ~ InvGamma(*tau_eps_hp) # tau_eps is a float
beta_0 ~ N(0, beta_0_hp I) # beta_0 is of size d * S
beta_i | beta_0, tau ~ N(0, Sigma * tau_i ** 2) # beta_i is of size S

X is of size T x d, y is of size T x S
The linear model is y = X beta + epsilon, where beta is the collection of beta_i.
the covariance of epsilon is Sigma * tau_eps ** 2, so y_{i} ~ N(0, Sigma * tau_eps ** 2) for i in {1, …, T}

I wish there was LaTeX available for better formatting, apologies for the weird notations.

Thank you for taking a look.

that helps thanks but i mean what are you actually trying to accomplish? what is L? i would start by simplifying your model and seeing at what point HMC starts struggling. you can also integrate out some of these normal random variables by hand.

Alright thank you, I’ll dig deeper, and try simplifying. L is a Laplacian Matrix; the model is a Bayesian spatial model.

I was originally under the impression that I could simply write a model and let it sort out the details internally.

there is some funsor-driven collapse machinery but there is not currently any automagical integration logic that works by default and addresses all conceivable cases

Thanks, I’ll take a look.

There’s a bug in the test:, it’s supposed to model2.

thanks for pointing that out @vishwakftw !

Just want to jump into here a bit. The latent size is (12630,), which is too high for NUTS. Using dense_mass=True would make things much worse. Looking like a solution here is to marginalize out beta_0, betas by playing with math a bit

X @ betas ~ X @ beta_0 + Sigma_chol * tau_js.reshape(-1, 1, 1)) @ eps2
          ~ X @ diag(beta_0_hp) @ ep1 \
            + Sigma_chol * tau_js.reshape(-1, 1, 1)) @ eps2
y ~ X @ betas + (Sigma_chol * tau_eps) * eps1
  ~ (A @ ep1 + B @ eps2 + C @ eps3)
  ~ MVN(0, cholesky(AAt + BBt + CCt))

After this, you can use NUTS with dense_mass=True (if needed) to sample over variables ‘rho’, ‘tau_eps’, ‘tau_js’`.

I think collapse can handle the above math. Currently, it can handle Normal-Normal pattern like in this test but I believe it should work for MVN too. What do you think, @fritzo?

edit: to deal with expanded distribution like in the above model, we need a fix for collapse in this pr. I will try to add some MVN tests in that PR to see if things work.

Thanks very much @fehiepsi. However, I need to access betas for application specific diagnostics, which means that I can’t marginalize them out. I can however marginalize out beta_0, which is what I’m trying to do at the moment.

In that case, I think it is better to use SVI. It is unlikely that HMC will work for this high dimensional model. :frowning:

Perhaps, I will try SVI too – thanks for the suggestion. But, I could reduce the runtime drastically by marginalizing out beta_0 to about 15 secs per sample on Colab (all times from before were on Colab).

you can recover the marginalized latent variables after inference. generally speaking it makes sense to marginalize out everything that you can.

1 Like

If you want to stick with using MCMC to sample betas then one trick that will be very helpful for mixing is reparameterization:

betas_base = sample("betas_base", Normal(0, 1).expand([d, S]))
betas = numpyro.deterministic("betas", loc + scale_tril @ betas_base)

Also, making sure that dense_mass=False (the latent size ~6000 is quite high for dense_mass=True setting).

recover the marginalized latent variables after inference

How to do it, @martinjankowiak?

Yeah, that’s precisely what I’m doing! I also set dense_mass to be True.

I have the same question as @fehiepsi, @martinjankowiak how do you sample betas without an explicit sample site in the model.

well, this may not be easy to do automagically with the current state of funsor/numpyro integration but algorithmically you want to do the following.

suppose your model has observed random variable x and latent random variables y and z: p(x|y,z)p(y,z). suppose you can integrate out y analytically so that you can compute p(x|z)p(z). then you can do HMC on p(x|z)p(z) and end up with a bag of samples {z_k}. what you want is a bag of samples {y_k, z_k}. to do that you sample y_k ~ p(y|x,z_k) where p(y|...) is the conditional posterior for y. by assumption you can compute the latter distribution analytically.