vectorize_particles=True in tutorial - Bayesian Regression - Introduction (Part 1)

  • What tutorial are you running?

Bayesian Regression - Introduction (Part 1)

  • What version of Pyro are you using?

1.8.6 (Latest version)

  • Please link or paste relevant code, and steps to reproduce.

hello.

I’m just getting started with Pyro and am looking for information on acceleration (using GPU CUDA).

In Bayesian Regression - Introduction (Part 1), I made the following changes to the code in [10] to improve convergence in the SVI process and increase speed. (Add num_particle=7 and vectorize_particles=True for Trace_ELBO())

Before:

from pyro.infer import SVI, Trace_ELBO


adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO())

After:

from pyro.infer import SVI, Trace_ELBO


adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO(num_particles=7, vectorize_particles=True))

However, after making these changes, I encountered the following error while running SVI.

pyro.clear_param_store()
num_iterations = 100
for j in range(num_iterations):
    # calculate the loss and take a gradient step
    loss = svi.step(x_data, y_data)
    if j % 100 == 0:
        print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
    173 try:
--> 174     ret = self.fn(*args, **kwargs)
    175 except (ValueError, RuntimeError) as e:

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
     11 with context:
---> 12     return fn(*args, **kwargs)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
     11 with context:
---> 12     return fn(*args, **kwargs)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/nn/module.py:449, in PyroModule.__call__(self, *args, **kwargs)
    448 with self._pyro_context:
--> 449     result = super().__call__(*args, **kwargs)
    450 if (
    451     pyro.settings.get("validate_poutine")
    452     and not self._pyro_context.active
    453     and _is_module_local_param_enabled()
    454 ):

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:

/home/ingyojeong/DBTOF/examples/BNN_example_4.ipynb Cell 8 line 1
     12 sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
---> 13 mean = self.linear(x).squeeze(-1)
     14 with pyro.plate("data", x.shape[0]):

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/nn/module.py:449, in PyroModule.__call__(self, *args, **kwargs)
    448 with self._pyro_context:
--> 449     result = super().__call__(*args, **kwargs)
    450 if (
    451     pyro.settings.get("validate_poutine")
    452     and not self._pyro_context.active
    453     and _is_module_local_param_enabled()
    454 ):

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
/home/ingyojeong/DBTOF/examples/BNN_example_4.ipynb Cell 8 line 5
      2 num_iterations = 100
      3 for j in range(num_iterations):
      4     # calculate the loss and take a gradient step
----> 5     loss = svi.step(x_data, y_data)
      6     if j % 100 == 0:
      7         print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, **kwargs)
    143 # get loss and compute gradients
    144 with poutine.trace(param_only=True) as param_capture:
--> 145     loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    147 params = set(
    148     site["value"].unconstrained() for site in param_capture.trace.nodes.values()
    149 )
    151 # actually perform gradient steps
    152 # torch.optim objects gets instantiated for any params that haven't been seen yet

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/infer/trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
    138 loss = 0.0
    139 # grab a trace from the generator
--> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
    141     loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
    142         model_trace, guide_trace
    143     )
    144     loss += loss_particle / self.num_particles

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/infer/elbo.py:234, in ELBO._get_traces(self, model, guide, args, kwargs)
    232     if self.max_plate_nesting == float("inf"):
    233         self._guess_max_plate_nesting(model, guide, args, kwargs)
--> 234     yield self._get_vectorized_trace(model, guide, args, kwargs)
    235 else:
    236     for i in range(self.num_particles):

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/infer/elbo.py:211, in ELBO._get_vectorized_trace(self, model, guide, args, kwargs)
    205 def _get_vectorized_trace(self, model, guide, args, kwargs):
    206     """
    207     Wraps the model and guide to vectorize ELBO computation over
    208     ``num_particles``, and returns a single trace from the wrapped model
    209     and guide.
    210     """
--> 211     return self._get_trace(
    212         self._vectorized_num_particles(model),
    213         self._vectorized_num_particles(guide),
    214         args,
    215         kwargs,
    216     )

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
     52 def _get_trace(self, model, guide, args, kwargs):
     53     """
     54     Returns a single trace from the guide, and the model that is run
     55     against it.
     56     """
