# vectorize_particles=True in tutorial - Bayesian Regression - Introduction (Part 1)

• What tutorial are you running?

# Bayesian Regression - Introduction (Part 1)

• What version of Pyro are you using?

hello.

I’m just getting started with Pyro and am looking for information on acceleration (using GPU CUDA).

In Bayesian Regression - Introduction (Part 1), I made the following changes to the code in [10] to improve convergence in the SVI process and increase speed. (Add num_particle=7 and vectorize_particles=True for Trace_ELBO())

Before:

``````from pyro.infer import SVI, Trace_ELBO

svi = SVI(model, guide, adam, loss=Trace_ELBO())
``````

After:

``````from pyro.infer import SVI, Trace_ELBO

svi = SVI(model, guide, adam, loss=Trace_ELBO(num_particles=7, vectorize_particles=True))
``````

However, after making these changes, I encountered the following error while running SVI.

``````pyro.clear_param_store()
num_iterations = 100
for j in range(num_iterations):
# calculate the loss and take a gradient step
loss = svi.step(x_data, y_data)
if j % 100 == 0:
print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))
``````
``````---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
173 try:
--> 174     ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
11 with context:
---> 12     return fn(*args, **kwargs)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
11 with context:
---> 12     return fn(*args, **kwargs)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/nn/module.py:449, in PyroModule.__call__(self, *args, **kwargs)
448 with self._pyro_context:
--> 449     result = super().__call__(*args, **kwargs)
450 if (
451     pyro.settings.get("validate_poutine")
452     and not self._pyro_context.active
453     and _is_module_local_param_enabled()
454 ):

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525         or _global_backward_pre_hooks or _global_backward_hooks
1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
1529 try:

/home/ingyojeong/DBTOF/examples/BNN_example_4.ipynb Cell 8 line 1
12 sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
---> 13 mean = self.linear(x).squeeze(-1)
14 with pyro.plate("data", x.shape[0]):

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/nn/module.py:449, in PyroModule.__call__(self, *args, **kwargs)
448 with self._pyro_context:
--> 449     result = super().__call__(*args, **kwargs)
450 if (
451     pyro.settings.get("validate_poutine")
452     and not self._pyro_context.active
453     and _is_module_local_param_enabled()
454 ):

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525         or _global_backward_pre_hooks or _global_backward_hooks
1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
1529 try:

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
/home/ingyojeong/DBTOF/examples/BNN_example_4.ipynb Cell 8 line 5
2 num_iterations = 100
3 for j in range(num_iterations):
4     # calculate the loss and take a gradient step
----> 5     loss = svi.step(x_data, y_data)
6     if j % 100 == 0:
7         print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, **kwargs)
143 # get loss and compute gradients
144 with poutine.trace(param_only=True) as param_capture:
--> 145     loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
147 params = set(
148     site["value"].unconstrained() for site in param_capture.trace.nodes.values()
149 )
151 # actually perform gradient steps
152 # torch.optim objects gets instantiated for any params that haven't been seen yet

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/infer/trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
138 loss = 0.0
139 # grab a trace from the generator
--> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
141     loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
142         model_trace, guide_trace
143     )
144     loss += loss_particle / self.num_particles

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/infer/elbo.py:234, in ELBO._get_traces(self, model, guide, args, kwargs)
232     if self.max_plate_nesting == float("inf"):
233         self._guess_max_plate_nesting(model, guide, args, kwargs)
--> 234     yield self._get_vectorized_trace(model, guide, args, kwargs)
235 else:
236     for i in range(self.num_particles):

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/infer/elbo.py:211, in ELBO._get_vectorized_trace(self, model, guide, args, kwargs)
205 def _get_vectorized_trace(self, model, guide, args, kwargs):
206     """
207     Wraps the model and guide to vectorize ELBO computation over
208     ``num_particles``, and returns a single trace from the wrapped model
209     and guide.
210     """
--> 211     return self._get_trace(
212         self._vectorized_num_particles(model),
213         self._vectorized_num_particles(guide),
214         args,
215         kwargs,
216     )

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
52 def _get_trace(self, model, guide, args, kwargs):
53     """
54     Returns a single trace from the guide, and the model that is run
55     against it.
56     """
---> 57     model_trace, guide_trace = get_importance_trace(
58         "flat", self.max_plate_nesting, model, guide, args, kwargs
59     )
60     if is_validation_enabled():
61         check_if_enumerated(guide_trace)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/infer/enum.py:65, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
63     if detach:
64         guide_trace.detach_()
---> 65     model_trace = poutine.trace(
66         poutine.replay(model, trace=guide_trace), graph_type=graph_type
67     ).get_trace(*args, **kwargs)
69 if is_validation_enabled():
70     check_model_guide_match(model_trace, guide_trace, max_plate_nesting)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.get_trace(self, *args, **kwargs)
190 def get_trace(self, *args, **kwargs):
191     """
192     :returns: data structure
193     :rtype: pyro.poutine.Trace
(...)
196     Calls this poutine and returns its trace instead of the function's return value.
197     """
--> 198     self(*args, **kwargs)
199     return self.msngr.get_trace()

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:180, in TraceHandler.__call__(self, *args, **kwargs)
178         exc = exc_type("{}\n{}".format(exc_value, shapes))
179         exc = exc.with_traceback(traceback)
--> 180         raise exc from e
182         "_RETURN", name="_RETURN", type="return", value=ret
183     )
184 return ret

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
171     "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs
172 )
173 try:
--> 174     ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:
176     exc_type, exc_value, traceback = sys.exc_info()

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11     with context:
---> 12         return fn(*args, **kwargs)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11     with context:
---> 12         return fn(*args, **kwargs)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/nn/module.py:449, in PyroModule.__call__(self, *args, **kwargs)
447 def __call__(self, *args, **kwargs):
448     with self._pyro_context:
--> 449         result = super().__call__(*args, **kwargs)
450     if (
451         pyro.settings.get("validate_poutine")
452         and not self._pyro_context.active
453         and _is_module_local_param_enabled()
454     ):
455         self._check_module_local_param_usage()

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525         or _global_backward_pre_hooks or _global_backward_hooks
1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
1529 try:
1530     result = None

