Size error when setting variance prior on GP kernel

I am trying to specify priors on a kernel that is being used in a GP as a submodel of a larger model. The kernel is specified in a plate as follows:

with pyro.plate("venues", V):
        K_v = gp.kernels.Sum(
            gp.kernels.Matern32(
                input_dim=1, 
                lengthscale=torch.tensor(4.), 
                variance=torch.tensor(1.)
            ),
            gp.kernels.RBF(
                input_dim=1,
                lengthscale=torch.tensor(15.),
                variance=torch.tensor(1.)
            )
        )

I can set the lenghtscale priors as follows:

        K_v.kern0.set_prior(
            "lengthscale",
            Uniform(torch.tensor(5.0), torch.tensor(10.0))
        )
        K_v.kern1.set_prior(
            "lengthscale",
            Uniform(torch.tensor(10.0), torch.tensor(20.0))
        )

and the model fits with SVI without error.

However, when I try to do the same for the variance:

        K_v.kern0.set_prior(
            "variance",
            Gamma(torch.tensor(2.0), torch.tensor(0.5))
        )
        K_v.kern1.set_prior(
            "variance",
            Gamma(torch.tensor(2.0), torch.tensor(0.5))
        )

I get a RuntimeError that complains about its shape:

~/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/pyro/contrib/gp/kernels/isotropic.py in forward(self, X, Z, diag)
    148         r = self._scaled_dist(X, Z)
    149         sqrt3_r = 3**0.5 * r
--> 150         return self.variance * (1 + sqrt3_r) * torch.exp(-sqrt3_r)
    151 
    152 

RuntimeError: The size of tensor a (211) must match the size of tensor b (252) at non-singleton dimension 1
      Trace Shapes:                  
       Param Sites:                  
      Sample Sites:                  
          m_mu dist          |       
              value          |       
          s_mu dist          |       
              value          |       
            mu dist     4207 |       
              value     4207 |       
   lengthscale dist       13 |       
              value       13 |       
       f_tilde dist 252   13 |       
              value 252   13 |       
         noise dist          |       
              value          |       
          beta dist          | 252 13
              value          | 252 13
kern0.variance dist      211 |       
              value      211 |       
Trace Shapes:
 Param Sites:
Sample Sites:

This is confusing because a) its a scalar prior, so I expect it to work irrespective of the shapes of other things and b) lenghtscale prior setting works just fine.

I have tried adding .to_event() to the priors, but this does not resolve the problem.

Any ideas?

I’ve also tried simply setting a constraint on the variance via:

K_v.kern0.set_constraint("variance", constraints.interval(torch.tensor(1.0), torch.tensor(10.0)))

However, there does not appear to be a set_constraint method on the kernel:

AttributeError: 'Matern32' object has no attribute 'set_constraint'

I thought that kernels were supposed to have this method?

So, I tried another approach, using PyroSample:

with pyro.plate("venues", V):
    ...
    K_v.kern0.lengthscale = pyro.nn.PyroSample(Uniform(torch.tensor(5.0), torch.tensor(10.0)))
    K_v.kern0.variance = pyro.nn.PyroSample(Gamma(torch.tensor(2.0), torch.tensor(0.5)))

Also results in shape issues, though in a different place:

~/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/pyro/contrib/gp/kernels/isotropic.py in _square_scaled_dist(self, X, Z)
     50 
---> 51         scaled_X = X / self.lengthscale
     52         scaled_Z = Z / self.lengthscale

RuntimeError: The size of tensor a (252) must match the size of tensor b (13) at non-singleton dimension 0
   Trace Shapes:         
    Param Sites:         
   Sample Sites:         
       m_mu dist        |
           value        |
       s_mu dist        |
           value        |
         mu dist 4207 1 |
           value 4207 1 |
lengthscale dist   13 1 |
           value   13 1 |

Am I wrong in assuming scalars should be automatically handled correctly by the plate?

Hi @fonnesbeck, to set a constraint for a parameter, you can use the pattern

K_v.kern0.variance = PyroParam(..., constraint=...)

For sample attributes under plate, I think this will work

K_v.kern0.lengthscale = pyro.nn.PyroSample(...)
K_v.kern0.variance = pyro.nn.PyroSample(...)

def forward(self, ...):
    with pyro.plate("venues", V):
        f = self.K_v.kern0(X, X)

Under the hood, when K_v.kern0 is called, we will need to access the attribute lengthscale of K_v.kern0. At that point, a pyro.sample statement with distribution defined in pyro.nn.PyroSample will be called, and Uniform(torch.tensor(5.0), torch.tensor(10.0)) will be expanded properly. This is explained in this section.

I would prefer to define the shape for prior eagerly:

K_v.kern0.lengthscale = pyro.nn.PyroSample(Uniform(torch.tensor(5.0), torch.tensor(10.0)).expand(...).to_event())

Edit: Oops, looking like you need batched GP? Currently, pyro.contrib.gp does not support that. :frowning:

I’m not using the GPRegession class here, I am just using the kernels to generate covariance matrices for MultivariateNormal variables that get used elsewhere. For example:

    with pyro.plate("levels", L):

        ...

        cov_beta = K_l(torch.arange(D, device=device)).contiguous()

        with pyro.plate("level_days", D):
            f_tilde = pyro.sample("f_tilde", Normal(torch.tensor(0.0), torch.tensor(1.0)))

        f = pyro.deterministic(
            "f", (cov_beta + torch.eye(D, device=device) * jitter).cholesky() @ f_tilde
        )

    noise = pyro.sample("noise", HalfNormal(torch.tensor(1.0)))
    beta = pyro.sample("beta", Normal(loc=f, scale=noise).to_event())

I’m also specifying the model as a function, not as a PyroModule. Is it still possible to use PyroSample in a function?

I see. Unfortunately, the kernels also do not support “batching” (back to a few years ago, I didn’t have much experience writing codes supporting broadcasting). :frowning:

Is it still possible to use PyroSample in a function?

Yes, as long as it is an attribute of a PyroModule (e.g. the gp kernels are PyroModules). When that attribute is accessed, a pyro.sample statement will be called for that attribute and returns a sample.

If I cannot batch them, can I still create several GPs by looping over the plate, instead of using the context manager? I realize this will affect the speed of the fitting.

Yes, it seems that you can do it with ModuleList (let me know if it is not working for your case - I haven’t used ModuleList previously):

import torch.nn as nn
from pyro.nn import PyroModule

class A(PyroModule):
    def __init__(self):
        super().__init__()
        # create a list of GPs by looping over range(4)
        self.gps = nn.ModuleList([PyroModule[nn.Linear](2, 3) for i in range(4)])

    def forward(self, x):
        # using list indexing to access a gp in the list
        return self.gps[0](x)

a = A()
{k: v for k, v in a.named_parameters()}
1 Like