(Reproducible example) Proper shape for Predictive with parallel=True

I previously posted this question at Proper shape for Predictive with parallel=True - Pyro Discussion Forum, but it was difficult to follow with my specific example. I’ve created an easily reproducible example but have the same question. I am using spark but it can easily be avoided.

import pandas as pd
from pyspark.sql.functions import *
import numpy as np
import pyro
from pyro.nn import PyroModule, PyroSample
from pyro.contrib.autoname import *
import torch
import torch.nn as nn
from torch.distributions.utils import broadcast_all
from torch.distributions import constraints
import pyro.distributions as dist
from pyro.infer.autoguide import AutoDiagonalNormal, AutoMultivariateNormal, AutoLowRankMultivariateNormal, AutoLaplaceApproximation, init_to_mean
from pyro.infer import SVI, Trace_ELBO,TraceEnum_ELBO, Predictive
import pyro.poutine as poutine
from petastorm.spark import SparkDatasetConverter, make_spark_converter
from collections import OrderedDict

n_row = 1000
rng = np.random.default_rng()
df = pd.DataFrame({'x1' : rng.standard_normal(n_row),
                   'x2' : np.random.uniform(1, 10, n_row)})
df['sigma'] = df['x2']
df['mu'] = 2*df['x1'] + df['x2'] + 1000
df['y'] = rng.standard_normal(n_row) * df['sigma'] + df['mu']
df = spark.createDataFrame(df)

if torch.cuda.is_available():
  device = torch.device('cuda')
else: 
  device = torch.device('cpu')

x_feat = ['x1', 'x2']
y_name = 'y'

torch.set_default_tensor_type("torch.cuda.FloatTensor")
df.sort('x1', 'x2').display()
# spark stuff
spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, "file:///dbfs/tmp/petastorm/cache")
df_conv = make_spark_converter(df)

class myLinear(nn.Module):
  def __init__(self, in_features, out_features, bias=True, _print=False):
    super().__init__()
    self.in_features = in_features
    self.out_features = out_features
    self.bias = bias
    self._print = _print
    self.weight = torch.nn.Parameter(torch.randn(out_features, in_features))
    self.bias = torch.nn.Parameter(torch.randn(out_features))
       
  def forward(self, input):
    if self._print:
      print('myLin weight', self.weight.shape, self.weight)
      print('myLin input', input.shape)
      print('myLin bias', self.bias.shape, self.bias)
      print('myLin input.dim', input.dim())
    if input.dim() > 1:
      x, y = input.shape
      if y != self.in_features:
          sys.exit(f'Wrong Input Features. Please use tensor with {self.in_features} Input Features')
    output = (input @ self.weight.transpose(-1, -2)) + self.bias
    #output = input @ self.weight.squeeze().t() + self.bias
    if self._print:
      print('myLin output', output.shape)
    return output
  
class net(PyroModule):
  # 
    @name_count
    def __init__(self, mu_l, sh_l):
        super().__init__()
        
        # parameter names list
        self.parameter_names = []
        
        layers = []
        for i in range(len(mu_l)-1):
            layers.append(('mu_fc' + str(i), myLinear(mu_l[i], mu_l[i+1], _print=False)))
            if i != (len(mu_l)-2): layers.append(('mu_ReLU' + str(i), nn.ReLU()))
        mu = OrderedDict(layers)
        self.mu = nn.Sequential(mu)
        
        for name, param in self.mu.named_parameters():
            self.parameter_names.append(name)
        
        pyro.nn.module.to_pyro_module_(self.mu)
        for m in self.mu.modules():
          if m._pyro_name == 'mu.mu_fc' + str(len(mu_l)-2):
            for name, value in list(m.named_parameters(recurse=False)):
              setattr(m, name, PyroSample(prior=dist.Normal(0., 1.).expand(value.shape).to_event(value.dim())))
              
        # sigma
        layers = []
        for i in range(len(sh_l)-1):
            layers.append(('sh_fc' + str(i), myLinear(sh_l[i], sh_l[i+1])))
            if i != (len(sh_l)-2): layers.append(('sh_ReLU' + str(i), nn.ReLU()))
        shape = OrderedDict(layers)
        self.shape = nn.Sequential(shape)
        
        for name, param in self.shape.named_parameters():
            self.parameter_names.append(name)
        
        pyro.nn.module.to_pyro_module_(self.shape)
        for m in self.shape.modules():
          if m._pyro_name == 'shape.sh_fc' + str(len(sh_l)-2):
            for name, value in list(m.named_parameters(recurse=False)):
              setattr(m, name, PyroSample(prior=dist.Normal(0., 1.).expand(value.shape).to_event(value.dim())))                

    def forward(self, x, y=None):
        
        mu = self.mu(x).exp().clamp(min = .000001, max = 1e6).squeeze()
        shape = self.shape(x).exp().clamp(min = 0.000001, max = 1e1).squeeze()
        
        with pyro.plate("data", x.shape[0], device = device, dim=-1): # 
          obs = pyro.sample("obs", dist.Normal(mu, shape), obs=y)

        return  torch.cat((mu, shape), 0)

