- What tutorial are you running?
Bayesian Regression - Introduction (Part 1)
- What version of Pyro are you using?
1.8.6 (Latest version)
- Please link or paste relevant code, and steps to reproduce.
hello.
I’m just getting started with Pyro and am looking for information on acceleration (using GPU CUDA).
In Bayesian Regression - Introduction (Part 1), I made the following changes to the code in [10] to improve convergence in the SVI process and increase speed. (Add num_particle=7 and vectorize_particles=True for Trace_ELBO())
Before:
from pyro.infer import SVI, Trace_ELBO
adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO())
After:
from pyro.infer import SVI, Trace_ELBO
adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO(num_particles=7, vectorize_particles=True))
However, after making these changes, I encountered the following error while running SVI.
pyro.clear_param_store()
num_iterations = 100
for j in range(num_iterations):
# calculate the loss and take a gradient step
loss = svi.step(x_data, y_data)
if j % 100 == 0:
print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
173 try:
--> 174 ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
11 with context:
---> 12 return fn(*args, **kwargs)
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
11 with context:
---> 12 return fn(*args, **kwargs)
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/nn/module.py:449, in PyroModule.__call__(self, *args, **kwargs)
448 with self._pyro_context:
--> 449 result = super().__call__(*args, **kwargs)
450 if (
451 pyro.settings.get("validate_poutine")
452 and not self._pyro_context.active
453 and _is_module_local_param_enabled()
454 ):
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
/home/ingyojeong/DBTOF/examples/BNN_example_4.ipynb Cell 8 line 1
12 sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
---> 13 mean = self.linear(x).squeeze(-1)
14 with pyro.plate("data", x.shape[0]):
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/nn/module.py:449, in PyroModule.__call__(self, *args, **kwargs)
448 with self._pyro_context:
--> 449 result = super().__call__(*args, **kwargs)
450 if (
451 pyro.settings.get("validate_poutine")
452 and not self._pyro_context.active
453 and _is_module_local_param_enabled()
454 ):
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
113 def forward(self, input: Tensor) -> Tensor:
--> 114 return F.linear(input, self.weight, self.bias)
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
/home/ingyojeong/DBTOF/examples/BNN_example_4.ipynb Cell 8 line 5
2 num_iterations = 100
3 for j in range(num_iterations):
4 # calculate the loss and take a gradient step
----> 5 loss = svi.step(x_data, y_data)
6 if j % 100 == 0:
7 print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, **kwargs)
143 # get loss and compute gradients
144 with poutine.trace(param_only=True) as param_capture:
--> 145 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
147 params = set(
148 site["value"].unconstrained() for site in param_capture.trace.nodes.values()
149 )
151 # actually perform gradient steps
152 # torch.optim objects gets instantiated for any params that haven't been seen yet
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/infer/trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
138 loss = 0.0
139 # grab a trace from the generator
--> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
141 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
142 model_trace, guide_trace
143 )
144 loss += loss_particle / self.num_particles
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/infer/elbo.py:234, in ELBO._get_traces(self, model, guide, args, kwargs)
232 if self.max_plate_nesting == float("inf"):
233 self._guess_max_plate_nesting(model, guide, args, kwargs)
--> 234 yield self._get_vectorized_trace(model, guide, args, kwargs)
235 else:
236 for i in range(self.num_particles):
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/infer/elbo.py:211, in ELBO._get_vectorized_trace(self, model, guide, args, kwargs)
205 def _get_vectorized_trace(self, model, guide, args, kwargs):
206 """
207 Wraps the model and guide to vectorize ELBO computation over
208 ``num_particles``, and returns a single trace from the wrapped model
209 and guide.
210 """
--> 211 return self._get_trace(
212 self._vectorized_num_particles(model),
213 self._vectorized_num_particles(guide),
214 args,
215 kwargs,
216 )
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
52 def _get_trace(self, model, guide, args, kwargs):
53 """
54 Returns a single trace from the guide, and the model that is run
55 against it.
56 """
---> 57 model_trace, guide_trace = get_importance_trace(
58 "flat", self.max_plate_nesting, model, guide, args, kwargs
59 )
60 if is_validation_enabled():
61 check_if_enumerated(guide_trace)
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/infer/enum.py:65, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
63 if detach:
64 guide_trace.detach_()
---> 65 model_trace = poutine.trace(
66 poutine.replay(model, trace=guide_trace), graph_type=graph_type
67 ).get_trace(*args, **kwargs)
69 if is_validation_enabled():
70 check_model_guide_match(model_trace, guide_trace, max_plate_nesting)
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.get_trace(self, *args, **kwargs)
190 def get_trace(self, *args, **kwargs):
191 """
192 :returns: data structure
193 :rtype: pyro.poutine.Trace
(...)
196 Calls this poutine and returns its trace instead of the function's return value.
197 """
--> 198 self(*args, **kwargs)
199 return self.msngr.get_trace()
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:180, in TraceHandler.__call__(self, *args, **kwargs)
178 exc = exc_type("{}\n{}".format(exc_value, shapes))
179 exc = exc.with_traceback(traceback)
--> 180 raise exc from e
181 self.msngr.trace.add_node(
182 "_RETURN", name="_RETURN", type="return", value=ret
183 )
184 return ret
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
170 self.msngr.trace.add_node(
171 "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs
172 )
173 try:
--> 174 ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:
176 exc_type, exc_value, traceback = sys.exc_info()
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/nn/module.py:449, in PyroModule.__call__(self, *args, **kwargs)
447 def __call__(self, *args, **kwargs):
448 with self._pyro_context:
--> 449 result = super().__call__(*args, **kwargs)
450 if (
451 pyro.settings.get("validate_poutine")
452 and not self._pyro_context.active
453 and _is_module_local_param_enabled()
454 ):
455 self._check_module_local_param_usage()
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
/home/ingyojeong/DBTOF/examples/BNN_example_4.ipynb Cell 8 line 1
11 def forward(self, x, y=None):
12 sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
---> 13 mean = self.linear(x).squeeze(-1)
14 with pyro.plate("data", x.shape[0]):
15 obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/nn/module.py:449, in PyroModule.__call__(self, *args, **kwargs)
447 def __call__(self, *args, **kwargs):
448 with self._pyro_context:
--> 449 result = super().__call__(*args, **kwargs)
450 if (
451 pyro.settings.get("validate_poutine")
452 and not self._pyro_context.active
453 and _is_module_local_param_enabled()
454 ):
455 self._check_module_local_param_usage()
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
113 def forward(self, input: Tensor) -> Tensor:
--> 114 return F.linear(input, self.weight, self.bias)
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D
Trace Shapes:
Param Sites:
Sample Sites:
sigma dist 7 1 |
value 7 1 |
linear.weight dist 7 1 | 1 3
value 7 1 | 1 3
linear.bias dist 7 1 | 1
value 7 1 | 1
Question 1: Can you please tell me what is wrong with this code that makes it not work?
Question 2: I need to train a similar but much more complex neural network than the one in the code above. To do this, I need to accelerate the Stochastic Variational Inference (SVI) process. Can you please tell me how to accelerate(parallelization) the SVI process?
I would really appreciate any help you can provide.
Thank you for your attention to my problem!