Using pyro.contrib.gp inside a model()

I’m learning Pyro, and trying to get my head around the global hidden state, in particular within the Gaussian Process pyro.contrib. I’ve reduced my model into a simple latent Poisson rate model to show below.

The following code shows how I’m trying to approach the problem with the pyro.contrib.gp which I can’t get working, and I’ve also written a GP manually (switched with use_contrib argument) that does work.

The variance and lengthscale don’t seem to be fitting via pyro.contrib.gp. I’m fitting with an AutoDelta for MAP estimation, but the same results occur when using other guides such as AutoMutlivariateNormal.

def model(x: torch.Tensor,
          count_obs: Optional[torch.Tensor] = None,
          x_pred: Optional[torch.Tensor] = None,
          use_contrib=False):

    n = len(x)
    x = x.unsqueeze(-1)
    jitter = torch.tensor(1e-4)

    # Priors for RBF kernel lengthscale and variance.
    lam = pyro.sample('lam', dist.Gamma(torch.tensor(14.), torch.tensor(1.)))
    tau = pyro.sample('tau', dist.Gamma(torch.tensor(2.), torch.tensor(4.)))

    if use_contrib:
        # Covariance matrix. (contrib: this does not work)
        kern = gp.kernels.RBF(1)
        kern.variance = tau
        kern.lengthscale = lam
        cov_x = kern.forward(x)
    else:
        # Covariance matrix. (manual: this works)
        cov_x = tau * torch.exp(-(x - x.t())**2/lam**2)

    # Sample latent rate.
    cov_x = cov_x + jitter * torch.eye(n)
    lograte = pyro.sample('lograte', dist.MultivariateNormal(torch.zeros(n), cov_x))

    # Observations.
    with pyro.plate('x', n):
        pyro.sample('count_obs', dist.Poisson(rate=lograte.exp()), obs=count_obs)

    # Predictions.
    if x_pred is not None:

        if use_contrib:
            # Get a GP for the estimated latent rate. (contrib: this does not work)
            kern = gp.kernels.RBF(1)
            kern.variance = tau
            kern.lengthscale = lam
            gp_pred = gp.models.GPRegression(x.squeeze(), lograte, kern, jitter=jitter)
            gp_pred.noise = torch.tensor(0.)  # No noise since this is the latent rate.
            lograte_pred_mean, lograte_pred_cov = gp_pred.forward(x_pred, full_cov=True)
        else:
            # Get a GP for the estimated latent rate. (manual: this works)
            x_pred = x_pred.unsqueeze(-1)
            cov_x_pred = tau * torch.exp(-(x - x_pred.t())**2/lam**2)
            cov_x_pred_pred = tau * torch.exp(-(x_pred - x_pred.t())**2/lam**2)
            cov_x_chol = torch.linalg.cholesky(cov_x)
            alpha = torch.cholesky_solve(lograte.unsqueeze(-1), cov_x_chol)
            v = torch.linalg.solve_triangular(cov_x_chol, cov_x_pred, upper=False)
            lograte_pred_mean = (cov_x_pred.T @ alpha).squeeze()
            lograte_pred_cov = cov_x_pred_pred - v.t() @ v

        # Sample predicted latent rate.
        lograte_pred_cov = lograte_pred_cov + jitter * torch.eye(len(x_pred))
        lograte_pred = pyro.sample('lograte_pred', dist.MultivariateNormal(lograte_pred_mean, lograte_pred_cov))

        # Sample predicted counts.
        with pyro.plate('x_pred', len(x_pred)):
            pyro.sample('count_pred', dist.Poisson(rate=lograte_pred.exp()))

I think you can use PyroSample as in Gaussian Processes — Pyro Tutorials 1.8.4 documentation or

del kernel.variance
kernel.variance = tau

Thank you, the del approach works well, and I can even reuse the kern in the # Predictions block.

The PyroSample approach doesn’t work well for multiple reasons. Looks like when calling the forward function from the RBF kernel it calls _square_scaled_dist, where two calls are made to the lengthscale property to pre-scale the X and Y: pyro/isotropic.py at dev · pyro-ppl/pyro (github.com). Each of those calls results in a different length scale which causes the resulting covariance matrix to be non positive definite leading to a PositiveDefinite constraint error. Fixing this by calling self.lengthscale only once stops the skewed covariance, but then causes a Multiple sample sites named 'kernel.lengthscale' error when trying to reuse the kernel at the prediction call site.

I think you need to use pyromethod to avoid the multiple sample site issue. How about replacing .forward(...) by (...)?

Switching to (...) rather than forward(...) fixed the different length scale problem. Now trying to use PyroSample since that seems to be the correct way to replace an attribute that is init with PyroParam, but for some reason it creates two versions of lengthscale and variance. E.g. for lengthscale one is called "AutoDelta.lengthscale" and the other "kernel.lengthscale_map", and neither of these nor the variance ones get fit properly (or at all?).

I don’t know what should be tagged with the pyro_method decorator, as I’m not creating a new PyroModule class with custom methods, just using the function-style model creation. Here is a notebook that gives the full reproducible context: pyro_gp_learn.ipynb - Colaboratory (google.com).

In contrib.gp, GPR.forward() uses its .guide() method to draw the posterior samples. kernel.lengthscale_map is the parameter of GPR.guide. Here you are using AutoDelta guide, so I think GPR.forward won’t work for you. You might need to implement the prediction by hand if you use AutoDelta. In GP tutorials/examples, we use contrib.gp.util.train to train the models instead of AutoDelta & SVI. You can also find some examples in this gist. Hope this clarify your observations. :slight_smile:

I tried to use the Kernels in contrib.gp to create the covariance matrices with (...) and not use GPR at all and instead predicting by hand like in the manual model, but it seems even that doesn’t work with Auto guides (giving the dreaded RuntimeError: Multiple sample sites named 'lengthscale' error). No worries, I’ll just implement the kernels manually.

Yeah, I think implementing them manually is better if you want to use your custom/auto guides. Just a reference for future readers:

  • To avoid multiple sample sites when using Pyro nn.Module, we should use module(input) rather than module.forward(input).
  • We can use built-in utilities like gp.util.train to fit and make predictions, like in the current tutorials and examples. pyro.contrib.gp currently supports those autoguides: Delta, Normal, (blocked) MultivariateNormal, which might be enough for most usages.