In the tutorial, the examples show a single variable input into the sparse Gaussian process model. However, we have multiple variables inputs (2-dimensional) in most real-world cases. So, I just revise the tutorial example slightly to the following code before applying it to real-world dataset. And it always pops the error ‘torch.linalg.cholesky: U() is zero, singular U’
I have checked RuntimeError during Cholesky Decomposition and Numerical issue with cholesky decomposition on GitHub, but none of post can really remove the issue.
nd = 15
xu1, xu2 = np.meshgrid(np.linspace(0.0, 5.0, nd), np.linspace(0.0, 52.0, nd))
Xu = np.concatenate([xu1.reshape(nd*nd, 1), xu2.reshape(nd*nd, 1)], 1)
Xu = torch.tensor(Xu,dtype=torch.float32)
N = 500
X = torch.zeros((N,2))
x = dist.Uniform(0.0, 5.0).sample(sample_shape=(N,))
x1 = 0.5 * torch.sin(3*x)
x2 = dist.Normal(0.0, 0.2).sample(sample_shape=(N,))
X[:,0] = x1
X[:,1] = x2
y = x1+x2
from torch.distributions import constraints
# initialize the kernel and model
pyro.clear_param_store()
kernel = gp.kernels.RBF(input_dim=2)
likelihood = gp.likelihoods.Gaussian()
kernel.variance = pyro.param( torch.ones(1), constraint=constraints.positive)
vsgp = gp.models.VariationalSparseGP(X, y, kernel, Xu=Xu, likelihood=likelihood, whiten=True)
# instead of defining our own training loop, we will
# use the built-in support provided by the GP module
num_steps = 1500 if not smoke_test else 2
losses = gp.util.train(vsgp, num_steps=num_steps)
plt.plot(losses);
Full error messages:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
173 try:
--> 174 ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
11 with context:
---> 12 return fn(*args, **kwargs)
13
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/pyro/nn/module.py in cached_fn(self, *args, **kwargs)
635 with self._pyro_context:
--> 636 return fn(self, *args, **kwargs)
637
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/pyro/contrib/gp/models/vsgp.py in model(self)
121 Kuu.view(-1)[:: M + 1] += self.jitter # add jitter to the diagonal
--> 122 Luu = torch.linalg.cholesky(Kuu)
123
RuntimeError: torch.linalg.cholesky: U(41,41) is zero, singular U.
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
/var/folders/qj/w_dwwf3134gfzrp00w5c2qk40000gn/T/ipykernel_3460/2240828194.py in <module>
13 # use the built-in support provided by the GP module
14 num_steps = 1500 if not smoke_test else 2
---> 15 losses = gp.util.train(vsgp, num_steps=num_steps)
16 plt.plot(losses);
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/pyro/contrib/gp/util.py in train(gpmodule, optimizer, loss_fn, retain_graph, num_steps)
190 losses = []
191 for i in range(num_steps):
--> 192 loss = optimizer.step(closure)
193 losses.append(torch_item(loss))
194 return losses
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/torch/optim/optimizer.py in wrapper(*args, **kwargs)
86 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
87 with torch.autograd.profiler.record_function(profile_name):
---> 88 return func(*args, **kwargs)
89 return wrapper
90
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
26 def decorate_context(*args, **kwargs):
27 with self.__class__():
---> 28 return func(*args, **kwargs)
29 return cast(F, decorate_context)
30
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/torch/optim/adam.py in step(self, closure)
64 if closure is not None:
65 with torch.enable_grad():
---> 66 loss = closure()
67
68 for group in self.param_groups:
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/pyro/contrib/gp/util.py in closure()
184 def closure():
185 optimizer.zero_grad()
--> 186 loss = loss_fn(gpmodule.model, gpmodule.guide)
187 torch_backward(loss, retain_graph)
188 return loss
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/pyro/infer/trace_elbo.py in differentiable_loss(self, model, guide, *args, **kwargs)
119 loss = 0.0
120 surrogate_loss = 0.0
--> 121 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
122 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
123 model_trace, guide_trace
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/pyro/infer/elbo.py in _get_traces(self, model, guide, args, kwargs)
184 else:
185 for i in range(self.num_particles):
--> 186 yield self._get_trace(model, guide, args, kwargs)
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/pyro/infer/trace_mean_field_elbo.py in _get_trace(self, model, guide, args, kwargs)
80
81 def _get_trace(self, model, guide, args, kwargs):
---> 82 model_trace, guide_trace = super()._get_trace(model, guide, args, kwargs)
83 if is_validation_enabled():
84 _check_mean_field_requirement(model_trace, guide_trace)
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/pyro/infer/trace_elbo.py in _get_trace(self, model, guide, args, kwargs)
55 against it.
56 """
---> 57 model_trace, guide_trace = get_importance_trace(
58 "flat", self.max_plate_nesting, model, guide, args, kwargs
59 )
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
53 if detach:
54 guide_trace.detach_()
---> 55 model_trace = poutine.trace(
56 poutine.replay(model, trace=guide_trace), graph_type=graph_type
57 ).get_trace(*args, **kwargs)
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
196 Calls this poutine and returns its trace instead of the function's return value.
197 """
--> 198 self(*args, **kwargs)
199 return self.msngr.get_trace()
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
178 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
179 exc = exc.with_traceback(traceback)
--> 180 raise exc from e
181 self.msngr.trace.add_node(
182 "_RETURN", name="_RETURN", type="return", value=ret
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
172 )
173 try:
--> 174 ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:
176 exc_type, exc_value, traceback = sys.exc_info()
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/pyro/nn/module.py in cached_fn(self, *args, **kwargs)
634 def cached_fn(self, *args, **kwargs):
635 with self._pyro_context:
--> 636 return fn(self, *args, **kwargs)
637
638 return cached_fn
/opt/anaconda3/envs/pyro/lib/python3.8/site-packages/pyro/contrib/gp/models/vsgp.py in model(self)
120 Kuu = self.kernel(self.Xu).contiguous()
121 Kuu.view(-1)[:: M + 1] += self.jitter # add jitter to the diagonal
--> 122 Luu = torch.linalg.cholesky(Kuu)
123
124 zero_loc = self.Xu.new_zeros(self.u_loc.shape)
RuntimeError: torch.linalg.cholesky: U(41,41) is zero, singular U.
Trace Shapes:
Param Sites:
Xu 225 2
kernel.lengthscale
kernel.variance 1
Sample Sites: