Problems sampling a hierarchical model with gaussian processes

Hi dear pyro/Numpyro forum,
I have a Numpyro Model that I want to fit using MCMC. I’m using the NUTS sampler.

I’m using a single GPU and am running into both memory allocation issues, and for larger amount of locations (about 8000, modeled using a gaussian process) the sampling gets way to slow (minutes for one sample iteration). Also it gets slower the longer the sampling goes on and with higher acceptance probability. For smaller data the model works well!

I would like to know if there are obvious modeling mistakes or potential ideas to improve the situation.
I am also quite new to numpyro. Or is NUTs maybe simply not suitable for this dimensionality?

Ideas I tried or was thinking about so far:

  • running it with smaller data (works quite well, and the results are good!)
    main problem seems to be the latent space dimension/amount of locations, which also is needed for the covariance matrix of the GP
  • using plates as much as possible (is there any other way I could do this?) I tried to put more things into plates but did not find any immediate improvements
  • reparametrization
  • simplifying the model to find the hard part
  • Would batching work here? Sadly currently it is super slow even with only a few samples.
  • Trying Hilbert space approximate gaussian processes, as in the example in numpyro?

A more detailed description of the model is below


# precalculated values
# all_locs = unique locations as (lat, lon) tuples
# location_distance_matrix: pairwise euclidean distances between locations as a matrix, used as covariance matrix in spatial GP
# locations: list of numbers representing a location each

# helper functions
def exponential_covariogram(dist, gamma):
return gamma[0] * jnp.exp(-jnp.divide(jnp.abs(dist), gamma[1]))

def stable_inv_logit(x):
return 0.5*(1. + jnp.sign(x)*(2./(1. + jnp.exp(-jnp.abs(x))) - 1.))

# model
def Numpyro_Model(locations, times, previous_values, y=None):
# locations: array of integers, each number represents a location, for each datapoint in y, shape: (number of datapoints)
# times: array of integers, each number represents a time, for each datapoint in y,  shape: (number of datapoints)
# previous_values: array of previous results of a different model, shape: (number of datapoints)
# y: observed datapoints, array of 1s and 0s, shape: (number of datapoints)

with numpyro.plate('gammas', 2):
	gamma_map = numpyro.sample("gamma_map", dist.Normal(0,1))
	gamma = softplus(gamma_map)

Sigma = exponential_covariogram(location_distance_matrix, gamma)
beta = numpyro.sample('beta', dist.Normal(0,2.5))
mu = jnp.full(number_of_locations, beta)
lambda_ = numpyro.sample("lambda_active", dist.MultivariateNormal(mu, Sigma))
#spatial GP ^, I guess the covariance Sigma here is a problem?

alpha = numpyro.sample('alpha', dist.Normal(0,1))
mu_2 = jnp.full(number_of_timesteps, alpha)
tau = numpyro.sample("tau_active", dist.MultivariateNormal(mu_2,jnp.eye(number_of_timesteps)))
# no real GP here, may be added after it works well with one GP

sumterm = previous_values + lambda_[locations]*tau[times]

theta = numpyro.deterministic('theta', stable_inv_logit(sumterm))

with numpyro.plate('data', len(locations)):
	numpyro.sample('obs', dist.Bernoulli(theta), obs=y)

The model is a hierarchical model, with a spatial and (potentially) a temporal gaussian process prior.
The datapoints correspond to a location and timestep each, but are given as 1d array. The model samples 1s and 0s from a bernoulli distribution parametrized by theta. Theta is a transform of a sumterm, which incorporates previous results of a different model and the last term added to the sumterm consists of a product of lambda and tau, which are sampled from a spatial and a temporal gaussian process (at least spatially for lambda, the temporal one for tau is not as important). The gaussian process for lambda is parametrized by a mean vector mu and covariance Sigma, where Sigma has an hyperparameter gamma again. mu has an element for each location, Sigma is a kernel of the distance of the locations.
dimension of data y: (number of datapoints = number of locations * number of timesteps =~ 8.000.000)
dimensions of mu: (number of locations = ideally 8000), dimensions of Sigma: (number of locations, number of locations)
dimensions of lambda: (number of locations), dimensions of tau: (number of timesteps = roughly between 200 and ideally 6000)
dimensions of the sumterm: (number of datapoints), dimensions of theta: (number of datapoints)

sorry for the long description, I was not sure how to shorten the model even more.
I am glad for any help or ideas!

this looks like a challenging problem for MCMC inference and may be at (or beyond) the limit of what is possible, especially without lots of custom work

it’s hard to give a reliable rule of thumb but probably the viability range for most black-box MCMC algorithms is something like 100-10^4 or maybe 10^5 datapoints—8 million is a lot.

if your GP input dimension is low (which it appears to be) the Hilbert space approximation should give you large speed-ups. i’d try to get that working on a subset of your data and see how much that allows you to push things (probably not to 8 million).

at that point you might try using HMCECS (demo). maybe that’ll get you to 8 million but it’s hard to say.

Thanks so much for the reply!
I will try to only work on a subset, or not use a GP for the temporal part data, and I hope maybe I could use some kind of batching then. This way I could decrease the amount of datapoints by quite a bit.
Also thanks for the confirmation that the Hilbert space approximation might be worth a try.
For a dataset of size number of locations(1320)*timesteps(2000) ~= 3 million,
I already managed to get okay-ish results before. But I have to say that the sampling seemed to have some small issues there already.
I will give an update when I have any improved results.