Latent GP with derivative information

Hey there,
Hope this post finds you well. I’ve been using Pyro for research for a year or so, and it is an awesome Python package.
I’ve been facing a new problem lately, somewhat unusual (?). In short, I’ve been trying to learn a latent GP, where I wish to inject derivative information as well as function observations.
I’m attaching below an MWE, which is kind of a toy example but should route the development of a more complex framework. Therein, I’m learning the function y = latent(x) * x**2, where I assumed latent(x) = x for simplicity.
The strange thing is that the training is successful in learning the whole function without derivative information. Yet, when I try to enforce derivative observations, the model fits both the derivative and function observation independently. In other words, the fit to derivative has no effect on function. I am essentially imposing the derivative of latent(x) = 1.
I am supposedly on the wrong track and seeking advice on this matter. I’d be really grateful to anyone who will give feedback! Thank in advance and best regards.

The code:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import Normalize
import torch
from torch.distributions.constraints import positive as torch_pos
import pyro
import pyro.distributions as dist
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.infer import Predictive, SVI, Trace_ELBO
pyro.set_rng_seed(0)


def latent(x):
    return x


def make_data():
    x1 = np.linspace(-1, 1, 50)
    x2 = np.linspace(-1, 1, 50)
    y  = latent(x2)*x1**2 + np.random.normal(0, 0.05, size=x1.shape[0])

    # plt.figure()
    # plt.plot(x1.numpy(), y.numpy())
    # plt.show()

    return torch.tensor(x1, dtype=torch.double), torch.tensor(x2, dtype=torch.double), torch.tensor(y, dtype=torch.double)


def fdiff(x, y):

    dydx = torch.zeros_like(y)
    dydx[1:-1] = (y[2:] - y[:-2]) / (x[2:] - x[:-2])
    dydx[0] = (y[1] - y[0]) / (x[1] - x[0])
    dydx[-1] = (y[-1] - y[-2]) / (x[-1] - x[-2])

    return dydx


def kernel(X, Z, var, length, noise, jitter=1.0e-6, include_noise=True):
    """RBF kernel of function observations
    taken from https://num.pyro.ai/en/0.15.3/examples/gp.html
    """
    deltaXsq = torch.pow((X[:, None] - Z) / length, 2.0)
    k = var * torch.exp(-0.5 * deltaXsq)
    if include_noise:
        k += (noise + jitter) * torch.eye(X.shape[0])
    return k


def kernel_f_prime(X, Z, var, length, noise, jitter=1.0e-6, include_noise=False):
    """RBF kernel: covariance between derivatives and function observations"""
    diff = X[:, None] - Z
    k = kernel(X, Z, var, length, noise, jitter, include_noise=False)
    if include_noise:
        k += (noise + jitter) * torch.eye(X.shape[0])
    return (-1.0) * k * diff / length**2 


def kernel_prime_prime(X, Z, var, length, noise, jitter=1.0e-6, include_noise=False):
    """RBF kernel of derivatives observations"""
    diff = X[:, None] - Z
    k = kernel(X, Z, var, length, noise, jitter, include_noise=True)
    if include_noise:
        k += (noise + jitter) * torch.eye(X.shape[0])
    return k * (length**2 - diff**2) / length**4


def kernel_joint(ker, ker_f_p, ker_p_p):
    """Asseble joint kernel"""
    return torch.cat([torch.cat([ker,       ker_f_p.T], dim=1),
                      torch.cat([ker_f_p,     ker_p_p], dim=1)], dim=0)


def inspect_kernel():
    var = 1.
    length = 0.5
    noise = 1.0e-5
    jitter = 1.0e-6
    
    x, x_prime, y = make_data()
    # x_prime = torch.linspace(-1., 1., 50)
    
    k = kernel(x, x, var, length, noise, jitter, include_noise=True)
    k_f_p = kernel_f_prime(x_prime, x, var, length, noise, jitter, include_noise=False)
    k_p_p = kernel_prime_prime(x_prime, x_prime, var, length, noise, jitter, include_noise=False)
    k_j = kernel_joint(k, k_f_p, k_p_p)

    plt.figure()
    for k in range(0, 5):
        sample = pyro.sample("sample", dist.MultivariateNormal(torch.zeros(x.shape[0]+ x_prime.shape[0], dtype=torch.double), k_j))
        f = sample[0: x.shape[0]]
        p = sample[x.shape[0]: None]

        f_diff = fdiff(x, f)

        plt.plot(x, f, "b")
        plt.plot(x, f_diff, "g")
        plt.plot(x_prime, p, "r")
    plt.grid(True)
    plt.show()



def model_wo_derivatives(x1, x2=None, y=None, noise_level=1.):
    """model without derivative information"""
    jitter = 1e-6

    var_x2 = pyro.param("var_x2", torch.tensor([5.]), constraint=torch_pos)
    len_x2 = pyro.param("len_x2", torch.tensor([2.]), constraint=torch_pos)
    noi_x2 = pyro.param("noi_x2", torch.tensor([1.e-2]), constraint=torch_pos)
    noi = pyro.param("noi", torch.tensor([0.5]), constraint=torch_pos)
    der = pyro.param("der", torch.tensor([1.e-2]), constraint=torch_pos)

    ker_x2 = kernel(X=x2, Z=x2, var=var_x2, length=len_x2, noise=noi_x2, jitter=jitter, include_noise=True)

    _x2_est = pyro.sample("_x2_all", dist.MultivariateNormal(loc=torch.zeros_like(ker_x2[:, 0]), covariance_matrix=ker_x2))
    _y_est = _x2_est * (x1**2)
    noise = torch.ones_like(_x2_est)*noi
    
    with pyro.plate("data", x1.shape[0]):
        pyro.sample("obs", dist.Normal(_y_est, noi).to_event(), obs=y)

    return _y_est


