Proper shape for Predictive with parallel=True

I’m trying to run Predictive with parallel=True on a neural network (again the problem I’ve posted about before). I think I’m close to having the shapes right for parallel inference but I’m stuck.

Based on searching the forum, I’m using my own nn.Linear class. This is designed to handle shapes properly.

class myLinear(nn.Module):
  def __init__(self, in_features, out_features, bias=True, _print=False):
    super().__init__()
    self.in_features = in_features
    self.out_features = out_features
    self.bias = bias
    self._print = _print
    self.weight = torch.nn.Parameter(torch.randn(out_features, in_features))
    self.bias = torch.nn.Parameter(torch.randn(out_features))
       
  def forward(self, input):
    x, y = input.shape
    if y != self.in_features:
        sys.exit(f'Wrong Input Features. Please use tensor with {self.in_features} Input Features')
    output = input @ self.weight.squeeze().t() + self.bias.squeeze()
    return output

I get an error with Predictive.

predictive_obs = Predictive(model, guide=guide, num_samples=int(12), return_sites = ['obs'], parallel = True)
with converter_val.make_torch_dataloader(batch_size=5) as val_dataloader:
    val_dataloader_iter = iter(val_dataloader)
    validation_steps = 1

    # Iterate over all the validation data.
    for step in range(validation_steps):
      pd_batch = next(val_dataloader_iter)
      pd_batch['features'] = torch.transpose(torch.stack([pd_batch[x] for x in x_feat]), 0, 1)#.double()    
      inputs = pd_batch['features'].to(device)
      labels = pd_batch[y_name].to(device)
      samples_obs = predictive_obs(inputs, subsample=False)

ValueError: Shape mismatch inside plate('data') at site obs dim -1, 5 vs 12

For a batch size of 5 with 12 samples, the shape I get for my parameters such as mu is [5, 12] = [batch size, samples]. Is the output of my linear module correct? And the error is in plate somewhere?

Full error output:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    173             try:
--> 174                 ret = self.fn(*args, **kwargs)
    175             except (ValueError, RuntimeError) as e:

/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 

/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 

/databricks/python/lib/python3.8/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     27             with self.__class__():
---> 28                 return func(*args, **kwargs)
     29         return cast(F, decorate_context)

/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 

/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    425         with self._pyro_context:
--> 426             return super().__call__(*args, **kwargs)
    427 

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used

<command-3701536137140920> in forward(self, x, y, subsample)
    118           with pyro.plate("data", x.shape[0], device = device, dim=-1):
--> 119             obs = pyro.sample("obs", GammaHurdle(concentration = shape, rate = shape / mu, theta = theta), obs=y) #.to_event(1)
    120               #print(obs.shape)

/databricks/python/lib/python3.8/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
    162         # apply the stack and return its return value
--> 163         apply_stack(msg)
    164         return msg["value"]

/databricks/python/lib/python3.8/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
    212 
--> 213         frame._process_message(msg)
    214 

/databricks/python/lib/python3.8/site-packages/pyro/poutine/plate_messenger.py in _process_message(self, msg)
     18         super()._process_message(msg)
---> 19         return BroadcastMessenger._pyro_sample(msg)
     20 

/usr/lib/python3.8/contextlib.py in inner(*args, **kwds)
     74             with self._recreate_cm():
---> 75                 return func(*args, **kwds)
     76         return inner

/databricks/python/lib/python3.8/site-packages/pyro/poutine/broadcast_messenger.py in _pyro_sample(msg)
     64                 ):
---> 65                     raise ValueError(
     66                         "Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format(

ValueError: Shape mismatch inside plate('data') at site obs dim -1, 5 vs 12

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
<command-3701536137140873> in <module>
     11       inputs = pd_batch['features'].to(device)
     12       labels = pd_batch[y_name].to(device)
---> 13       samples_obs = predictive_obs(inputs, subsample=False)

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/databricks/python/lib/python3.8/site-packages/pyro/infer/predictive.py in forward(self, *args, **kwargs)
    271                 model_kwargs=kwargs,
    272             )