/home/ingyojeong/DBTOF/examples/BNN_example_4.ipynb Cell 8 line 1
11 def forward(self, x, y=None):
12     sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
---> 13     mean = self.linear(x).squeeze(-1)
14     with pyro.plate("data", x.shape[0]):
15         obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/nn/module.py:449, in PyroModule.__call__(self, *args, **kwargs)
447 def __call__(self, *args, **kwargs):
448     with self._pyro_context:
--> 449         result = super().__call__(*args, **kwargs)
450     if (
451         pyro.settings.get("validate_poutine")
452         and not self._pyro_context.active
453         and _is_module_local_param_enabled()
454     ):
455         self._check_module_local_param_usage()

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525         or _global_backward_pre_hooks or _global_backward_hooks
1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
1529 try:
1530     result = None

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D
Trace Shapes:
Param Sites:
Sample Sites:
sigma dist 7 1 |
value 7 1 |
linear.weight dist 7 1 | 1 3
value 7 1 | 1 3
linear.bias dist 7 1 | 1
value 7 1 | 1
``````

Question 1: Can you please tell me what is wrong with this code that makes it not work?

Question 2: I need to train a similar but much more complex neural network than the one in the code above. To do this, I need to accelerate the Stochastic Variational Inference (SVI) process. Can you please tell me how to accelerate(parallelization) the SVI process?

Thank you for your attention to my problem!

Hi,

You might have to set `vectorize_particles=False`. Due to the vectorization the shape of `linear.weight` is 4 dimensional but `F.linear` expect 2 dimensional `weight`.

To run on the GPU you can move your model and guide weights to the gpu by using `model.to("cuda")` and `guide.to("cuda")`. Also make sure that the data is moved to the cuda device.

1 Like

It looks like the nn.Linear function is not compatible with vectorize_particles in Pyro, I’ll keep that in mind.

I also tried acceleration using the GPU, and saw a slight performance improvement (speedup). I assume the acceleration will increase further as the model becomes more complex.

Again, thank you for your attention to this issue!