Adding Gaussian processes in mixed effects models

Hi everyone:

I’ve been trying to learn Pyro and the best way I seem to be able to learn is with actual use-cases for the type of problem I’m trying to solve. The problem I have is to run a model with some random effects, and also have a Gaussian process on time.

Neeraj’s GDP tutorial here was an amazing example of the groundwork I’d need to learn about implementing random effects, but what I’m not sure of yet is how I’d go about adding Gaussian process noise in my model. The GP examples here evaluates the GP model and guide using gp.models.GPRegression which is fed into the SVI, but I’m wondering if anyone has yet implemented adding a GP to the mean function (mu in the GDP example)?

If there was a way to evaluate the realizations of the GP such that it can be additive with my mean function, then that’d solve the problems that I’m trying to solve… any tips?

Thanks!

N

Great question! You can do it easily using ‘mean_function’ argument in GP. It is a hidden feature which I have not come up with an example yet for it. Mean function can be arbitrary stochastic function indeed. I will post an example here soon :slight_smile:

@fehiepsi Aha that is interesting! I would have expected to take the mixed effects model setup and add on GP stuff on top of it, but you’re suggesting that it’d be feasible to take the GPR setup and add on mixed effect elements to it eh?

And awesome thanks for taking the time to write up an example! Would be great if you could possibly show the example with a distribution like good ol’ random effects (since it takes me way too long to figure these stuff out!)!

linear = nn.Linear(10, 2)
def mean(x):
    return pyro.module("linear", linear)(x)

gpmodel = gp.models.VariationalSparseGP(..., mean_function=mean,...)

Another way to do it (more general):

def model(X, y):
    pyro.module("linear", linear)
    m = linear(X)
    gpmodel.set_data(X, None)
    f, f_var = gpmodel.model()
    noise = pyro.param(...)
    pyro.sample("y", dist.Cauchy(m + f, f_var + noise), obs=y)

def guide(X, y):
    gpmodel.guide()

svi = SVI(model, guide, ...)

As you can see, gp module is pretty flexible to create sophisticated probabilistic models (which is unique among all gp frameworks which I know).

1 Like

Oh my that looks amazing! I think I’m definitely gonna start with the model/guide defined setup, just because it’d be easier to slide into the mixed effects example. I’ll post some code here once I have a working model going! Thanks again @fehiepsi!

1 Like

@fehiepsi Here’s an example I wrote up:

## Sim data
N = 1000
x1 = dist.Uniform(0.0, 5.0).sample(sample_shape=(N,))
x2 = dist.Normal(0., 0.005).sample(sample_shape=(N,))
y = 0.2 + \
    0.5 * torch.sin(3*x1) + \
    0.05 * x2 + dist.Normal(0.0, 0.1).sample(sample_shape=(N,))

## Have GP on x1 (the sine part)
kern = gp.kernels.Matern52(input_dim=1, variance=torch.tensor(0.1),
                         lengthscale=torch.tensor(10.))
gpmodel = gp.models.GPRegression(x1, y, kernel = kern)

def model(x1, x2, y):
    a = pyro.sample("a", dist.Normal(8., 1000.))
    b_2 = pyro.sample("b2", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))    
    mean = a + b_2 * x2
    gpmodel.set_data(x1, None)
    f, f_var = gpmodel.model()
    pyro.sample("y", dist.Normal(mean + f, f_var + sigma), obs=y)

def guide(x1, x2, y)    :
        gpmodel.guide()

svi = SVI(model, 
          guide, 
          optim.Adam({"lr": .01}), 
          loss=Trace_ELBO())
pyro.clear_param_store()
for i in range(50):
    elbo = svi.step(y, x1, x2)

Here’s the param_store outputs I get:

for name in pyro.get_param_store().get_all_p    aram_names():
        print("[%s]: %.3f" % (name, pyro.param(name).data.numpy()))

[Matern52$$$variance]: 0.116
[Matern52$$$lengthscale]: 10.498
[GPR$$$noise]: 1.155

Here’s the questions that I have:

  1. Is my setup correct to begin with? It looks like my mean function isn’t feeding into my GP model at all…

  2. Speaking of, I had to define GP model outside my model function so that guide could refer to it, but that also means that I had to create the GP model object without defining the mean function. I guess technically, the pyro.sample over a Normal takes care of that by adding in the mean function in the last line of model().

  3. Why are my a and b_2 parameters not getting picked up in param store? Is this because the guide is 100% coming from the GPModel and so it’s overriding everything else?

