Gaussian Process Factor Analysis with Pyro+GPyTorch

Hi all!

I’m trying to implement Gaussian process factor analysis for spatial count data, similar to the basic idea in MEFISTO.

The model is essentially factor analysis with a Poisson likelihood, but with the latent factors not being sampled independently, but from a 2D Gaussian process. This acts as a smoothness prior.

I decided to implement this with Pyro and the low-level Pyro interface of GPyTorch.

My spatial count data has D features and lies on a regularly spaced square grid with a total of N points, i.e. sqrt(N) points in every direction. The model accepts this data with flattened spatial dimensions, i.e. as a tensor of shape (D, N,).

class P_GPFA(gpytorch.models.ApproximateGP):
def __init__(
    self,
    data: torch.tensor,
    n_latents: int,
    n_inducing: int=100,
    ):
    """Spatial Gaussian Process Factor Analysis with Poisson likelihood.

    :param data: tensor of shape (n_features, N) with bin count data.
        N is the total number of bins and the last dimension is the 
        flattened version of the two spatial dimensions of shape
        (sqrt(N), sqrt(N)).
    :param n_latents: number of latent processes (factors)
    :param n_inducing: number of sparse variational GP inducing points per
        latent factor
    """
    self.device = data.device
    self.dtype = data.dtype
    self.data = data
    self.N = data.shape[1] # total number of bins
    self.K = n_latents
    self.D = data.shape[0] # number of features

    # bin center coordinate tensor of shape (self.N, 2)
    self.bin_coordinates = (utils.get_bin_coordinates(self.N)
        .to(self.device, dtype=self.dtype))

    # bin coordinates projected on 1D axes, shape (sqrt(N), 2)
    grid = (utils.get_bin_coordinates_projection(self.N)
        .to(self.device, dtype=self.dtype))
    
    # feature wise average intensity, shape (self.D,)
    self.mean_intensity = (data.mean(dim=[-1])
        .to(self.device, dtype=self.dtype))

    # data / bins plate
    self.N_plate = pyro.plate(
        name='N_plate',
        dim=-1,
        size=self.N,
        device=self.device,
    )

    # latents plate
    self.K_plate = pyro.plate(
        name='K_plate',
        dim=-2,
        size=self.K,
        device=self.device,
    )

    # features plate
    self.D_plate = pyro.plate(
        name='D_plate',
        dim=-3,
        size=self.D,
        device=self.device,
    )

    # initial sparse variational GP inducing points locations
    inducing_points = torch.rand(
        size=[self.K, n_inducing, 2],
        device=self.device,
        dtype=self.dtype,
    )

    # GP variational distribution
    var_dist = gpytorch.variational.CholeskyVariationalDistribution(
        num_inducing_points=n_inducing,
        batch_shape=torch.Size(
            [self.K],
            device=self.device,
            dtype=self.dtype,
        ),
    )
    variational_strategy = gpytorch.variational.VariationalStrategy(
        model=self,
        inducing_points=inducing_points,
        variational_distribution=var_dist,
    )

    super().__init__(variational_strategy=variational_strategy)

    # GP prior mean
    self.mean_module = gpytorch.means.ZeroMean(
        batch_shape=torch.Size(
            [self.K],
            device=self.device,
            dtype=self.dtype,
        ),
    )

    # GP prior covariance
    self.covar_module = gpytorch.kernels.GridKernel(
        base_kernel=gpytorch.kernels.RBFKernel(
            batch_shape=torch.Size(
                [self.K],
                device=self.device,
                dtype=self.dtype,
            ),
        ),
        grid=grid,
    )

def forward(
    self,
    x: torch.tensor,
    ) -> gpytorch.distributions.MultivariateNormal:
    """Sample from GP priors.
    
    :param x: coordinate tensor of shape (N, 2) where the GPs are evaluated
    :returns: GP prior distributions at x
    """
    mean = self.mean_module(x)
    covar = self.covar_module(x)
    return gpytorch.distributions.MultivariateNormal(mean, covar)

