RVM using Pyro

Hello! Thanks for the great framework!

I’ve been trying to implement a Relevance Vector Machine using Pyro. I have a problem regarding the basis function expansion.
When I apply the kernel over the input variables, it return a Tensor with the following shape:
(X.shape[0], X.shape[0]).
Following the RVM approach, I need a unique hyperparameter for each vector, so I sample a gamma distribution and expand it to the X.shape[0] dimensions. The training is done via MCMC.
The problem is on test time. As the model expects a Tensor with the same number of “rows” as the trainning set, I am unable to do inference on new data.
Here is my model:

def model(X, y=None):
  b0 = pyro.sample('b0', dist.Normal(0, 10))
  rbf = gp.kernels.RBF(input_dim=1, lengthscale=torch.ones(1))
  inv_sigma = pyro.sample('inv_sigma', dist.Gamma(1e-4, 1e-4))
  k = rbf(X, X)
  with pyro.plate('data', X.shape[0]):
    gamma = pyro.sample('gamma', dist.Gamma(1e-4, 1e-4))
    alpha = pyro.sample('alpha', dist.Gamma(1e-4, gamma))
    beta = pyro.sample('beta', dist.Normal(0, alpha))
    mu = pyro.deterministic('mu', b0.float() + torch.matmul(k.float(), beta.float()))
    pyro.sample('obs', dist.Normal(mu, inv_sigma), obs=y)

How can I change my model so I am able to to inference?
Maybe I am missing some really simple detail.