---> 57     model_trace, guide_trace = get_importance_trace(
     58         "flat", self.max_plate_nesting, model, guide, args, kwargs
     59     )
     60     if is_validation_enabled():
     61         check_if_enumerated(guide_trace)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/infer/enum.py:65, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     63     if detach:
     64         guide_trace.detach_()
---> 65     model_trace = poutine.trace(
     66         poutine.replay(model, trace=guide_trace), graph_type=graph_type
     67     ).get_trace(*args, **kwargs)
     69 if is_validation_enabled():
     70     check_model_guide_match(model_trace, guide_trace, max_plate_nesting)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.get_trace(self, *args, **kwargs)
    190 def get_trace(self, *args, **kwargs):
    191     """
    192     :returns: data structure
    193     :rtype: pyro.poutine.Trace
   (...)
    196     Calls this poutine and returns its trace instead of the function's return value.
    197     """
--> 198     self(*args, **kwargs)
    199     return self.msngr.get_trace()

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:180, in TraceHandler.__call__(self, *args, **kwargs)
    178         exc = exc_type("{}\n{}".format(exc_value, shapes))
    179         exc = exc.with_traceback(traceback)
--> 180         raise exc from e
    181     self.msngr.trace.add_node(
    182         "_RETURN", name="_RETURN", type="return", value=ret
    183     )
    184 return ret

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
    170 self.msngr.trace.add_node(
    171     "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs
    172 )
    173 try:
--> 174     ret = self.fn(*args, **kwargs)
    175 except (ValueError, RuntimeError) as e:
    176     exc_type, exc_value, traceback = sys.exc_info()

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/nn/module.py:449, in PyroModule.__call__(self, *args, **kwargs)
    447 def __call__(self, *args, **kwargs):
    448     with self._pyro_context:
--> 449         result = super().__call__(*args, **kwargs)
    450     if (
    451         pyro.settings.get("validate_poutine")
    452         and not self._pyro_context.active
    453         and _is_module_local_param_enabled()
    454     ):
    455         self._check_module_local_param_usage()

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

/home/ingyojeong/DBTOF/examples/BNN_example_4.ipynb Cell 8 line 1
     11 def forward(self, x, y=None):
     12     sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
---> 13     mean = self.linear(x).squeeze(-1)
     14     with pyro.plate("data", x.shape[0]):
     15         obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/pyro/nn/module.py:449, in PyroModule.__call__(self, *args, **kwargs)
    447 def __call__(self, *args, **kwargs):
    448     with self._pyro_context:
--> 449         result = super().__call__(*args, **kwargs)
    450     if (
    451         pyro.settings.get("validate_poutine")
    452         and not self._pyro_context.active
    453         and _is_module_local_param_enabled()
    454     ):
    455         self._check_module_local_param_usage()

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/miniconda3/envs/DBTOF/lib/python3.11/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D
     Trace Shapes:          
      Param Sites:          
     Sample Sites:          
        sigma dist 7 1 |    
             value 7 1 |    
linear.weight dist 7 1 | 1 3
             value 7 1 | 1 3
  linear.bias dist 7 1 | 1  
             value 7 1 | 1  

Question 1: Can you please tell me what is wrong with this code that makes it not work?

Question 2: I need to train a similar but much more complex neural network than the one in the code above. To do this, I need to accelerate the Stochastic Variational Inference (SVI) process. Can you please tell me how to accelerate(parallelization) the SVI process?

I would really appreciate any help you can provide.

Thank you for your attention to my problem!

Hi,

You might have to set vectorize_particles=False. Due to the vectorization the shape of linear.weight is 4 dimensional but F.linear expect 2 dimensional weight.

To run on the GPU you can move your model and guide weights to the gpu by using model.to("cuda") and guide.to("cuda"). Also make sure that the data is moved to the cuda device.

1 Like

Thanks for the answer!

It looks like the nn.Linear function is not compatible with vectorize_particles in Pyro, I’ll keep that in mind.

I also tried acceleration using the GPU, and saw a slight performance improvement (speedup). I assume the acceleration will increase further as the model becomes more complex.

Again, thank you for your attention to this issue!