Hey @fehiepsi I was trying to implement the solution suggested in this thread for a similar problem. I am attaching my code that generates the synthetic data and tries to combines GPR and a simple linear model.
import random
import numpy as np
import torch
import pyro
import pyro.distributions as dist
import pyro.optim as optim
import pyro.contrib.gp as gp
import pyro.nn.module as mod
from matplotlib import pyplot as plt
from torch.distributions import constraints
from pyro.infer.autoguide.guides import AutoDelta, AutoDiagonalNormal
from pyro.infer.autoguide import init_to_mean, init_to_feasible
from pyro import poutine
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete, Trace_ELBO, Predictive, MCMC, NUTS
# %% Create trajectories
np.random.seed(10)
samp_t = np.linspace(0, 10, num=100)
samp_traj = 2 * samp_t + np.random.rand(np.size(samp_t, 0),)*3
X = torch.tensor(samp_t, dtype=torch.float)
y = torch.tensor(samp_traj, dtype=torch.float)
# %% Training
class Linear(gp.Parameterized):
def __init__(self):
super(Linear, self).__init__()
self.a = mod.PyroSample(dist.Normal(0., 5.))
self.b_x0 = mod.PyroSample(dist.Normal(0., 3.))
def forward(self, X):
a = self.a
b_x0 = self.b_x0
m = a + b_x0 * X
return m
linear = Linear()
def model(X, y):
linear.set_mode("model")
f, f_var = gpmodel.model()
mean = linear(X)
sigma = pyro.param("sigma",
torch.tensor(1.),
constraint=torch.distributions.constraints.positive)
with pyro.plate("data", len(y)):
pyro.sample("obs", dist.Normal(mean + f, sigma + f_var), obs=y)
def guide(X, y):
linear.set_mode("guide")
gpmodel.set_data(X, None)
gpmodel.guide()
# %% Using SVI
pyro.clear_param_store()
pyro.set_rng_seed(1524)
kern = gp.kernels.Matern52(input_dim=1,
variance=torch.tensor(0.1),
lengthscale=torch.tensor(1.))
gpmodel = gp.models.GPRegression(X, y, kern, mean_function=linear)
svi = SVI(model,
guide,
optim.Adam({"lr": .01}),
loss=pyro.infer.Trace_ELBO().differentiable_loss)
num_iters = 500
losses = []
for i in range(num_iters):
elbo = svi.step(X, y)
losses.append(elbo)
print(elbo)
However, I receive an error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
156 try:
--> 157 ret = self.fn(*args, **kwargs)
158 except (ValueError, RuntimeError):
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
7 with context:
----> 8 return fn(*args, **kwargs)
9
~/etl_tasks/etl_tasks/out_of_warehouse/wound_analysis/rish_analyses/general_experiments/probabilistic_programs/mixture_gp_test.py in model(X, y)
47 f, f_var = gpmodel.model()
---> 48 mean = linear(X)
49 sigma = pyro.param("sigma",
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
285 with self._pyro_context:
--> 286 return super().__call__(*args, **kwargs)
287
/anaconda3/envs/experimental/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
540 else:
--> 541 result = self.forward(*input, **kwargs)
542 for hook in self._forward_hooks.values():
~/etl_tasks/etl_tasks/out_of_warehouse/wound_analysis/rish_analyses/general_experiments/probabilistic_programs/mixture_gp_test.py in forward(self, X)
34 def forward(self, X):
---> 35 a = self.a
36 b_x0 = self.b_x0
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/nn/module.py in __getattr__(self, name)
327 prior = prior(self)
--> 328 value = pyro.sample(fullname, prior)
329 context.set(fullname, value)
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
109 # apply the stack and return its return value
--> 110 apply_stack(msg)
111 return msg["value"]
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
197 for frame in stack[-pointer:]:
--> 198 frame._postprocess_message(msg)
199
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/messenger.py in _postprocess_message(self, msg)
137 if hasattr(self, method_name):
--> 138 return getattr(self, method_name)(msg)
139 return None
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in _pyro_post_sample(self, msg)
118 if not self.param_only:
--> 119 self.trace.add_node(msg["name"], **msg.copy())
120
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/trace_struct.py in add_node(self, site_name, **kwargs)
108 # Cannot sample after a previous sample statement.
--> 109 raise RuntimeError("Multiple {} sites named '{}'".format(kwargs['type'], site_name))
110
RuntimeError: Multiple sample sites named 'mean_function.a'
During handling of the above exception, another exception occurred:
RuntimeError Traceback (most recent call last)
~/etl_tasks/etl_tasks/out_of_warehouse/wound_analysis/rish_analyses/general_experiments/probabilistic_programs/mixture_gp_test.py in
75 losses = []
76 for i in range(num_iters):
---> 77 elbo = svi.step(X, y)
78 losses.append(elbo)
79 print(elbo)
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
123 # get loss and compute gradients
124 with poutine.trace(param_only=True) as param_capture:
--> 125 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
126
127 params = set(site["value"].unconstrained()
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/infer/svi.py in _loss_and_grads(*args, **kwargs)
66 if loss_and_grads is None:
67 def _loss_and_grads(*args, **kwargs):
---> 68 loss_val = loss(*args, **kwargs)
69 if getattr(loss_val, 'requires_grad', False):
70 loss_val.backward(retain_graph=True)
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/infer/trace_elbo.py in differentiable_loss(self, model, guide, *args, **kwargs)
104 loss = 0.
105 surrogate_loss = 0.
--> 106 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
107 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace)
108 surrogate_loss += surrogate_loss_particle / self.num_particles
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/infer/elbo.py in _get_traces(self, model, guide, args, kwargs)
165 else:
166 for i in range(self.num_particles):
--> 167 yield self._get_trace(model, guide, args, kwargs)
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/infer/trace_elbo.py in _get_trace(self, model, guide, args, kwargs)
48 """
49 model_trace, guide_trace = get_importance_trace(
---> 50 "flat", self.max_plate_nesting, model, guide, args, kwargs)
51 if is_validation_enabled():
52 check_if_enumerated(guide_trace)
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
43 guide_trace.detach_()
44 model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
---> 45 graph_type=graph_type).get_trace(*args, **kwargs)
46 if is_validation_enabled():
47 check_model_guide_match(model_trace, guide_trace, max_plate_nesting)
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
175 Calls this poutine and returns its trace instead of the function's return value.
176 """
--> 177 self(*args, **kwargs)
178 return self.msngr.get_trace()
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
159 exc_type, exc_value, traceback = sys.exc_info()
160 shapes = self.msngr.trace.format_shapes()
--> 161 raise exc_type(u"{}\n{}".format(exc_value, shapes)).with_traceback(traceback)
162 self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret)
163 return ret
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
155 args=args, kwargs=kwargs)
156 try:
--> 157 ret = self.fn(*args, **kwargs)
158 except (ValueError, RuntimeError):
159 exc_type, exc_value, traceback = sys.exc_info()
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
6 def _context_wrap(context, fn, *args, **kwargs):
7 with context:
----> 8 return fn(*args, **kwargs)
9
10
~/etl_tasks/etl_tasks/out_of_warehouse/wound_analysis/rish_analyses/general_experiments/probabilistic_programs/mixture_gp_test.py in model(X, y)
46 linear.set_mode("model")
47 f, f_var = gpmodel.model()
---> 48 mean = linear(X)
49 sigma = pyro.param("sigma",
50 torch.tensor(1.),
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
284 def __call__(self, *args, **kwargs):
285 with self._pyro_context:
--> 286 return super().__call__(*args, **kwargs)
287
288 def __getattr__(self, name):
/anaconda3/envs/experimental/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
539 result = self._slow_forward(*input, **kwargs)
540 else:
--> 541 result = self.forward(*input, **kwargs)
542 for hook in self._forward_hooks.values():
543 hook_result = hook(self, input, result)
~/etl_tasks/etl_tasks/out_of_warehouse/wound_analysis/rish_analyses/general_experiments/probabilistic_programs/mixture_gp_test.py in forward(self, X)
33
34 def forward(self, X):
---> 35 a = self.a
36 b_x0 = self.b_x0
37 m = a + b_x0 * X
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/nn/module.py in __getattr__(self, name)
326 if not hasattr(prior, "sample"): # if not a distribution
327 prior = prior(self)
--> 328 value = pyro.sample(fullname, prior)
329 context.set(fullname, value)
330 return value
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
108 msg["is_observed"] = True
109 # apply the stack and return its return value
--> 110 apply_stack(msg)
111 return msg["value"]
112
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
196
197 for frame in stack[-pointer:]:
--> 198 frame._postprocess_message(msg)
199
200 cont = msg["continuation"]
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/messenger.py in _postprocess_message(self, msg)
136 method_name = "_pyro_post_{}".format(msg["type"])
137 if hasattr(self, method_name):
--> 138 return getattr(self, method_name)(msg)
139 return None
140
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in _pyro_post_sample(self, msg)
117 def _pyro_post_sample(self, msg):
118 if not self.param_only:
--> 119 self.trace.add_node(msg["name"], **msg.copy())
120
121 def _pyro_post_param(self, msg):
/anaconda3/envs/experimental/lib/python3.7/site-packages/pyro/poutine/trace_struct.py in add_node(self, site_name, **kwargs)
107 elif kwargs['type'] != "param":
108 # Cannot sample after a previous sample statement.
--> 109 raise RuntimeError("Multiple {} sites named '{}'".format(kwargs['type'], site_name))
110
111 # XXX should copy in case site gets mutated, or dont bother?
RuntimeError: Multiple sample sites named 'mean_function.a'
Trace Shapes:
Param Sites:
kernel.lengthscale
kernel.variance
noise
Sample Sites:
mean_function.a dist |
value |
mean_function.b_x0 dist |
value |
Could you suggest a solution to this problem? I tried removing the mean_function
argument from gp.models.GPRegression
, the error disappears but the model never converges to a solution.
Thanks