def model(self):
    """Pyro model"""
    pyro.module('gp', self)

    # factor loadings
    w_loc = torch.tensor(0., device=self.device, dtype=self.dtype)
    w_scale = torch.tensor(1., device=self.device, dtype=self.dtype)
    with self.D_plate, self.K_plate:
        w = pyro.sample(
            name='w',
            fn=dist.Normal(loc=w_loc, scale=w_scale),
        ).view([-1, self.D, self.K, 1])

    # latent processes (factors)
    with self.K_plate, self.N_plate:
        z = pyro.sample(
            name="z",
            fn=self.pyro_model(self.bin_coordinates),
        ).view([-1, 1, self.K, self.N]).to(self.device, dtype=self.dtype)

    # GLM link function
    intensity = torch.log(1 + torch.exp(
        torch.matmul(w.squeeze(-1), z.squeeze(-3))
    )).view(-1, self.D, 1, self.N)

    # scale intensity by feature wise mean
    intensity *= self.mean_intensity.view([1, self.D, 1, 1])

    # make sure the intensity is non-negative
    intensity = torch.clamp(intensity, min=1e-4)

    # observations
    with self.D_plate, self.N_plate:
        pyro.sample(
            name="obs",
            fn=dist.Poisson(intensity),
            obs=self.data,
        )

def guide(self):
    """Pyro guide"""
    # factor loadings
    w_loc = pyro.param(
        name="w_loc",
        init_tensor=torch.zeros(
            size=[self.D, self.K, 1],
            device=self.device,
            dtype=self.dtype,
        ),
    )

    w_scale = pyro.param(
        name="w_scale",
        init_tensor=torch.ones(
            size=[self.D, self.K, 1],
            device=self.device,
            dtype=self.dtype,
            ),
        constraint=dist.constraints.positive,
    )
    
    with self.D_plate, self.K_plate:
        pyro.sample(
            name="w",
            fn=dist.Normal(loc=w_loc, scale=w_scale),
        )

    # latent processes (factors)
    with self.K_plate, self.N_plate:
        pyro.sample(
            name="z",
            fn=self.pyro_guide(self.bin_coordinates),
        )

I train the model with the following function:

def train_model(
    model,
    lr: float=5e-3,
    gamma: float=0.2,
    max_n_epochs: int=10000,
    n_particles: int=10,
    print_every: int=10,
    patience: int=300,
    delta: float=0.01
    ):
    """Train the model.

    :param model: Pyro/GPyTorch model
    :param lr: initial learning rate
    :param gamma: decaying learning rate parameter (lr is this fraction of
    itself after 1000 epochs)
    :param max_n_epochs: maximum number of training epochs
    :param n_particles: number of samples to use for ELBO gradient
    :param print_every: frequency of progress prints
    :param patience: number of epochs to wait for loss to decrease by
    param delta before stopping training
    :param delta: absolute value of loss difference in early stopping
    """
    model.to(device=model.device)
    model.train()

    optimizer = pyro.optim.ClippedAdam(
    {'lr' : lr, 'lrd' : gamma ** (1 / 1000)}
    )

    loss_func = pyro.infer.Trace_ELBO(
        retain_graph=True,
        num_particles=n_particles,
        vectorize_particles=True,
    )

    svi = pyro.infer.SVI(model.model, model.guide, optimizer, loss=loss_func)

    min_loss = 1e5
    patience_count = 0
    initial_loss = svi.evaluate_loss()
    model.losses = []

    for epoch in range(max_n_epochs):
        # rescale loss to be initially of magnitude 1
        with pyro.poutine.scale(scale=1.0 / abs(initial_loss)):
            model.losses.append(svi.step())

        if epoch % print_every == 0:
            print(
                'Epoch: %s; Loss: %s; Min loss: %s; Patience count: %s'
                %(epoch, round(model.losses[-1], 4),
                round(min_loss, 4), patience_count)
            )

        # early stopping
        if model.losses[-1] <= (min_loss - delta):
            min_loss = model.losses[-1]
            patience_count = 0
        else:
            patience_count += 1
    
        if patience_count > patience:
            break

The problem: already with a relatively small dataset (D=50 features, N=50*50=2500 points) this leads to my GPU (24 GB) running out of memory. I tried to figure out what requires so much memory here, but I didn’t get a satisfactory answer.
I put this trick into my training loop and checked the cuda tensors, but a rough calculation of the required bits resulted in only a few GBs.

Does anyone have an idea how to further trace down the large memory consumption? Or suggestions on how to reduce it? I converted everything to float32 already, float16 seems to be not possible for some reason. Reducing n_particles to 1 works, but is close to the limit.

i don’t believe you can use GridKernel with VariationalStrategy like that.

have you tried removing GridKernel? the approximations inherent in the latter should already be sufficient. if you’re still having memory issues you may also need to use data subsampling

Thank you for the advice, GridKernel was indeed redundant. Removing it did however not resolve the memory issues, so I will probably have to look at ways to use subsampling.