You should define a guide for a, b2, sigma likes any pyro program. gpmodel.guide() only build guides for its own latent variables (such as lengthscale if you set prior for lengthscale).

If you want guide automatically built for a, b2 then you can do this way

class Linear(Reparameterized):
    def __init__(self, a, b2):
        self.a = Parameter(a)
        self.b2 = Parameter(b2)
        self.set_prior("a", dist.Normal(8., 1000))
        self.set_prior("b2", dist.Normal(0., 1.))
    def forward(self, x):
        a = self.get_param("a")
        b = self.get_param("b2")
        return a + b * x

linear = Linear(torch.tensor(0.), torch.tensor(0.))

Then you can define model and guide like this

def model(x1, x2, y):
    linear.set_mode("model")
    f, f_var = gpmodel.model()
    mean = linear(x2)
    sigma = pyro.param("sigma", torch.tensor(1.), constraint=torch.distributions.constraints.positive)
    pyro.sample("y", dist.Normal(mean + f, f_var + sigma), obs=y)

def guide(x1, x2, y):
    linear.set_mode("guide")
    # in Pyro, we run `guide` before `model`; for GPR, it does not
    # matter though because GPR guide does not interact with data
    gpmodel.set_data(x1, None)  
    gpmodel.guide()

A simpler way is to merge x1 and x2. And set active_dims=[0] for RBF kernel. Then modify Linear to take effect only on second dimension

    def forward(self, x):
         ...
         return a + b * x[1]  # assume x.shape = 2 x num_data

then define

gpmodel = GPR(X, y, kernel, mean_function=linear)
svi = SVI(gpmodel.model, gpmodel.guide,...)
for i in range(1000):
    # if you want to change data (like batch training), then call `gpmodel.set_data(Xnew, ynew)`
    svi.step()

I prefer the second option. It is simpler but requires us understand how kernel interact with input under the hood. The first option is easier to understand I guess.

As in the deep kernel learning example , I used input warping function to feed a cnn into the kernel. That way is simple. But you can definitely create a separate model and guide if you don’t want to use warping kernel.

1 Like

@sadatnfs Have you made it work? If not yet, then I’ll make an explicit code for you.

@fehiepsi oh sorry I was a bit occupied with my regular job stuff during the day, I’m planning to give it a try tonight! Thanks!

I got things to work it looks like! Obviously my sim data is crap, so I’d have to properly sim it, but thanks again!!

The next thing I’m planning to do is to simulate draws from the posterior of the params, so that’ll be fun!

class Linear(Parameterized):
    def __init__(self, a, b_2):
        super(Linear, self).__init__()
        self.a = Parameter(a)
        self.b_2 = Parameter(b_2)
        self.set_prior("a", dist.Normal(0., 100.))
        self.set_prior("b_2", dist.Normal(0., 100.))
    def forward(self, x):
        a = self.get_param("a")
        b_2 = self.get_param("b_2")
        return a + b_2 * x[:,1]

linear = Linear(torch.tensor(0.), torch.tensor(0.))

def model(x, y):
    linear.set_mode("model")
    f, f_var = gpmodel.model()
    mean = linear(x)
    sigma = pyro.param("sigma", 
                       torch.tensor(1.), 
                       constraint=torch.distributions.constraints.positive)
    pyro.sample("y", dist.Normal(mean + f, f_var + sigma), obs=y)

def guide(x, y):
    linear.set_mode("guide")
    # in Pyro, we run `guide` before `model`; for GPR, it does not
    # matter though because GPR guide does not interact with data
    gpmodel.set_data(x, None)  
    gpmodel.guide()
    
# ## GP kern
kern = gp.kernels.Matern52(input_dim=1, active_dims=[0],
                           variance=torch.tensor(0.1),
                         lengthscale=torch.tensor(1.))
gpmodel = gp.models.GPRegression(X, y, kern)

@fehiepsi any tips on the easiest way to sample from the parameters given this setup? The end goal is to be able to get draws of the parameters so that I can create draws of my output y

Looks like GPRegression has an optimize() function which can be run after setting data (giving me a mean and covariance matrix to simulate outputs off of), but given that I’m running SVI, I probably wouldn’t want that eh?