def model_w_derivatives(x1, x2=None, y=None, noise_level=1.): #25
    """model with derivative information"""
    jitter = 1e-6
    
    var_x2 = pyro.param("var_x2", torch.tensor([5.]), constraint=torch_pos)
    len_x2 = pyro.param("len_x2", torch.tensor([2.]), constraint=torch_pos)
    noi_x2 = pyro.param("noi_x2", torch.tensor([1.e-2]), constraint=torch_pos)
    noi = pyro.param("noi", torch.tensor([0.5]), constraint=torch_pos)
    der = pyro.param("der", torch.tensor([1.e-2]), constraint=torch_pos)

    ker_x2 = kernel(X=x2, Z=x2, var=var_x2, length=len_x2, noise=noi_x2, jitter=jitter, include_noise=True)
    ker_x2_f_p = kernel_f_prime(X=x2, Z=x2, var=var_x2, length=len_x2, noise=0.0, jitter=0.0, include_noise=False)
    ker_x2_p_p = kernel_prime_prime(X=x2, Z=x2, var=var_x2, length=len_x2, noise=noi_x2, jitter=jitter, include_noise=False)
    ker_x2_j = kernel_joint(ker_x2, ker_x2_f_p, ker_x2_p_p)

    _x2_all =  pyro.sample("_x2_all", dist.MultivariateNormal(loc=torch.zeros_like(ker_x2_j[:, 0]), covariance_matrix=ker_x2_j))
    # extract function and derivative values according to how joint kernel was assembled
    _x2_est =  _x2_all[0: x2.shape[0]]
    _x2_diff = _x2_all[x2.shape[0]: None]

    _y_est = torch.concat([_x2_est * (x1**2), _x2_diff])
    noise = torch.concat([torch.ones_like(_x2_est)*noi, torch.ones_like(_x2_diff)*der])

    if y is not None:
        _sample = torch.concat([y, 1.0*torch.ones_like(_x2_diff)])
        with pyro.plate("data"):
            # pyro.sample("obs", dist.MultivariateNormal(loc=_y_est, covariance_matrix=torch.diag(noise)).to_event(), obs=_sample)
            pyro.sample("obs", dist.Normal(_x2_est * (x1**2), noi*torch.ones_like(_x2_est)).to_event(), obs=y) *\
                pyro.sample("diff", dist.Normal(_x2_diff, der*torch.ones_like(_x2_diff)).to_event(), obs=1.0*torch.ones_like(_x2_diff))
    

def train(model, gx1, gx2, gy, lr=.05, steps=10, post_samples=1):
    loss = Trace_ELBO()
    guide = AutoDiagonalNormal(model)
    optim = pyro.optim.Adam({"lr": lr})
    losses = []
    svi = SVI(model, guide, optim, loss)
    for step in range(steps):
        guide_trace = pyro.poutine.trace(guide).get_trace(gx1, gx2, gy)
        guide_model = pyro.poutine.trace(model).get_trace(gx1, gx2, gy)
        loss_value = svi.step(gx1, gx2, gy)
        losses.append(loss_value)
        if step % 100 == 0:
            print(f"Step {step} : Loss = {loss_value}")
            print("var_x2", guide_model.nodes["var_x2"]["value"])
            print("len_x2", guide_model.nodes["len_x2"]["value"])
            print("noi_x2", guide_model.nodes["noi_x2"]["value"])
            print("noi", guide_model.nodes["noi"]["value"])
            print("der", guide_model.nodes["der"]["value"])
            print("_x2_all", guide_trace.nodes["_x2_all"]["value"])
            print("----------")

    return Predictive(model, guide=guide, num_samples=post_samples, return_sites=("_RETURN",))


def inference_wo_derivatives(steps=1000):
    x1, x2, y = make_data()
    predictive = train(model_wo_derivatives, x1, x2, y, steps=steps, post_samples=200)
    x1_t = torch.linspace(-1, 1, 50)
    x2_t = torch.linspace(-1, 1, 10)


def inference_w_derivatives(steps=1000):
    x1, x2, y = make_data()
    predictive = train(model_w_derivatives, x1, x2, y, steps=steps, post_samples=200)
    x1_t = torch.linspace(-1, 1, 50)
    x2_t = torch.linspace(-1, 1, 10)


if __name__ == "__main__":
    
    # inspect_kernel()
    # inference_wo_derivatives()
    inference_w_derivatives()

Hi aletgn,

I haven’t looked at your code in complete detail, but I have faced a similar problem with implementing derivative and sparse GP with numpyro. I think the issue lies in your guide function. The guide function does not “know” that it is a distribution of function values and derivatives, so you shouldn’t expect it to take into account the relationship between the function and its derivative.

I recommend building your model so your latent variables are only the function values. That way, the derivatives appear only when conditioning your derivative GP with the function values. The derivative observations would then be implemented as a regularization term via numpyro.factor. By doing things this way, you’re effectively constraining your guide to retain the derivative structure (i.e., q(f' | f) = p_{GP}(f' | f)). I don’t have time to sketch a proof for this, but I’m fairly certain of this since I’ve proven a similar result for sparse GPs.

Take a look at my github repo below. The first and third notebooks may help you with implementation.

derivative-tinygp/03_svi_1d_deriv_gp.ipynb at main · edwarddramirez/derivative-tinygp · GitHub

2 Likes

Hello @edwarddramirez,
Thanks for your interest in my question and for sharing your repo’s link – very much appreciated!
I have quickly gone through your notebooks to get some input for the implementation, and I believe your approach will surely help me. I am actively working on this, but I need some time to get things to work. If I need further input, I will get back to you.
Thanks again for your feedback thus far.