SparseGPRegression error from Pytorch's lowrank_multivariate_normal


#1

Hi,

I’m using pyro 0.2.1 and pytorch 1.0.0.dev20181109.

I was trying the example from GPLVM model example in the documentation which uses SparseGPRegression. I got the following error from pytorch.distribution.lowrank_multivariate_normal

ValueError                                
Traceback (most recent call last)
<ipython-input-38-bb76365aeb43> in <module>()
      1 # Finally, wrap gpmodel by GPLVM, optimize, and get the "learned" mean of X:
      2 gplvm = gp.models.GPLVM(gpmodel)
----> 3 gplvm.optimize()
      4 X = gplvm.get_param("X_loc")

~/.conda/envs/pyro/lib/python3.6/site-packages/pyro/contrib/gp/models/gplvm.py in optimize(self, optimizer, num_steps)
    138         losses = []
    139         for i in range(num_steps):
--> 140             losses.append(svi.step())
    141         return losses

~/.conda/envs/pyro/lib/python3.6/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
     73         # get loss and compute gradients
     74         with poutine.trace(param_only=True) as param_capture:
---> 75             loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
     76 
     77         params = set(site["value"].unconstrained()

~/.conda/envs/pyro/lib/python3.6/site-packages/pyro/infer/trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
    105         elbo = 0.0
    106         # grab a trace from the generator
--> 107         for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
    108             elbo_particle = 0
    109             surrogate_elbo_particle = 0

~/.conda/envs/pyro/lib/python3.6/site-packages/pyro/infer/trace_elbo.py in _get_traces(self, model, guide, *args, **kwargs)
     53         for i in range(self.num_particles):
     54             guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
---> 55             model_trace = poutine.trace(poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
     56             if is_validation_enabled():
     57                 check_model_guide_match(model_trace, guide_trace)

~/.conda/envs/pyro/lib/python3.6/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    200         Calls this poutine and returns its trace instead of the function's return value.
    201         """
--> 202         self(*args, **kwargs)
    203         return self.msngr.get_trace()

~/.conda/envs/pyro/lib/python3.6/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    184                                       name="_INPUT", type="args",
    185                                       args=args, kwargs=kwargs)
--> 186             ret = self.fn(*args, **kwargs)
    187             self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret)
    188         return ret

~/.conda/envs/pyro/lib/python3.6/site-packages/pyro/poutine/messenger.py in _wraps(*args, **kwargs)
     25         def _wraps(*args, **kwargs):
     26             with self:
---> 27                 return fn(*args, **kwargs)
     28         _wraps.msngr = self
     29         return _wraps

~/.conda/envs/pyro/lib/python3.6/site-packages/pyro/contrib/gp/models/gplvm.py in model(self)
     94 
     95         self.base_model.set_data(X, self.y)
---> 96         self.base_model.model()
     97 
     98     def guide(self):

~/.conda/envs/pyro/lib/python3.6/site-packages/pyro/contrib/gp/models/sgpr.py in model(self)
    153             y_name = param_with_module_name(self.name, "y")
    154             return pyro.sample(y_name,
--> 155                                dist.LowRankMultivariateNormal(f_loc, W, D, trace_term)
    156                                    .expand_by(self.y.shape[:-1])
    157                                    .independent(self.y.dim() - 1),

~/.conda/envs/pyro/lib/python3.6/site-packages/torch/distributions/lowrank_multivariate_normal.py in __init__(self, loc, cov_factor, cov_diag, validate_args)
     98         if cov_factor.shape[-2:-1] != event_shape:
     99             raise ValueError("cov_factor must be a batch of matrices with shape {} x m"
--> 100                              .format(event_shape[0]))
    101         if cov_diag.shape[-1:] != event_shape:
    102             raise ValueError("cov_diag must be a batch of vectors with shape {}".format(event_shape))

ValueError: cov_factor must be a batch of matrices with shape 150 x m

I managed to run the example with GPRegression though. I’m very new to probabilistic programming. Any help would be much appreciated.

By the way, GPVLM is not bookmarked in the documentation though it is in there after VariationalSparseGP.

Thanks


#2

if you use pytorch pre-release you will have to use the pre-release version of pyro.


#3

In PyTorch 1.0, the cov_factor is the transpose of W. You need to use Pyro dev instead.

About docs of GPLVM, thanks for reporting! It has been fixed in pyro dev recently.