I’d recommend using TracePredictive class (as in @neerajprad’s tutorial). You first need to set svi.num_samples = 10000, then simply call svi.run(x, y) to generate trace posterior with 10000 traces from your guide. After this, you have your posterior samples. To get statistics (e.g. quantile, mean, std) of parameters in these posterior traces, you can use EmpiricalMarginal class. To get statistics of prediction on new data, you can use TracePredictive (note that you have to run with y=None to transform y from observable node to latent node).

A statistician would follow the above method (draw samples of parameters,…). GP optimize is useful to predict output on X_new, using only 1 sample from guide. If your guide does not have any pyro.sample(...) call (except sampling from Delta distribution), then it is safe to use linear.set_mode("guide"); f_new, f_new_var = gpmodel(X_new); y_new = f_new + linear(X_new). If you set linear is a mean function (as I suggested in my last comment), then to get output on X_new, you just simply call y_new, y_new_var = gpmodel(X_new).

1 Like

Hey @fehiepsi I was trying to implement the solution suggested in this thread for a similar problem. I am attaching my code that generates the synthetic data and tries to combines GPR and a simple linear model.

import random
import numpy as np
import torch
import pyro
import pyro.distributions as dist
import pyro.optim as optim
import pyro.contrib.gp as gp
import pyro.nn.module as mod
from matplotlib import pyplot as plt
from torch.distributions import constraints
from pyro.infer.autoguide.guides import AutoDelta, AutoDiagonalNormal
from pyro.infer.autoguide import init_to_mean, init_to_feasible
from pyro import poutine
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete, Trace_ELBO, Predictive, MCMC, NUTS

# %% Create trajectories
np.random.seed(10)
samp_t = np.linspace(0, 10, num=100)
samp_traj = 2 * samp_t + np.random.rand(np.size(samp_t, 0),)*3


X = torch.tensor(samp_t, dtype=torch.float)
y = torch.tensor(samp_traj, dtype=torch.float)
# %% Training


class Linear(gp.Parameterized):

    def __init__(self):
        super(Linear, self).__init__()
        self.a = mod.PyroSample(dist.Normal(0., 5.))
        self.b_x0 = mod.PyroSample(dist.Normal(0., 3.))

    def forward(self, X):
        a = self.a
        b_x0 = self.b_x0
        m = a + b_x0 * X
        return m


linear = Linear()


def model(X, y):

    linear.set_mode("model")
    f, f_var = gpmodel.model()
    mean = linear(X)
    sigma = pyro.param("sigma",
                       torch.tensor(1.),
                       constraint=torch.distributions.constraints.positive)
    with pyro.plate("data", len(y)):
        pyro.sample("obs", dist.Normal(mean + f, sigma + f_var), obs=y)


def guide(X, y):
    linear.set_mode("guide")
    gpmodel.set_data(X, None)
    gpmodel.guide()


# %% Using SVI
pyro.clear_param_store()
pyro.set_rng_seed(1524)
kern = gp.kernels.Matern52(input_dim=1,
                           variance=torch.tensor(0.1),
                           lengthscale=torch.tensor(1.))
gpmodel = gp.models.GPRegression(X, y, kern, mean_function=linear)

svi = SVI(model,
          guide,
          optim.Adam({"lr": .01}),
          loss=pyro.infer.Trace_ELBO().differentiable_loss)
num_iters = 500
losses = []
for i in range(num_iters):

    elbo = svi.step(X, y)
    losses.append(elbo)
    print(elbo)

However, I receive an error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    156             try:
--> 157                 ret = self.fn(*args, **kwargs)
    158             except (ValueError, RuntimeError):

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      7     with context:
----> 8         return fn(*args, **kwargs)
      9 

~/etl_tasks/etl_tasks/out_of_warehouse/wound_analysis/rish_analyses/general_experiments/probabilistic_programs/mixture_gp_test.py in model(X, y)
     47     f, f_var = gpmodel.model()
