Hey there,
Hope this post finds you well. I’ve been using Pyro for research for a year or so, and it is an awesome Python package.
I’ve been facing a new problem lately, somewhat unusual (?). In short, I’ve been trying to learn a latent GP, where I wish to inject derivative information as well as function observations.
I’m attaching below an MWE, which is kind of a toy example but should route the development of a more complex framework. Therein, I’m learning the function y = latent(x) * x**2
, where I assumed latent(x) = x
for simplicity.
The strange thing is that the training is successful in learning the whole function without derivative information. Yet, when I try to enforce derivative observations, the model fits both the derivative and function observation independently. In other words, the fit to derivative has no effect on function. I am essentially imposing the derivative of latent(x) = 1
.
I am supposedly on the wrong track and seeking advice on this matter. I’d be really grateful to anyone who will give feedback! Thank in advance and best regards.
The code:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import Normalize
import torch
from torch.distributions.constraints import positive as torch_pos
import pyro
import pyro.distributions as dist
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.infer import Predictive, SVI, Trace_ELBO
pyro.set_rng_seed(0)
def latent(x):
return x
def make_data():
x1 = np.linspace(-1, 1, 50)
x2 = np.linspace(-1, 1, 50)
y = latent(x2)*x1**2 + np.random.normal(0, 0.05, size=x1.shape[0])
# plt.figure()
# plt.plot(x1.numpy(), y.numpy())
# plt.show()
return torch.tensor(x1, dtype=torch.double), torch.tensor(x2, dtype=torch.double), torch.tensor(y, dtype=torch.double)
def fdiff(x, y):
dydx = torch.zeros_like(y)
dydx[1:-1] = (y[2:] - y[:-2]) / (x[2:] - x[:-2])
dydx[0] = (y[1] - y[0]) / (x[1] - x[0])
dydx[-1] = (y[-1] - y[-2]) / (x[-1] - x[-2])
return dydx
def kernel(X, Z, var, length, noise, jitter=1.0e-6, include_noise=True):
"""RBF kernel of function observations
taken from https://num.pyro.ai/en/0.15.3/examples/gp.html
"""
deltaXsq = torch.pow((X[:, None] - Z) / length, 2.0)
k = var * torch.exp(-0.5 * deltaXsq)
if include_noise:
k += (noise + jitter) * torch.eye(X.shape[0])
return k
def kernel_f_prime(X, Z, var, length, noise, jitter=1.0e-6, include_noise=False):
"""RBF kernel: covariance between derivatives and function observations"""
diff = X[:, None] - Z
k = kernel(X, Z, var, length, noise, jitter, include_noise=False)
if include_noise:
k += (noise + jitter) * torch.eye(X.shape[0])
return (-1.0) * k * diff / length**2
def kernel_prime_prime(X, Z, var, length, noise, jitter=1.0e-6, include_noise=False):
"""RBF kernel of derivatives observations"""
diff = X[:, None] - Z
k = kernel(X, Z, var, length, noise, jitter, include_noise=True)
if include_noise:
k += (noise + jitter) * torch.eye(X.shape[0])
return k * (length**2 - diff**2) / length**4
def kernel_joint(ker, ker_f_p, ker_p_p):
"""Asseble joint kernel"""
return torch.cat([torch.cat([ker, ker_f_p.T], dim=1),
torch.cat([ker_f_p, ker_p_p], dim=1)], dim=0)
def inspect_kernel():
var = 1.
length = 0.5
noise = 1.0e-5
jitter = 1.0e-6
x, x_prime, y = make_data()
# x_prime = torch.linspace(-1., 1., 50)
k = kernel(x, x, var, length, noise, jitter, include_noise=True)
k_f_p = kernel_f_prime(x_prime, x, var, length, noise, jitter, include_noise=False)
k_p_p = kernel_prime_prime(x_prime, x_prime, var, length, noise, jitter, include_noise=False)
k_j = kernel_joint(k, k_f_p, k_p_p)
plt.figure()
for k in range(0, 5):
sample = pyro.sample("sample", dist.MultivariateNormal(torch.zeros(x.shape[0]+ x_prime.shape[0], dtype=torch.double), k_j))
f = sample[0: x.shape[0]]
p = sample[x.shape[0]: None]
f_diff = fdiff(x, f)
plt.plot(x, f, "b")
plt.plot(x, f_diff, "g")
plt.plot(x_prime, p, "r")
plt.grid(True)
plt.show()
def model_wo_derivatives(x1, x2=None, y=None, noise_level=1.):
"""model without derivative information"""
jitter = 1e-6
var_x2 = pyro.param("var_x2", torch.tensor([5.]), constraint=torch_pos)
len_x2 = pyro.param("len_x2", torch.tensor([2.]), constraint=torch_pos)
noi_x2 = pyro.param("noi_x2", torch.tensor([1.e-2]), constraint=torch_pos)
noi = pyro.param("noi", torch.tensor([0.5]), constraint=torch_pos)
der = pyro.param("der", torch.tensor([1.e-2]), constraint=torch_pos)
ker_x2 = kernel(X=x2, Z=x2, var=var_x2, length=len_x2, noise=noi_x2, jitter=jitter, include_noise=True)
_x2_est = pyro.sample("_x2_all", dist.MultivariateNormal(loc=torch.zeros_like(ker_x2[:, 0]), covariance_matrix=ker_x2))
_y_est = _x2_est * (x1**2)
noise = torch.ones_like(_x2_est)*noi
with pyro.plate("data", x1.shape[0]):
pyro.sample("obs", dist.Normal(_y_est, noi).to_event(), obs=y)
return _y_est
def model_w_derivatives(x1, x2=None, y=None, noise_level=1.): #25
"""model with derivative information"""
jitter = 1e-6
var_x2 = pyro.param("var_x2", torch.tensor([5.]), constraint=torch_pos)
len_x2 = pyro.param("len_x2", torch.tensor([2.]), constraint=torch_pos)
noi_x2 = pyro.param("noi_x2", torch.tensor([1.e-2]), constraint=torch_pos)
noi = pyro.param("noi", torch.tensor([0.5]), constraint=torch_pos)
der = pyro.param("der", torch.tensor([1.e-2]), constraint=torch_pos)
ker_x2 = kernel(X=x2, Z=x2, var=var_x2, length=len_x2, noise=noi_x2, jitter=jitter, include_noise=True)
ker_x2_f_p = kernel_f_prime(X=x2, Z=x2, var=var_x2, length=len_x2, noise=0.0, jitter=0.0, include_noise=False)
ker_x2_p_p = kernel_prime_prime(X=x2, Z=x2, var=var_x2, length=len_x2, noise=noi_x2, jitter=jitter, include_noise=False)
ker_x2_j = kernel_joint(ker_x2, ker_x2_f_p, ker_x2_p_p)
_x2_all = pyro.sample("_x2_all", dist.MultivariateNormal(loc=torch.zeros_like(ker_x2_j[:, 0]), covariance_matrix=ker_x2_j))
# extract function and derivative values according to how joint kernel was assembled
_x2_est = _x2_all[0: x2.shape[0]]
_x2_diff = _x2_all[x2.shape[0]: None]
_y_est = torch.concat([_x2_est * (x1**2), _x2_diff])
noise = torch.concat([torch.ones_like(_x2_est)*noi, torch.ones_like(_x2_diff)*der])
if y is not None:
_sample = torch.concat([y, 1.0*torch.ones_like(_x2_diff)])
with pyro.plate("data"):
# pyro.sample("obs", dist.MultivariateNormal(loc=_y_est, covariance_matrix=torch.diag(noise)).to_event(), obs=_sample)
pyro.sample("obs", dist.Normal(_x2_est * (x1**2), noi*torch.ones_like(_x2_est)).to_event(), obs=y) *\
pyro.sample("diff", dist.Normal(_x2_diff, der*torch.ones_like(_x2_diff)).to_event(), obs=1.0*torch.ones_like(_x2_diff))
def train(model, gx1, gx2, gy, lr=.05, steps=10, post_samples=1):
loss = Trace_ELBO()
guide = AutoDiagonalNormal(model)
optim = pyro.optim.Adam({"lr": lr})
losses = []
svi = SVI(model, guide, optim, loss)
for step in range(steps):
guide_trace = pyro.poutine.trace(guide).get_trace(gx1, gx2, gy)
guide_model = pyro.poutine.trace(model).get_trace(gx1, gx2, gy)
loss_value = svi.step(gx1, gx2, gy)
losses.append(loss_value)
if step % 100 == 0:
print(f"Step {step} : Loss = {loss_value}")
print("var_x2", guide_model.nodes["var_x2"]["value"])
print("len_x2", guide_model.nodes["len_x2"]["value"])
print("noi_x2", guide_model.nodes["noi_x2"]["value"])
print("noi", guide_model.nodes["noi"]["value"])
print("der", guide_model.nodes["der"]["value"])
print("_x2_all", guide_trace.nodes["_x2_all"]["value"])
print("----------")
return Predictive(model, guide=guide, num_samples=post_samples, return_sites=("_RETURN",))
def inference_wo_derivatives(steps=1000):
x1, x2, y = make_data()
predictive = train(model_wo_derivatives, x1, x2, y, steps=steps, post_samples=200)
x1_t = torch.linspace(-1, 1, 50)
x2_t = torch.linspace(-1, 1, 10)
def inference_w_derivatives(steps=1000):
x1, x2, y = make_data()
predictive = train(model_w_derivatives, x1, x2, y, steps=steps, post_samples=200)
x1_t = torch.linspace(-1, 1, 50)
x2_t = torch.linspace(-1, 1, 10)
if __name__ == "__main__":
# inspect_kernel()
# inference_wo_derivatives()
inference_w_derivatives()