def train_and_evaluate_SVI(svi, model, guide, bs, ne):
    
  model = model.to(device)
  with df_conv.make_torch_dataloader(batch_size=bs) as train_dataloader:

      train_dataloader_iter = iter(train_dataloader)
      steps_per_epoch = len(df_conv) // bs

      for epoch in range(ne):
          if (epoch + 1) % 10 == 0: print('-' * 10)
          if (epoch + 1) % 10 == 0: print('Epoch {}/{}'.format(epoch + 1, ne))
          
          train_loss = train_one_epoch_SVI(svi, train_dataloader_iter, steps_per_epoch, epoch, device)
          
  return train_loss

def train_one_epoch_SVI(svi, train_dataloader_iter, steps_per_epoch, epoch, device):
  model.train()
  
  running_loss = 0.0
  total_lives = 0
  
  # Iterate over the data for one epoch.
  for step in range(steps_per_epoch):
      pd_batch = next(train_dataloader_iter)
      pd_batch['features'] = torch.transpose(torch.stack([pd_batch[x] for x in x_feat]), 0, 1)
      inputs = pd_batch['features'].to(device)
      labels = pd_batch[y_name].to(device)
      loss = svi.step(inputs, labels)

      # statistics
      running_loss += loss
      total_lives += inputs.shape[0]
  
  epoch_loss = running_loss  / total_lives
  
  print('Train Loss: {:.4f}'.format(epoch_loss))
  return epoch_loss

# train
model = net(mu_l = [2, 3, 4, 1], sh_l = [2, 3, 4, 1])
pyro.clear_param_store()
pyro.set_rng_seed(123456789)
loss_fnc = Trace_ELBO(num_particles=2, vectorize_particles=True)
model = model.to(device)
guide = AutoDiagonalNormal(model)
adam = pyro.optim.ClippedAdam({"lr": 0.001, 'betas': (.95, .999), 'weight_decay' : .25, 'clip_norm' : 10.}) 
svi = SVI(model, guide, adam, loss=loss_fnc)
train_and_evaluate_SVI(svi=svi, model = model, guide = guide, bs = 10, ne = 20)

# posterior 
predictive_obs = Predictive(model, guide=guide, num_samples=int(3), return_sites = ['obs', '_RETURN'], parallel = True)
with df_conv.make_torch_dataloader(batch_size=2) as dl:
    dl_iter = iter(dl)
    steps = 1

    # Iterate over all the validation data.
    for step in range(steps):
      pd_batch = next(dl_iter)
      pd_batch['features'] = torch.transpose(torch.stack([pd_batch[x] for x in x_feat]), 0, 1)#.double()    
      inputs = pd_batch['features'].to(device)
      samples_obs = predictive_obs(inputs)

The above trains with both the default Trace_ELBO and Trace_ELBO(num_particles=2, vectorize_particles=True). However, the posterior prediction part fails when parallel = True but not when parallel=False. When it is set to True, I get the following error:

RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
        Trace Shapes:          
         Param Sites:          
     mu.mu_fc0.weight   3 2    
       mu.mu_fc0.bias     3    
     mu.mu_fc1.weight   4 3    
       mu.mu_fc1.bias     4    
        Sample Sites:          