---> 48     mean = linear(X)
     49     sigma = pyro.param("sigma",

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    285         with self._pyro_context:
--> 286             return super().__call__(*args, **kwargs)
    287 

/anaconda3/envs/experimental/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    540         else:
--> 541             result = self.forward(*input, **kwargs)
    542         for hook in self._forward_hooks.values():

~/etl_tasks/etl_tasks/out_of_warehouse/wound_analysis/rish_analyses/general_experiments/probabilistic_programs/mixture_gp_test.py in forward(self, X)
     34     def forward(self, X):
---> 35         a = self.a
     36         b_x0 = self.b_x0

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/nn/module.py in __getattr__(self, name)
    327                             prior = prior(self)
--> 328                         value = pyro.sample(fullname, prior)
    329                         context.set(fullname, value)

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
    109         # apply the stack and return its return value
--> 110         apply_stack(msg)
    111         return msg["value"]

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
    197     for frame in stack[-pointer:]:
--> 198         frame._postprocess_message(msg)
    199 

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/messenger.py in _postprocess_message(self, msg)
    137         if hasattr(self, method_name):
--> 138             return getattr(self, method_name)(msg)
    139         return None

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in _pyro_post_sample(self, msg)
    118         if not self.param_only:
--> 119             self.trace.add_node(msg["name"], **msg.copy())
    120 

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/trace_struct.py in add_node(self, site_name, **kwargs)
    108                 # Cannot sample after a previous sample statement.
--> 109                 raise RuntimeError("Multiple {} sites named '{}'".format(kwargs['type'], site_name))
    110 

RuntimeError: Multiple sample sites named 'mean_function.a'

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
~/etl_tasks/etl_tasks/out_of_warehouse/wound_analysis/rish_analyses/general_experiments/probabilistic_programs/mixture_gp_test.py in 
     75 losses = []
     76 for i in range(num_iters):
---> 77     elbo = svi.step(X, y)
     78     losses.append(elbo)
     79     print(elbo)

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
    123         # get loss and compute gradients
    124         with poutine.trace(param_only=True) as param_capture:
--> 125             loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    126 
    127         params = set(site["value"].unconstrained()

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/infer/svi.py in _loss_and_grads(*args, **kwargs)
     66             if loss_and_grads is None:
     67                 def _loss_and_grads(*args, **kwargs):
---> 68                     loss_val = loss(*args, **kwargs)
     69                     if getattr(loss_val, 'requires_grad', False):
     70                         loss_val.backward(retain_graph=True)

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/infer/trace_elbo.py in differentiable_loss(self, model, guide, *args, **kwargs)
    104         loss = 0.
    105         surrogate_loss = 0.
--> 106         for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
    107             loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace)
    108             surrogate_loss += surrogate_loss_particle / self.num_particles

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/infer/elbo.py in _get_traces(self, model, guide, args, kwargs)
    165         else:
    166             for i in range(self.num_particles):
--> 167                 yield self._get_trace(model, guide, args, kwargs)

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/infer/trace_elbo.py in _get_trace(self, model, guide, args, kwargs)
     48         """
     49         model_trace, guide_trace = get_importance_trace(
---> 50             "flat", self.max_plate_nesting, model, guide, args, kwargs)
     51         if is_validation_enabled():
     52             check_if_enumerated(guide_trace)

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     43         guide_trace.detach_()
     44     model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
---> 45                                 graph_type=graph_type).get_trace(*args, **kwargs)
     46     if is_validation_enabled():
     47         check_model_guide_match(model_trace, guide_trace, max_plate_nesting)

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    175         Calls this poutine and returns its trace instead of the function's return value.
    176         """
--> 177         self(*args, **kwargs)
    178         return self.msngr.get_trace()

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    159                 exc_type, exc_value, traceback = sys.exc_info()
    160                 shapes = self.msngr.trace.format_shapes()
--> 161                 raise exc_type(u"{}\n{}".format(exc_value, shapes)).with_traceback(traceback)
    162             self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret)
    163         return ret

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    155                                       args=args, kwargs=kwargs)
    156             try:
--> 157                 ret = self.fn(*args, **kwargs)
    158             except (ValueError, RuntimeError):
    159                 exc_type, exc_value, traceback = sys.exc_info()

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      6 def _context_wrap(context, fn, *args, **kwargs):
      7     with context:
----> 8         return fn(*args, **kwargs)
      9 
     10 

~/etl_tasks/etl_tasks/out_of_warehouse/wound_analysis/rish_analyses/general_experiments/probabilistic_programs/mixture_gp_test.py in model(X, y)
     46     linear.set_mode("model")
     47     f, f_var = gpmodel.model()
