Gaussian process with hierarchical intercept for longitudinal modelling


In the aim of modelling patient health from longitudinal data, I could create a Gaussian process with a random (i.e. hierarchical) intercept in PyMC3, as shown in this discourse thread. Since PyMC3 doesn’t have a sparse approximation of latent GPs, I gave it a try with Pyro. The documentation of Pyro shows a different interface between model definition with distributions for sampling and GPs for optimizers. For example, a multivariate hierarchical model with random intercept looks like:

def lm_ri(idx, X, y=None): # idx, 
    # Random intercepts
    n_idx = len(np.unique(idx))
    μ_α = pyro.sample("μ_α", dist.Normal(0., 100.))
    σ_α = pyro.sample("σ_α", dist.HalfNormal(100.))
    with pyro.plate("plate_idx", n_idx):
        α = pyro.sample("α", dist.Normal(μ_α, σ_α))
    # multivariate slopes
    y_hat = α[idx]
    for i in range(X.shape[1]):
        y_hat = y_hat + X[:, i] * pyro.sample(f"beta_{i}", dist.Normal(0, 10))
    σ = pyro.sample("σ", dist.HalfNormal(100.))

    with pyro.plate("data", X.shape[0]):
        pyro.sample("obs", dist.Normal(y_hat, σ), obs=y)

… while GPs are optimized rather than sampled, for instance:

def gp_ri(num_steps = 400):
    # model definition
    kernel = gp.kernels.RBF(input_dim=X.shape[1])
    sgpr = gp.models.SparseGPRegression(
        jitter=1.0e-5 # jitter pour stabiliser
    sgpr.kernel.lengthscale = pyro.nn.PyroSample(dist.HalfCauchy(3.0))
    sgpr.kernel.variance = pyro.nn.PyroSample(dist.HalfCauchy(3.0))

    # fit
    optimizer = torch.optim.Adam(sgpr.parameters(), lr=0.005)
    loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
    losses = []
    for i in range(num_steps):
        loss = loss_fn(sgpr.model,
        if i % 100 == 0: 
            print('i = {}'.format(i))

model = gpri(num_steps = 400)

Is there a way to combine both options to add a random intercept to the GP, be it by NUTS sampling or torch optimizers?


I think you can use mean_function for a random intercept. I believe this semi-parametric notebook would be helpful for you. :slight_smile:

1 Like