mu.mu_fc2.weight dist 3 1 | 1 4
                value   3 | 1 4
  mu.mu_fc2.bias dist 3 1 | 1  
                value   3 | 1

I’ve tried different solutions such as adding plates, but I can’t seem to get both training and Predictive to work.

Still working on this with no luck but here are some observations on shapes. To make the example clearer, I’m running with one hidden layer, model = net(mu_l = [2, 4, 1], sh_l = [2, 4, 1]). When running with Trace_ELBO(), the shapes coming through look like

Parameter Size
Weight torch.Size([1, 4])
Input torch.Size([1000, 4])
Bias torch.Size([1])
Output torch.Size([1000, 1])
mu torch.Size([1000, 1])
shape torch.Size([1000, 1])

Running Trace_ELBO(num_particles=10, vectorize_particles=True), two dimensions get added where I would expect only one to be added for the vectorized particles.

Parameter Size
Weight torch.Size([10, 1, 1, 4])
Input torch.Size([1000, 4])
Bias torch.Size([10, 1, 1])
Output torch.Size([10, 10, 1000, 1])
mu torch.Size([10, 10, 1000, 1])
shape torch.Size([10, 10, 1000, 1])

I.e., I would expect weight to be torch.Size([10, 1, 1, 4]). The extra dimension also results in the output having an extra dimension of 10, torch.Size([10, 10, 1000, 1]) vs torch.Size([10, 1000, 1]) expected.

Predictive only works if parallel = False. If I set parallel = True and use a batch size = 1000 and num_samples = 10, I get the following shapes and error:

Parameter Size
Weight torch.Size([10, 1, 4])
Input torch.Size([1000, 4])
Bias torch.Size([10, 1])
RuntimeError: The size of tensor a (1000) must match the size of tensor b (10) at non-singleton dimension 1
        Trace Shapes:            
         Param Sites:            
     mu.mu_fc0.weight     4 2    
       mu.mu_fc0.bias       4    
        Sample Sites:            
mu.mu_fc1.weight dist 10  1 | 1 4
                value    10 | 1 4
  mu.mu_fc1.bias dist 10  1 | 1  
                value    10 | 1

I’m not sure if the issue is with the myLinear layer or if I have defined something wrong in the model.

Anyone have thoughts on this? I’m starting to think a bug is occurring with the dimensions of weights. @eb8680_2 previously wrote:

