[Help pls] Multiple latent GPs but only one f_loc and f_scale_tril

Hello all,

I am trying to use two latent GPs to model a signal of the type:

y(x) = A(x_0) * exp(-0.5 * (d(x_0) - x_1)^2)

where A and d are functions that I am modeling with a GP. The code I have now is:

import torch
import pyro
import pyro.distributions as dist
import pyro.contrib.gp as gp
import matplotlib.pyplot as plt
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import numpy as np


JITTER = 1e-2
RESOLUTION = 0.1

def experimental_response(intensity_curve, dispersion_curve, X, resolution=1):
    return torch.exp(intensity_curve) * torch.exp(-0.5 * ((dispersion_curve - X[:,1]) / (resolution**2)) ** 2)

def true_dispersion_curve(x0):
    return torch.cos(2 * torch.pi * x0)

def true_intensity_response(x0):
    return 1 + x0

def generate_data(n: int = 100, noise: float = 0.1):
    """Generate synthetic data for GP training."""
    x0 = torch.linspace(0, 1, n) * 2
    x1 = torch.rand(n) * 4 - 2
    dispersion_true = true_dispersion_curve(x0)
    I_true = true_intensity_response(x0)
    X = torch.stack([x0, x1], dim=-1)
    y = experimental_response(I_true, dispersion_true, X) + noise * torch.randn(n)
    return X, y

def create_model(X, y):
    """Create GP model and SVI training setup."""
    pyro.clear_param_store()

    # Create kernels
    dispersion_kernel = gp.kernels.RBF(input_dim=1)
    amplitude_kernel = gp.kernels.RBF(input_dim=1)

    # Create GPs with prefixed names
    dispersion_gp = gp.models.VariationalGP(
        X=X[:, 0:1],
        y=None,
        kernel=dispersion_kernel,
        likelihood=None,
        whiten=True,
        jitter=JITTER
    )

    amplitude_gp = gp.models.VariationalGP(
        X=X[:, 0:1],
        y=None,
        kernel=amplitude_kernel,
        likelihood=None,
        whiten=True,
        jitter=JITTER
    )

    def model():
        dispersion_latent, _ = dispersion_gp.model()
        amplitude_latent, _ = amplitude_gp.model()

        mean_y = experimental_response(amplitude_latent, dispersion_latent, X)

        with pyro.plate("data", len(y)):
            pyro.sample("obs", dist.Normal(mean_y, 0.1), obs=y)

    def guide():
        dispersion_gp.guide()
        amplitude_gp.guide()

    return dispersion_gp, amplitude_gp, model, guide

def train_svi(model, guide, num_steps=5000, lr=0.01):
    """Train the model with SVI."""
    optimizer = Adam({'lr': lr})
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
    loss_list = [0]*num_steps
    for i in range(num_steps):
        loss = svi.step()
        loss_list[i] = loss
        if i % 1000 == 0:
            print(f"[{i}] Loss: {loss:.4f}")
    return svi, loss_list

def predict(dispersion_gp, Amplitude_gp, test_X):
    """Predict from trained GP model."""
    with torch.no_grad():
        dispersion_mean, dispersion_std = dispersion_gp(test_X[:, 0].unsqueeze(-1), full_cov=False)
        Amplitude_mean, Amplitude_std = Amplitude_gp(test_X[:, 0].unsqueeze(-1), full_cov=False)
    return dispersion_mean, dispersion_std, Amplitude_mean, Amplitude_std

# actually testing out the behaviour
N = 100
X, y = generate_data(n=N)

dispersion_gp, Amplitude_gp, model, guide = create_model(X, y)
svi, loss_list = train_svi(model, guide, num_steps=10)

test_X, test_y = generate_data(n=500)
dispersion_mean, dispersion_std, Amplitude_mean, Amplitude_std = predict(dispersion_gp, Amplitude_gp, test_X)

The problem is that I keep getting:

RuntimeError: Multiple sample sites named 'f'
Trace Shapes:            
 Param Sites:            
        f_loc     100    
 f_scale_tril 100 100    
Sample Sites:            
       f dist       | 100
        value       | 100

I am pretty sure that both GPs are trying to make parameters that are called f_loc and f_scale_tril. I wish there was a way to do something like:

amplitude_gp = VariationalGP(...., sample_prefix='amplitude_', ...)

I tried putting all the calls to guide() or model() inside their own poutine.block() indent. I am not sure what I can do with this now. I would like just to be able to tell pyro that I have 2 GPs, not 1, so it should make variational parameters for both Amplitude_gp and dispersion_gp. How do I do this? I hope this is not a noob problem. I have tried to read all the forum posts and guide docs that seemed relevant.

Thanks,

I think you need to use Pyro nn.Module Modules in Pyro — Pyro Tutorials 1.9.1 documentation and define your GPs as submodules. Otherwise we have name conflict. See e.g. Inferences for Deep Gaussian Process models in Pyro | fehiepsi's blog

Thanks a lot, especially for the blog post! I will try this and see how it goes.