I’ve got a Pyro model that is using a GP kernel to model a latent function, which runs just fine if I use any of the isotonic kernels:
...
self.age_kernel = gp.kernels.RBF(input_dim=1)
self.age_kernel.lengthscale = PyroSample(dist.Uniform(dtensor(3.), dtensor(10.)))
self.age_kernel.variance = PyroSample(dist.HalfCauchy(dtensor(1.)))
...
however, when I try to add a Linear or Polynomial kernel into the mix:
...
self.age_kernel = gp.kernels.Sum(
gp.kernels.RBF(input_dim=1),
gp.kernels.Polynomial(input_dim=1, degree=2)
)
...
I get a runtime error that appears to be related to inverting a non-positive definite matrix:
RuntimeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/pyro/contrib/gp/kernels/dot_product.py in _dot_product(self, X, Z, diag)
34 raise ValueError("Inputs must have the same number of features.")
35
---> 36 return X.matmul(Z.t())
37
38
RuntimeError: "addmm_cuda" not implemented for 'Long'
Its not clear to me why adding a simple kernel (and if I run just a Polynomial kernel, I get the same problem) would cause issues with positive definiteness. Am I missing something here?
Many thanks in advance.
Here is a runnable example that demonstrates the problem:
# -*- coding: utf-8 -*-
import pandas as pd
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import (
SVI,
Trace_ELBO
)
import pyro.infer.autoguide as autoguides
import pyro.contrib.gp as gp
data = pd.read_csv("https://scrippsco2.ucsd.edu/assets/data/atmospheric/stations/in_situ_co2/monthly/monthly_in_situ_co2_mlo.csv", skiprows=57,
names=("Yr","Mn", "Date_excel", "Date", "CO2", "seasonally_adj", "fit","seasonally_fit", "CO2_filled", "seasonally_filled"))
CO2, date = data.loc[data.CO2>0, ["CO2", "Date_excel"]].values.T
N = len(CO2)
device = 'cpu'
dtensor = lambda value: torch.tensor(value, device=device)
jitter = torch.tensor(1.0e-4, device=device)
class MaunaLoa(gp.Parameterized):
def __init__(self, *args, **kwargs):
super().__init__()
self.age_kernel = gp.kernels.Sum(
gp.kernels.RBF(input_dim=1),
gp.kernels.Linear(input_dim=1)
)
pyro.clear_param_store()
def forward(self,X, y=None):
# Aging curve
cov = self.age_kernel(X).contiguous()
with pyro.plate("months", N):
f_tilde = pyro.sample("f_tilde", dist.Normal(dtensor(0.0), dtensor(1.0)))
f = pyro.deterministic(
"f", torch.linalg.cholesky(cov + torch.eye(N, device=device) * jitter) @ f_tilde.squeeze()
)
self.sigma = pyro.sample("sigma", dist.LogNormal(dtensor(0.), dtensor(0.5)))
with pyro.plate("obs", N):
obs = pyro.sample(
"obs",
dist.Normal(loc=f, scale=self.sigma_stuff),
obs=y
)
projection_model = MaunaLoa()
N_STEPS = 5000
guide = autoguides.AutoLowRankMultivariateNormal(projection_model)
initial_lr = 0.01
gamma = 0.1
lrd = gamma ** (1 / N_STEPS)
optimizer = pyro.optim.ClippedAdam({"lr": initial_lr, "lrd": lrd})
svi = SVI(projection_model, guide, optimizer, loss=Trace_ELBO())
pyro.clear_param_store()
for j in range(N_STEPS):
loss = svi.step(dtensor(date).float(), dtensor(CO2).float())
if not j % 100:
print("[iteration %04d] loss: %.4f" % (j + 1, loss / N))
i suggest using double precision. also probably changing the units of time to be O(1)
date = (date - date.mean()) / 365
or what have you