---> 48     mean = linear(X)
     49     sigma = pyro.param("sigma",
     50                        torch.tensor(1.),

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    284     def __call__(self, *args, **kwargs):
    285         with self._pyro_context:
--> 286             return super().__call__(*args, **kwargs)
    287 
    288     def __getattr__(self, name):

/anaconda3/envs/experimental/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    539             result = self._slow_forward(*input, **kwargs)
    540         else:
--> 541             result = self.forward(*input, **kwargs)
    542         for hook in self._forward_hooks.values():
    543             hook_result = hook(self, input, result)

~/etl_tasks/etl_tasks/out_of_warehouse/wound_analysis/rish_analyses/general_experiments/probabilistic_programs/mixture_gp_test.py in forward(self, X)
     33 
     34     def forward(self, X):
---> 35         a = self.a
     36         b_x0 = self.b_x0
     37         m = a + b_x0 * X

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/nn/module.py in __getattr__(self, name)
    326                         if not hasattr(prior, "sample"):  # if not a distribution
    327                             prior = prior(self)
--> 328                         value = pyro.sample(fullname, prior)
    329                         context.set(fullname, value)
    330                     return value

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
    108             msg["is_observed"] = True
    109         # apply the stack and return its return value
--> 110         apply_stack(msg)
    111         return msg["value"]
    112 

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
    196 
    197     for frame in stack[-pointer:]:
--> 198         frame._postprocess_message(msg)
    199 
    200     cont = msg["continuation"]

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/messenger.py in _postprocess_message(self, msg)
    136         method_name = "_pyro_post_{}".format(msg["type"])
    137         if hasattr(self, method_name):
--> 138             return getattr(self, method_name)(msg)
    139         return None
    140 

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in _pyro_post_sample(self, msg)
    117     def _pyro_post_sample(self, msg):
    118         if not self.param_only:
--> 119             self.trace.add_node(msg["name"], **msg.copy())
    120 
    121     def _pyro_post_param(self, msg):

/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/trace_struct.py in add_node(self, site_name, **kwargs)
    107             elif kwargs['type'] != "param":
    108                 # Cannot sample after a previous sample statement.
--> 109                 raise RuntimeError("Multiple {} sites named '{}'".format(kwargs['type'], site_name))
    110 
    111         # XXX should copy in case site gets mutated, or dont bother?

RuntimeError: Multiple sample sites named 'mean_function.a'
          Trace Shapes:  
           Param Sites:  
     kernel.lengthscale  
        kernel.variance  
                  noise  
          Sample Sites:  
   mean_function.a dist |
                  value |
mean_function.b_x0 dist |
                  value |

Could you suggest a solution to this problem? I tried removing the mean_function argument from gp.models.GPRegression , the error disappears but the model never converges to a solution.

Thanks

@grishabhg I am not sure if it is intended but it seems to me your mean plays the role of mean_function and your sigma plays the role of noise. That said, I believe training with gpmodel.model and gpmodel.guide will do the job for you.

Assume that it is your intention, there are two cases here. In case linear is different from mean_function, you can train model, guide pair of the following module:

class Model(gp.Parameterized):
    def __init__(self, gpmodel):
        super().__init__()
        self.linear = Linear()
        self.gpmodel = gpmodel

    @pyro.nn.pyro_method
    def model(self, X, y):
        self.linear.set_mode("model")
        f, f_var = self.gpmodel.model()
        mean = self.linear(X)
        sigma = pyro.param("sigma",
                       torch.tensor(1.),
                       constraint=torch.distributions.constraints.positive)
        with pyro.plate("data", len(y)):
            pyro.sample("obs", dist.Normal(mean + f, sigma + f_var), obs=y)

    @pyro.nn.pyro_method
    def guide(self, X, y):
        self.linear.set_mode("guide")
        # this is required to trigger `sample` statements
        self.linear._load_pyro_samples()
        self.gpmodel.set_data(X, None)
        self.gpmodel.guide()

If linear is mean_function, you can replace self.linear(X) by self.gpmodel.mean_function(X).

I tried removing the mean_function argument from gp.models.GPRegression , the error disappears but the model never converges to a solution.

I think this is the right way to do (i.e. using only one of linear and mean_function). For mixed effect model, this gist might be helpful for you. :slight_smile:

Hey @fehiepsi thanks for your reply. I tried your solution and it works. But I have some follow up questions which probably originate from my limited knowledge of Pyro and bayesian models in general. First of all here is how I implemented your suggested solution:

# %% Training

class Linear(gp.Parameterized):
    def __init__(self):
        super(Linear, self).__init__()
        self.a = mod.PyroSample(dist.Normal(0., 5.))
        self.b_x0 = mod.PyroSample(dist.Normal(0., 3.))

    def forward(self, X):
        a = self.a
        b_x0 = self.b_x0
        m = a + b_x0 * X
        return m

class Model(gp.Parameterized):
    def __init__(self, gpmodel):
        super().__init__()
        self.linear = Linear()
        self.gpmodel = gpmodel

    @pyro.nn.pyro_method
    def model(self, X, y):
        self.linear.set_mode("model")
        f, f_var = self.gpmodel.model()
        mean = self.linear(X)
        sigma = pyro.param("sigma",
                           torch.tensor(1.),
                           constraint=torch.distributions.constraints.positive)
        with pyro.plate("data", len(y)):
            pyro.sample("obs", dist.Normal(mean + f, sigma + f_var), obs=y)

    @pyro.nn.pyro_method
    def guide(self, X, y):
        self.linear.set_mode("guide")
        # this is required to trigger `sample` statements
        self.linear._load_pyro_samples()
        self.gpmodel.set_data(X, None)
        self.gpmodel.guide()

# %% Using SVI
pyro.clear_param_store()
pyro.set_rng_seed(1524)
kern = gp.kernels.Matern52(input_dim=1,
                           variance=torch.tensor(0.1),
                           lengthscale=torch.tensor(1.))
gpmodel = gp.models.GPRegression(X, y, kern)
mod = Model(gpmodel)
svi = SVI(mod.model,
          mod.guide,
          optim.Adam({"lr": .01}),
          loss=pyro.infer.Trace_ELBO().differentiable_loss)
num_iters = 500
losses = []
for i in range(num_iters):
    elbo = svi.step(X, y)
    losses.append(elbo)
    print(elbo)

Here are my questions:

  1. In the guide function when we set self.linear.set_mode("guide") and then call self.gpmodel.guide, does it mean that in the main guide the parameters from linear model and gpmodel are getting appended somehow?

  2. Now in order to get posterior distribution of my parameters I did this:

     # %%
     num_samples = 1000
     predictive = Predictive(mod.model, guide=mod.guide, num_samples=num_samples)
     for k, v in predictive(X, y).items():
         print(k)
    
     data = predictive(X, y)
    

    The only parameters that I can see with this are : linear.a, linear.b_x0 and obs. I was wondering why I dont see lengthscale and variance here?

  3. If I wanted to get posterior predictive distribution over both Xnew* and Ynew* (where Ynew* is some initial observation values of the trend) would I need to make any changes in code snippets above?

  4. Also how do I get predicted values for some Ynew, where Xnew > Xnew* i.e., new time points that extend beyond the values that were used to build the posterior distribution?

  5. Under class Model and under function guide when we set data self.gpmodel.set_data(X, None) why do we set y = None? If we set it to none, then during inference when we want to condition our posterior on Ynew*, how would it include that information?

  6. My thinking behind developing this model was that: the Linear function would approximate the major trend in the data and GP regression would help me model some subtle variations in the trend. Do you think then this is the right approach?

Sorry if some of the questions dont make sense, I can elaborate if something doesnt make sense. But your answers will solve a lot my problems here.

Thanks a lot

@grishabhg Hope that the following points answer your questions

  1. When you set_mode("guide"), latent variables will be generated from “some params” (in your model, they are MAP points). Parameters are generated automatically if you call gpmodel.autoguide() method (Delta guide is called by default).

  2. lengthscale and variance are not latent variables in your model. They are parameters. You can get their values with list(gpmodel.named_pyro_params()) or gpmodel.kernel.variance.

3,4. If you want to predict, it is better to use the forward method gpmodel(Xnew*). It will condition on the training data X,y (otherwise, you only condition on learned parameters or latent samples).

  1. I thought it is your intention to set y=None? I rarely use that pattern. When you set y=None, gpmodel.model() will return f_loc and f_var. I think it is only useful when you want to build DeepGP or something like that.

  2. Yes, but again, you should use only one of mean_function or linear. IMO, it is better to only use mean_function. The mean_function is quite flexible (e.g. it can be a neural net, or some random effect (as in the above gist)).

1 Like