--> 273         return _predictive(
    274             self.model,
    275             posterior_samples,

/databricks/python/lib/python3.8/site-packages/pyro/infer/predictive.py in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
    135         )
    136 
--> 137     trace = poutine.trace(
    138         poutine.condition(vectorize(model), reshaped_samples)
    139     ).get_trace(*model_args, **model_kwargs)

/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    196         Calls this poutine and returns its trace instead of the function's return value.
    197         """
--> 198         self(*args, **kwargs)
    199         return self.msngr.get_trace()

/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    178                 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
    179                 exc = exc.with_traceback(traceback)
--> 180                 raise exc from e
    181             self.msngr.trace.add_node(
    182                 "_RETURN", name="_RETURN", type="return", value=ret

/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    172             )
    173             try:
--> 174                 ret = self.fn(*args, **kwargs)
    175             except (ValueError, RuntimeError) as e:
    176                 exc_type, exc_value, traceback = sys.exc_info()

/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

/databricks/python/lib/python3.8/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     26         def decorate_context(*args, **kwargs):
     27             with self.__class__():
---> 28                 return func(*args, **kwargs)
     29         return cast(F, decorate_context)
     30 

/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    424     def __call__(self, *args, **kwargs):
    425         with self._pyro_context:
--> 426             return super().__call__(*args, **kwargs)
    427 
    428     def __getattr__(self, name):

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

<command-3701536137140920> in forward(self, x, y, subsample)
    117         else:
    118           with pyro.plate("data", x.shape[0], device = device, dim=-1):
--> 119             obs = pyro.sample("obs", GammaHurdle(concentration = shape, rate = shape / mu, theta = theta), obs=y) #.to_event(1)
    120               #print(obs.shape)
    121 #         if SUBSAMPLE_SVI_N is None:

/databricks/python/lib/python3.8/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
    161         }
    162         # apply the stack and return its return value
--> 163         apply_stack(msg)
    164         return msg["value"]
    165 

/databricks/python/lib/python3.8/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
    211         pointer = pointer + 1
    212 
--> 213         frame._process_message(msg)
    214 
    215         if msg["stop"]:

/databricks/python/lib/python3.8/site-packages/pyro/poutine/plate_messenger.py in _process_message(self, msg)
     17     def _process_message(self, msg):
     18         super()._process_message(msg)
---> 19         return BroadcastMessenger._pyro_sample(msg)
     20 
     21     def __enter__(self):

/usr/lib/python3.8/contextlib.py in inner(*args, **kwds)
     73         def inner(*args, **kwds):
     74             with self._recreate_cm():
---> 75                 return func(*args, **kwds)
     76         return inner
     77 

/databricks/python/lib/python3.8/site-packages/pyro/poutine/broadcast_messenger.py in _pyro_sample(msg)
     63                     and target_batch_shape[f.dim] != f.size
     64                 ):
---> 65                     raise ValueError(
     66                         "Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format(
     67                             f.name,

ValueError: Shape mismatch inside plate('data') at site obs dim -1, 5 vs 12
           Trace Shapes:            
            Param Sites:            
        mu.mu_fc0.weight     8 2    
          mu.mu_fc0.bias       8    
        mu.mu_fc1.weight     8 8    
          mu.mu_fc1.bias       8    
     shape.sh_fc0.weight     8 2    
       shape.sh_fc0.bias       8    
     shape.sh_fc1.weight     8 8    
       shape.sh_fc1.bias       8    
     theta.th_fc0.weight     8 2    
       theta.th_fc0.bias       8    
     theta.th_fc1.weight     8 8    
       theta.th_fc1.bias       8    
           Sample Sites:            
   mu.mu_fc2.weight dist 12  1 | 1 8
                   value    12 | 1 8
     mu.mu_fc2.bias dist 12  1 | 1  
                   value    12 | 1  
shape.sh_fc2.weight dist 12  1 | 1 8
                   value    12 | 1 8
  shape.sh_fc2.bias dist 12  1 | 1  
                   value    12 | 1  
theta.th_fc2.weight dist 12  1 | 1 8
                   value    12 | 1 8
  theta.th_fc2.bias dist 12  1 | 1  
                   value    12 | 1  
               data dist       |    
                   value     5 |

Here is my model class (I’ve posted it in previous questions as well). I’ll edit things down for brevity.

class BayesNN_Gamma_muShapeTheta(PyroModule):
  # 
    @name_count
    def __init__(self, in_features, mu_l, sh_l, th_l, out_features = 1,unif_bound = 3.):
        super().__init__()
        
        # parameter names list
        self.parameter_names = []
        
        layers = []
        for i in range(len(mu_l)-1):
            layers.append(('mu_fc' + str(i), myLinear(mu_l[i], mu_l[i+1], _print=True)))
            if i != (len(mu_l)-2): layers.append(('mu_ReLU' + str(i), nn.ReLU()))
        mu = OrderedDict(layers)
        self.mu = nn.Sequential(mu)
        
        for name, param in self.mu.named_parameters():
            self.parameter_names.append(name)
        
        pyro.nn.module.to_pyro_module_(self.mu)
        # following step sets only the final layer to Bayesian
        # the lower layers remain constants as in a regular neural net
        for m in self.mu.modules():
          if m._pyro_name == 'mu.mu_fc' + str(len(mu_l)-2):
            for name, value in list(m.named_parameters(recurse=False)):
              setattr(m, name, PyroSample(prior=dist.Normal(0., 1.).expand(value.shape).to_event(value.dim())))

        # shape and theta are basically the same
        ...

    def forward(self, x, y=None, subsample = True):
        mu = self.mu(x).exp().clamp(min = .000001, max = 1e6)
        shape = self.shape(x).exp().clamp(min = 0.000001, max = 1e2)
        theta = self.theta(x)
        
       with pyro.plate("data", x.shape[0], device = device, dim=-1):
            obs = pyro.sample("obs", GammaHurdle(concentration = shape, rate = shape / mu, theta = theta), obs=y) 

        return  torch.cat((mu, shape, theta), 0)

Hi @yoshy, I’m unsure about the root cause of your problem, but I’d recommend avoiding the torch methods .t() and .squeeze() which are incompatible with broadcasting:

- output = input @ self.weight.squeeze().t() + self.bias.squeeze()
+ output = input @ self.weight.transpose(-1, -2) + self.bias

Does that fix your issue?

That still fails because the shapes are wrong when parallel=True. Note in my particular case, I have 3 layers but only the last layer is Bayesian. For batch size 5 and num_samples = 12 and parallel=True, I get the following shapes for the final Bayesian layer.

inputs from the previous layer :: Size([5, 8])
weight :: Size([12, 1, 8]) = [num_samples, 1, neurons from previous layer]
bias :: Size([12, 1])
inputs @ weight.transpose(-1, -2) :: Size([12, 5, 1])

output = input @ self.weight.transpose(-1, -2) + self.bias so we’re trying to add Size([12, 5, 1]) + Size([12, 1]) which returns the error:

The size of tensor a (5) must match the size of tensor b (12) at non-singleton dimension 1

I have a possibly working solution though I think it’s still not correct per se. I reread the shapes page Tensor shapes in Pyro — Pyro Tutorials 1.8.1 documentation and tried out using another layer of pyro.plate.

def forward(self, x, y=None, subsample = True):
        mu = self.mu(x).exp().clamp(min = .000001, max = 1e6)
        shape = self.shape(x).exp().clamp(min = 0.000001, max = 1e2)
        theta = self.theta(x)
        
    with pyro.plate('extra_dim', dim = -3):
       with pyro.plate("data", x.shape[0], device = device, dim=-2):
            obs = pyro.sample("obs", GammaHurdle(concentration = shape, rate = shape / mu, theta = theta), obs=y) 

        return  torch.cat((mu, shape, theta), 0)

The following code now runs but with confusing results.

predictive_obs = Predictive(model, guide=guide, num_samples=int(12), return_sites = ['obs', '_RETURN'], parallel = True)
with converter_val.make_torch_dataloader(batch_size=5) as val_dataloader:
    val_dataloader_iter = iter(val_dataloader)
    validation_steps = 1#np.max((1, len(converter_val) // 5)) if max_batches is None else max_batches
    #print(validation_steps)

    # Iterate over all the validation data.
    for step in range(validation_steps):
      pd_batch = next(val_dataloader_iter)
      pd_batch['features'] = torch.transpose(torch.stack([pd_batch[x] for x in x_feat]), 0, 1)#.double()    
      inputs = pd_batch['features'].to(device)
      labels = pd_batch[y_name].to(device)
      samples_obs = predictive_obs(inputs, subsample=False)

The results do look like proper sampling from the posterior (thought I need to verify for sure) but the shape is a little confusing. I guess the extra singleton dimensions are making me doubt my code.

print(samples_obs['obs'].size())
print(samples_obs['obs'].squeeze())
torch.Size([12, 1, 5, 1])
tensor([[   0.0000,    0.0000,    0.0000, 1141.8843,    0.0000],
        [  44.0103,  235.4155,  506.2648,  243.5357,  682.7267],
        [ 137.7539,  135.4984,  737.5692,  224.2097,    0.0000],
        [ 290.2030,    0.0000,  298.2992, 1305.8361,  398.1299],
        [ 218.5411,    0.0000,  423.9644, 1238.9811,    0.0000],
        [   0.0000,    0.0000,  834.5766,  604.4004,  318.4350],
        [   0.0000,  280.6942,  476.8374, 1169.0358,  518.7895],
        [   0.0000,    0.0000,    0.0000,    0.0000,  129.0209],
        [ 168.0999,    0.0000,    0.0000, 1515.7461,  121.2689],
        [ 289.1123,  220.8311,  565.0992,  479.5122,  233.5051],
        [ 423.8335,    0.0000,  501.7006, 5017.4473,  201.9666],
        [ 421.4970,   96.5464,  141.6303,  811.3078,  670.8171]])

Squeezing looks reasonable for posterior samples, but now training doesn’t work. My loss is much higher than before and it doesn’t seem to learn at all.

If I look at the shape of obs, I get [training batch size, 1], again an extra singleton dimension.

Hi @yoshy I can’t really tell what’s going on in your model (a full script with shape assertions might help), but here are some things that catch my eye:

  • Why is there no plate with dim=-1? That’s very weird, usually plates fill up the rightmost dimensions, so you’ll have plates with dims {-1} or {-1,-2} or {-1,-2,-3}.
  • What is this dimension of size 5? Is that covered by a plate? I’m surprised it doesn’t conflict with the plate of size 12 earlier.
  • What is the extra_plate and what should its size be?

I realized I need to come up with a simple full script. I’ll do that when I have time.

My answer to “why” for most of these is I’ve just been trying random things to get it to work. I’m still working on learning how things work internally. For my tests, I’ve been using batch size = 5 and then num_samples = 12. I believe that dimension is covered by:

with pyro.plate("data", x.shape[0], device = device, dim=-1):
            obs = pyro.sample("obs", GammaHurdle(concentration = shape, rate = shape / mu, theta = theta), obs=y)

Part of my confusion is what size this should be. The tensor shapes tutorial says " * We would also like to use the pyro.plate construct to parallelize the ELBO estimator over num_particles. This is done by wrapping the contents of model/guide inside an outermost pyro.plate context." So I added the “extra_dim” plate for that. Should the size of that be num_particles?

I’ve created a simpler, reproducible version of this question at (Reproducible example) Proper shape for Predictive with parallel=True - Pyro Discussion Forum

1 Like