Hi all!
I’m trying to implement Gaussian process factor analysis for spatial count data, similar to the basic idea in MEFISTO.
The model is essentially factor analysis with a Poisson likelihood, but with the latent factors not being sampled independently, but from a 2D Gaussian process. This acts as a smoothness prior.
I decided to implement this with Pyro and the low-level Pyro interface of GPyTorch.
My spatial count data has D features and lies on a regularly spaced square grid with a total of N points, i.e. sqrt(N) points in every direction. The model accepts this data with flattened spatial dimensions, i.e. as a tensor of shape (D, N,).
class P_GPFA(gpytorch.models.ApproximateGP):
def __init__(
self,
data: torch.tensor,
n_latents: int,
n_inducing: int=100,
):
"""Spatial Gaussian Process Factor Analysis with Poisson likelihood.
:param data: tensor of shape (n_features, N) with bin count data.
N is the total number of bins and the last dimension is the
flattened version of the two spatial dimensions of shape
(sqrt(N), sqrt(N)).
:param n_latents: number of latent processes (factors)
:param n_inducing: number of sparse variational GP inducing points per
latent factor
"""
self.device = data.device
self.dtype = data.dtype
self.data = data
self.N = data.shape[1] # total number of bins
self.K = n_latents
self.D = data.shape[0] # number of features
# bin center coordinate tensor of shape (self.N, 2)
self.bin_coordinates = (utils.get_bin_coordinates(self.N)
.to(self.device, dtype=self.dtype))
# bin coordinates projected on 1D axes, shape (sqrt(N), 2)
grid = (utils.get_bin_coordinates_projection(self.N)
.to(self.device, dtype=self.dtype))
# feature wise average intensity, shape (self.D,)
self.mean_intensity = (data.mean(dim=[-1])
.to(self.device, dtype=self.dtype))
# data / bins plate
self.N_plate = pyro.plate(
name='N_plate',
dim=-1,
size=self.N,
device=self.device,
)
# latents plate
self.K_plate = pyro.plate(
name='K_plate',
dim=-2,
size=self.K,
device=self.device,
)
# features plate
self.D_plate = pyro.plate(
name='D_plate',
dim=-3,
size=self.D,
device=self.device,
)
# initial sparse variational GP inducing points locations
inducing_points = torch.rand(
size=[self.K, n_inducing, 2],
device=self.device,
dtype=self.dtype,
)
# GP variational distribution
var_dist = gpytorch.variational.CholeskyVariationalDistribution(
num_inducing_points=n_inducing,
batch_shape=torch.Size(
[self.K],
device=self.device,
dtype=self.dtype,
),
)
variational_strategy = gpytorch.variational.VariationalStrategy(
model=self,
inducing_points=inducing_points,
variational_distribution=var_dist,
)
super().__init__(variational_strategy=variational_strategy)
# GP prior mean
self.mean_module = gpytorch.means.ZeroMean(
batch_shape=torch.Size(
[self.K],
device=self.device,
dtype=self.dtype,
),
)
# GP prior covariance
self.covar_module = gpytorch.kernels.GridKernel(
base_kernel=gpytorch.kernels.RBFKernel(
batch_shape=torch.Size(
[self.K],
device=self.device,
dtype=self.dtype,
),
),
grid=grid,
)
def forward(
self,
x: torch.tensor,
) -> gpytorch.distributions.MultivariateNormal:
"""Sample from GP priors.
:param x: coordinate tensor of shape (N, 2) where the GPs are evaluated
:returns: GP prior distributions at x
"""
mean = self.mean_module(x)
covar = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean, covar)
def model(self):
"""Pyro model"""
pyro.module('gp', self)
# factor loadings
w_loc = torch.tensor(0., device=self.device, dtype=self.dtype)
w_scale = torch.tensor(1., device=self.device, dtype=self.dtype)
with self.D_plate, self.K_plate:
w = pyro.sample(
name='w',
fn=dist.Normal(loc=w_loc, scale=w_scale),
).view([-1, self.D, self.K, 1])
# latent processes (factors)
with self.K_plate, self.N_plate:
z = pyro.sample(
name="z",
fn=self.pyro_model(self.bin_coordinates),
).view([-1, 1, self.K, self.N]).to(self.device, dtype=self.dtype)
# GLM link function
intensity = torch.log(1 + torch.exp(
torch.matmul(w.squeeze(-1), z.squeeze(-3))
)).view(-1, self.D, 1, self.N)
# scale intensity by feature wise mean
intensity *= self.mean_intensity.view([1, self.D, 1, 1])
# make sure the intensity is non-negative
intensity = torch.clamp(intensity, min=1e-4)
# observations
with self.D_plate, self.N_plate:
pyro.sample(
name="obs",
fn=dist.Poisson(intensity),
obs=self.data,
)
def guide(self):
"""Pyro guide"""
# factor loadings
w_loc = pyro.param(
name="w_loc",
init_tensor=torch.zeros(
size=[self.D, self.K, 1],
device=self.device,
dtype=self.dtype,
),
)
w_scale = pyro.param(
name="w_scale",
init_tensor=torch.ones(
size=[self.D, self.K, 1],
device=self.device,
dtype=self.dtype,
),
constraint=dist.constraints.positive,
)
with self.D_plate, self.K_plate:
pyro.sample(
name="w",
fn=dist.Normal(loc=w_loc, scale=w_scale),
)
# latent processes (factors)
with self.K_plate, self.N_plate:
pyro.sample(
name="z",
fn=self.pyro_guide(self.bin_coordinates),
)
I train the model with the following function:
def train_model(
model,
lr: float=5e-3,
gamma: float=0.2,
max_n_epochs: int=10000,
n_particles: int=10,
print_every: int=10,
patience: int=300,
delta: float=0.01
):
"""Train the model.
:param model: Pyro/GPyTorch model
:param lr: initial learning rate
:param gamma: decaying learning rate parameter (lr is this fraction of
itself after 1000 epochs)
:param max_n_epochs: maximum number of training epochs
:param n_particles: number of samples to use for ELBO gradient
:param print_every: frequency of progress prints
:param patience: number of epochs to wait for loss to decrease by
param delta before stopping training
:param delta: absolute value of loss difference in early stopping
"""
model.to(device=model.device)
model.train()
optimizer = pyro.optim.ClippedAdam(
{'lr' : lr, 'lrd' : gamma ** (1 / 1000)}
)
loss_func = pyro.infer.Trace_ELBO(
retain_graph=True,
num_particles=n_particles,
vectorize_particles=True,
)
svi = pyro.infer.SVI(model.model, model.guide, optimizer, loss=loss_func)
min_loss = 1e5
patience_count = 0
initial_loss = svi.evaluate_loss()
model.losses = []
for epoch in range(max_n_epochs):
# rescale loss to be initially of magnitude 1
with pyro.poutine.scale(scale=1.0 / abs(initial_loss)):
model.losses.append(svi.step())
if epoch % print_every == 0:
print(
'Epoch: %s; Loss: %s; Min loss: %s; Patience count: %s'
%(epoch, round(model.losses[-1], 4),
round(min_loss, 4), patience_count)
)
# early stopping
if model.losses[-1] <= (min_loss - delta):
min_loss = model.losses[-1]
patience_count = 0
else:
patience_count += 1
if patience_count > patience:
break
The problem: already with a relatively small dataset (D=50 features, N=50*50=2500 points) this leads to my GPU (24 GB) running out of memory. I tried to figure out what requires so much memory here, but I didn’t get a satisfactory answer.
I put this trick into my training loop and checked the cuda tensors, but a rough calculation of the required bits resulted in only a few GBs.
Does anyone have an idea how to further trace down the large memory consumption? Or suggestions on how to reduce it? I converted everything to float32 already, float16 seems to be not possible for some reason. Reducing n_particles to 1 works, but is close to the limit.