Using GP hyperparameters with plates

I’m working with a model that is specifying a set of GPs, one for each group member in the dataset:

    ls = pyro.sample('ls', Gamma(torch.FloatTensor([5.]).to(device), torch.FloatTensor([0.5]).to(device))) 
    amp = pyro.sample('amp', Gamma(torch.FloatTensor([2.]).to(device), torch.FloatTensor([1.]).to(device)))

    with pyro.plate('venues', V): 
        K_w =gp.kernels.RBF(
                input_dim=1,
                variance=amp, 
                lengthscale=ls
        )
        cov_alpha = K_w(torch.FloatTensor(weeks).to(device))
        cov_alpha.view(-1)[::W+1] += jitter
        alpha = pyro.sample('alpha', MultivariateNormal(torch.zeros(W).to(device), covariance_matrix=cov_alpha))

This works fine, as you would expect. However, if I want to specify individual-specific length scales, I assume I just have to move the ls statement into the plate:

    amp = pyro.sample('amp', Gamma(torch.FloatTensor([2.]).to(device), torch.FloatTensor([1.]).to(device)))

    with pyro.plate('venues', V): 
        ls = pyro.sample('ls', Gamma(torch.FloatTensor([5.]).to(device), torch.FloatTensor([0.5]).to(device))) 
        K_w =gp.kernels.RBF(
                input_dim=1,
                variance=amp, 
                lengthscale=ls
        )
        cov_alpha = K_w(torch.FloatTensor(weeks).to(device))
        cov_alpha.view(-1)[::W+1] += jitter
        alpha = pyro.sample('alpha', MultivariateNormal(torch.zeros(W).to(device), covariance_matrix=cov_alpha))

This model runs fine, but when I try to pull out estimates,

estimates = guide.quantiles([0.025, 0.5, 0.975])

I get a RuntimeError:

~/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/torch/distributions/normal.py in icdf(self, value)
     83 
     84     def icdf(self, value):
---> 85         return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2)
     86 
     87     def entropy(self):

RuntimeError: The size of tensor a (441) must match the size of tensor b (3) at non-singleton dimension 0

Moreover, if I try and bring both hyperparameters into the plate, the model does not run at all:

    with pyro.plate('venues', V): 
        ls = pyro.sample('ls', Gamma(torch.FloatTensor([5.]).to(device), torch.FloatTensor([0.5]).to(device))) 
        amp = pyro.sample('amp', Gamma(torch.FloatTensor([2.]).to(device), torch.FloatTensor([1.]).to(device)))
        K_w =gp.kernels.RBF(
                input_dim=1,
                variance=amp, 
                lengthscale=ls
        )
        cov_alpha = K_w(torch.FloatTensor(weeks).to(device))
        cov_alpha.view(-1)[::W+1] += jitter
        alpha = pyro.sample('alpha', MultivariateNormal(torch.zeros(W).to(device), covariance_matrix=cov_alpha))
~/anaconda3/envs/pytorch_latest_p37/lib/python3.7/site-packages/pyro/contrib/gp/kernels/isotropic.py in forward(self, X, Z, diag)
     87         r2 = self._square_scaled_dist(X, Z)
---> 88         return self.variance * torch.exp(-0.5 * r2)
     89 

RuntimeError: The size of tensor a (32) must match the size of tensor b (67) at non-singleton dimension 1

I get the feeling I’m just doing something dumb here, but I can’t see it. Is it something specific with GP parameters that is different?

I think this is a limitation of GP module (or in general, any torch.nn module that does not support “batched” parameters, like nn.Linear). The internal implementation assumes that lenghscale is either scalar or 1D vector with size input_dim. We should revise the implementation to support batched GP kernels but it would take a bit of time to cover all kernels and utilities. :frowning:

OK, thanks. I thought I was overlooking something simple.

@fonnesbeck you might have more luck with gpytorch+pyro

Thank you. Yeah, I was aware of that integration, but I was just trying to keep things simple–I just wanted to model some non-linear effects as a latent GP as part of a larger model.

The other limitation here is that I am not just integrating one GP into the model, I am using several, and adding them together, indexing them as appropriate. All of the pyro+gpytorch integration examples seem to assume a single GP, with no obvious extension to multiple GPs.