e.g. Trace_ELBO(num_particles=100, vectorize_particles=True)` ; the ELBO implementation will automatically create an extra batch dimension in your model and guide and average over this dimension when computing the ELBO estimate.

However, here weight has been given two additional dimensions, not one. Same for bias.

I am still trying to track down the causes for this but I seem to be stuck.

Hi @yoshy ,

Can you change your example so it could be run without using spark? I tried your example but ran into issues (related to spark I believe) reproducing it.

Sure no problem!

import pandas as pd
import numpy as np
import pyro
from pyro.nn import PyroModule, PyroSample
from pyro.contrib.autoname import *
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.distributions.utils import broadcast_all
from torch.distributions import constraints
import pyro.distributions as dist
from pyro.infer.autoguide import AutoDiagonalNormal, AutoMultivariateNormal, AutoLowRankMultivariateNormal, AutoLaplaceApproximation, init_to_mean
from pyro.infer import SVI, Trace_ELBO,TraceEnum_ELBO, Predictive
import pyro.poutine as poutine
from collections import OrderedDict

print('torch.__version__ ',torch.__version__)
print('pyro.__version__ ',pyro.__version__)

n_row = 1000
rng = np.random.default_rng()
df = pd.DataFrame({'x1' : rng.standard_normal(n_row),
                   'x2' : np.random.uniform(1, 10, n_row)})
df['sigma'] = df['x2']
df['mu'] = 2*df['x1'] + df['x2'] + 1000
df['y'] = rng.standard_normal(n_row) * df['sigma'] + df['mu']

if torch.cuda.is_available():
  device = torch.device('cuda')
else: 
  device = torch.device('cpu')

x_feat = ['x1', 'x2']
y_name = 'y'

torch.set_default_tensor_type("torch.cuda.FloatTensor")
class MyDataset(Dataset):

  def __init__(self,data):
    self.x_train=torch.tensor(data[['x1', 'x2']].values,dtype=torch.float32)
    self.y_train=torch.tensor(data['y'].values, dtype=torch.float32)

  def __len__(self):
    return len(self.y_train)
  
  def __getitem__(self,idx):
    return self.x_train[idx],self.y_train[idx]
ds = MyDataset(df)
train_loader=DataLoader(ds,batch_size=10,shuffle=False)
def train_and_evaluate_SVI(svi, model, guide, bs, ne):
    
  model = model.to(device)
  train_dataloader_iter = iter(train_loader)
  steps_per_epoch = len(train_dataloader_iter) // bs

  for epoch in range(ne):
    if (epoch + 1) % 10 == 0: print('-' * 10)
    if (epoch + 1) % 10 == 0: print('Epoch {}/{}'.format(epoch + 1, ne))

    train_loss = train_one_epoch_SVI(svi, train_dataloader_iter, steps_per_epoch, epoch, device)

  return train_loss

def train_one_epoch_SVI(svi, train_dataloader_iter, steps_per_epoch, epoch, device):
  model.train()
  
  running_loss = 0.0
  total_lives = 0
  
  # Iterate over the data for one epoch.
  for step in range(steps_per_epoch):
      x, y = next(train_dataloader_iter)
      inputs = x.to(device)
      labels = y.to(device)
      loss = svi.step(inputs, labels)

      # statistics
      running_loss += loss
      total_lives += inputs.shape[0]
  
  epoch_loss = running_loss  / total_lives
  
  print('Train Loss: {:.4f}'.format(epoch_loss))
  return epoch_loss
model = net(mu_l = [2, 4, 1], sh_l = [2, 4, 1])
#model = net_mu(mu_l = [2, 4, 1])
pyro.clear_param_store()
pyro.set_rng_seed(123456789)
loss_fnc = Trace_ELBO(num_particles=3, vectorize_particles=True)
#loss_fnc = Trace_ELBO()
model = model.to(device)
guide = AutoDiagonalNormal(model)
adam = pyro.optim.ClippedAdam({"lr": 0.001, 'betas': (.95, .999), 'weight_decay' : .25, 'clip_norm' : 10.}) 
svi = SVI(model, guide, adam, loss=loss_fnc)
train_and_evaluate_SVI(svi=svi, model = model, guide = guide, bs = 10, ne = 1)

I also made another simpler model net_mu with just mu that leaves shape a constant. It doesn’t change the problem but might be easier to use for debugging.

class net_mu(PyroModule):
  # 
    @name_count
    def __init__(self, mu_l):
        super().__init__()
        
        # parameter names list
        self.parameter_names = []
        
        layers = []
        for i in range(len(mu_l)-1):
            layers.append(('mu_fc' + str(i), myLinear(mu_l[i], mu_l[i+1], _print=True)))
            if i != (len(mu_l)-2): layers.append(('mu_ReLU' + str(i), nn.ReLU()))
        mu = OrderedDict(layers)
        self.mu = nn.Sequential(mu)
        
        for name, param in self.mu.named_parameters():
            self.parameter_names.append(name)
        
        pyro.nn.module.to_pyro_module_(self.mu)
        for m in self.mu.modules():
          if m._pyro_name == 'mu.mu_fc' + str(len(mu_l)-2):
            for name, value in list(m.named_parameters(recurse=False)):
              setattr(m, name, PyroSample(prior=dist.Normal(0., 1.).expand(value.shape).to_event(value.dim())))

    def forward(self, x, y=None):
        
        shape_cons = 5.
        mu = self.mu(x).exp().clamp(min = 10, max = 1e6)#.squeeze()
       mu = mu.squeeze()
        with pyro.plate("data", x.shape[0], device = device, dim=-1): 
          obs = pyro.sample("obs", dist.Normal(mu, shape_cons), obs=y)

        return  mu

To reiterate and shorten my problem, here are the shapes that get passed through when training without parallel, training with parallel, and predictive with parallel. Batch size = 10, 2 input features, 3 samples, and one hidden layer with four neurons. The first connection is not Bayesian so all are the same. The second connection is Bayesian. Training with multiple parallel particles adds 2 dimensions. Predictive in parallel adds one dimension.

First connection (input to hidden)

Without parallel Train (parallel) Predictive (parallel)
Input [10, 2] [10, 2] [10, 2]
Weight [4, 2] [4, 2] [4, 2]
Bias [4] [4] [4]
Output [10, 4] [10,4] [10, 4]

Second connection (hidden to output)

Without parallel Train (parallel) Predictive (parallel)
Input [10, 4] [10, 4] [10, 4]
Weight [1, 4] [3, 1, 1, 4] [3, 1, 4]
Bias [1] [3, 1, 1] [3, 1]
Output [10, 1] [3, 3, 10, 1] Error

The shapes for Predictive parallel look correct. For both parallel methods, I would expect the weight to be [3, 1, 4], bias to be [3, 1], and final output to be [3, 10, 1].

Hi @yoshy can you try removing all uses of .squeeze()? The raw .squeeze() is dangerous because it eliminates all dims of size 1, and therefore behaves very differently in different contexts. Sometimes I’ll use .squeeze(-k) to squeeze a particular dimension; this usage is safer (because it squeezes at most one dimension) but can still lead to some shape errors.

Removing squeeze results in an error for training:

ValueError: at site "obs", invalid log_prob shape
  Expected [10], actual [10, 10]
  Try one of the following fixes:
  - enclose the batched tensor in a with pyro.plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions

Using .squeeze(-1) works for training but still fails for predictive, the same error with squeeze().

RuntimeError: The size of tensor a (10) must match the size of tensor b (3) at non-singleton dimension 1
        Trace Shapes:          
         Param Sites:          
     mu.mu_fc0.weight   4 2    
       mu.mu_fc0.bias     4    
        Sample Sites:          
mu.mu_fc1.weight dist 3 1 | 1 4
                value   3 | 1 4
  mu.mu_fc1.bias dist 3 1 | 1  
                value   3 | 1

I had a closer look and at least one of the problems is in this line:

When it is not vectorized the shapes are:

input.shape == (10, 4)
self.weight.shape == (1, 4)
self.bias.shape == (1,)
output.shape == (10, 1)

When vectorized, the shapes of sampled weight and bias are:

self.weight.shape == (3, 1, 1, 4)
self.bias.shape == (3, 1, 1)

where dim=0 is for vectorized samples (num_samples=3), dim=1 is the plate (batch) dimension (corresponds to the dimension with size=10 in the input).

and input, output shapes are:

input.shape == (10, 4)
output.shape == (3, 3, 10, 1)  # this is wrong, there should be only one dim with size=3

You need to figure out matrix multiplication in the line that I pointed out so that it respects this new dimensions, in particular the output should have the shape of (3,10,1).
(Or in general, if you want to make use of vectorization, ensure that your code works as expected when new dimensions are added from the left and those new dims are line up correctly)

I don’t understand your model well, so it is hard to find out what is wrong. Making sure that dimension are lined up properly can be hard, at least it was for me when I started learning Pyro. I would recommend having a look at Tensor shapes in Pyro tutorial and then going back to your code and carefully checking the shapes of each variable and making sure that they are correct, and that your code works properly when additional dimension are added from the left.

Happily, I was able to get this working :smiley:

class myLinear_parallel(nn.Module):
  def __init__(self, in_features, out_features):
    super().__init__()
    self.in_features = in_features
    self.out_features = out_features
    self.weight = torch.nn.Parameter(torch.randn(out_features, in_features))    
    self.bias = torch.nn.Parameter(torch.randn(out_features))
       
  def forward(self, input):
    
    if self.weight.dim() > 2:
      tmpw = self.weight.view(-1, 1, self.weight.shape[-1])
      tmpb = self.bias.view(self.bias.shape[0], 1, 1)
      
    if input.dim() > 1:
      x, y = input.shape
      if y != self.in_features:
          sys.exit(f'Wrong Input Features. Please use tensor with {self.in_features} Input Features')
    if self.weight.dim() > 2:
      output = (input @ tmpw.transpose(-1, -2)) + tmpb
    else:
      output = (input @ self.weight.transpose(-1, -2)) + self.bias
    return output
2 Likes