Error about Shape Mismatch when introducing Mini-Batch

I am trying to introducing mini-batch training and sample latent variables in a model. But I met the
ValueError: Shape mismatch inside plate(‘latent’) at site latent_variables dim -1, 8 vs 2 error. It is mainly because there are some samples left (of which the number is not the same as the set batch_size) in the dataloader. I modified the example in Bayesian Regression - Introduction (Part 1) — Pyro Tutorials 1.8.2 documentation to reproduce the problem in my model. I have tried to control the dependence/indepence of dimensions of the latent variables but still failed. Could you please provide some suggestions about how to solve it? Thanks!
Following are the codes:

import logging
import os

import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from torch.distributions import constraints
import torch.nn as nn
from torch.utils.data import DataLoader

import pyro
import pyro.distributions as dist
import pyro.optim as optim
from pyro.nn import PyroSample, PyroModule
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.infer import SVI, Trace_ELBO
from pyro.infer import Predictive

class TorchDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        data = np.float32(self.data[index])
        return data

    def __len__(self):
        return len(self.data)


class BayesianRegression(PyroModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
        self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))

    def forward(self, x, y=None):
        sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
        mean = self.linear(x).squeeze(-1)

        loc = torch.zeros((x.shape[0], 5))
        scale = torch.ones((x.shape[0], 5))

        with pyro.plate('latent', x.shape[0]): 
            # The latent variables are generated arbitrarily, but I hope the shape of it is [x.shape[0], num_latent_features]
            latent_variables = pyro.sample('latent_variables', dist.Normal(loc, scale).to_event(1))

        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
        return mean


def summary(samples):
    site_stats = {}
    for k, v in samples.items():
        site_stats[k] = {
            "mean": torch.mean(v, 0),
            "std": torch.std(v, 0),
            "5%": v.kthvalue(int(len(v) * 0.05), dim=0)[0],
            "95%": v.kthvalue(int(len(v) * 0.95), dim=0)[0],
        }
    return site_stats

if __name__ == '__main__':
    smoke_test = ('CI' in os.environ)
    pyro.set_rng_seed(1)

    num_iterations = 2

    DATA_URL = "https://d2hg8soec8ck9v.cloudfront.net/datasets/rugged_data.csv"
    data = pd.read_csv(DATA_URL, encoding="ISO-8859-1")
    df = data[["cont_africa", "rugged", "rgdppc_2000"]]
    df = df[np.isfinite(df.rgdppc_2000)]
    df["rgdppc_2000"] = np.log(df["rgdppc_2000"])
    df["cont_africa_x_rugged"] = df["cont_africa"] * df["rugged"]
    data = torch.tensor(df[["cont_africa", "rugged", "cont_africa_x_rugged", "rgdppc_2000"]].values,
                        dtype=torch.float)

    dataloader = DataLoader(TorchDataset(data), batch_size=8, drop_last=False, num_workers=2)

    # x_data, y_data = data[:, :-1], data[:, -1]

    model = BayesianRegression(3, 1)
    guide = AutoDiagonalNormal(model)

    adam = pyro.optim.Adam({"lr": 0.03})
    svi = SVI(model, guide, adam, loss=Trace_ELBO())

    pyro.clear_param_store()
    for j in range(num_iterations):

        batch_loss = 0.0
        i = 0
        for data_ in dataloader:
            i += 1
            print('Batch - ', i)
            x_data, y_data = data_[:, :-1], data_[:, -1]
            batch_loss += svi.step(x_data, y_data)

        loss = batch_loss / len(dataloader)

        print("[iteration %04d] loss: %.4f" % (j + 1, loss))

    predictive = Predictive(model, guide=guide, num_samples=800,
                            return_sites=("linear.weight", "obs", "_RETURN"))
    samples = predictive(x_data)
    pred_summary = summary(samples)
    # print(pred_summary)

assuming X is of shape num_data x num_covariates you probably want x.shape[1] in the first plate?

Hi! Thanks for your reply. I do need a batch of latent variables whose shape is num_data x num_latent_dimensions. But I am confused why I should use x.shape[1] in the first plate. I modified the codes as you suggested but it still can not work.

“it can not work” is not helpful. please provide stack traces

The error in the first version of codes in this post is as follows:

Batch -  1
Batch -  2
Batch -  3
Batch -  4
Batch -  5
Batch -  6
Batch -  7
Batch -  8
Batch -  9
Batch -  10
Batch -  11
Batch -  12
Batch -  13
Batch -  14
Batch -  15
Batch -  16
Batch -  17
Batch -  18
Batch -  19
Batch -  20
Batch -  21
Batch -  22
Traceback (most recent call last):
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\nn\module.py", line 426, in __call__
    return super().__call__(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "E:/科研/科研探索/20220901/test.py", line 47, in forward
    latent_variables = pyro.sample('latent_variables', dist.Normal(loc, scale).to_event(1))
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\primitives.py", line 163, in sample
    apply_stack(msg)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\runtime.py", line 213, in apply_stack
    frame._process_message(msg)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\plate_messenger.py", line 19, in _process_message
    return BroadcastMessenger._pyro_sample(msg)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\contextlib.py", line 52, in inner
    return func(*args, **kwds)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\broadcast_messenger.py", line 71, in _pyro_sample
    target_batch_shape[f.dim],
ValueError: Shape mismatch inside plate('latent') at site latent_variables dim -1, 8 vs 2

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "E:/科研/科研探索/20220901/test.py", line 99, in <module>
    batch_loss += svi.step(x_data, y_data)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\svi.py", line 145, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\trace_elbo.py", line 140, in loss_and_grads
    for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\elbo.py", line 182, in _get_traces
    yield self._get_trace(model, guide, args, kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\trace_elbo.py", line 58, in _get_trace
    "flat", self.max_plate_nesting, model, guide, args, kwargs
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\enum.py", line 67, in get_importance_trace
    ).get_trace(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 198, in get_trace
    self(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 180, in __call__
    raise exc from e
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\nn\module.py", line 426, in __call__
    return super().__call__(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "E:/科研/科研探索/20220901/test.py", line 47, in forward
    latent_variables = pyro.sample('latent_variables', dist.Normal(loc, scale).to_event(1))
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\primitives.py", line 163, in sample
    apply_stack(msg)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\runtime.py", line 213, in apply_stack
    frame._process_message(msg)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\plate_messenger.py", line 19, in _process_message
    return BroadcastMessenger._pyro_sample(msg)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\contextlib.py", line 52, in inner
    return func(*args, **kwds)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\broadcast_messenger.py", line 71, in _pyro_sample
    target_batch_shape[f.dim],
ValueError: Shape mismatch inside plate('latent') at site latent_variables dim -1, 8 vs 2
     Trace Shapes:        
      Param Sites:        
     Sample Sites:        
        sigma dist   |    
             value   |    
linear.weight dist   | 1 3
             value   | 1 3
  linear.bias dist   | 1  
             value   | 1  
       latent dist   |    
             value 8 |    

Process finished with exit code 1

After I change the x.shape in the first plate, the error is as follows:

D:\Users\83451\anaconda3\envs\BaseEnv\python.exe E:/科研/科研探索/20220901/test.py 
Batch -  1
Traceback (most recent call last):
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\nn\module.py", line 426, in __call__
    return super().__call__(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "E:/科研/科研探索/20220901/test.py", line 47, in forward
    latent_variables = pyro.sample('latent_variables', dist.Normal(loc, scale).to_event(1))
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\primitives.py", line 163, in sample
    apply_stack(msg)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\runtime.py", line 213, in apply_stack
    frame._process_message(msg)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\plate_messenger.py", line 19, in _process_message
    return BroadcastMessenger._pyro_sample(msg)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\contextlib.py", line 52, in inner
    return func(*args, **kwds)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\broadcast_messenger.py", line 71, in _pyro_sample
    target_batch_shape[f.dim],
ValueError: Shape mismatch inside plate('latent') at site latent_variables dim -1, 3 vs 8

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\nn\module.py", line 426, in __call__
    return super().__call__(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\autoguide\guides.py", line 759, in forward
    self._setup_prototype(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\autoguide\guides.py", line 935, in _setup_prototype
    super()._setup_prototype(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\autoguide\guides.py", line 636, in _setup_prototype
    super()._setup_prototype(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\autoguide\guides.py", line 158, in _setup_prototype
    *args, **kwargs
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 198, in get_trace
    self(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 180, in __call__
    raise exc from e
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\nn\module.py", line 426, in __call__
    return super().__call__(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "E:/科研/科研探索/20220901/test.py", line 47, in forward
    latent_variables = pyro.sample('latent_variables', dist.Normal(loc, scale).to_event(1))
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\primitives.py", line 163, in sample
    apply_stack(msg)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\runtime.py", line 213, in apply_stack
    frame._process_message(msg)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\plate_messenger.py", line 19, in _process_message
    return BroadcastMessenger._pyro_sample(msg)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\contextlib.py", line 52, in inner
    return func(*args, **kwds)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\broadcast_messenger.py", line 71, in _pyro_sample
    target_batch_shape[f.dim],
ValueError: Shape mismatch inside plate('latent') at site latent_variables dim -1, 3 vs 8
     Trace Shapes:      
      Param Sites:      
     Sample Sites:      
        sigma dist |    
             value |    
linear.weight dist | 1 3
             value | 1 3
  linear.bias dist | 1  
             value | 1  

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "E:/科研/科研探索/20220901/test.py", line 99, in <module>
    batch_loss += svi.step(x_data, y_data)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\svi.py", line 145, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\trace_elbo.py", line 140, in loss_and_grads
    for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\elbo.py", line 182, in _get_traces
    yield self._get_trace(model, guide, args, kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\trace_elbo.py", line 58, in _get_trace
    "flat", self.max_plate_nesting, model, guide, args, kwargs
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\enum.py", line 61, in get_importance_trace
    *args, **kwargs
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 198, in get_trace
    self(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 180, in __call__
    raise exc from e
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\nn\module.py", line 426, in __call__
    return super().__call__(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\autoguide\guides.py", line 759, in forward
    self._setup_prototype(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\autoguide\guides.py", line 935, in _setup_prototype
    super()._setup_prototype(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\autoguide\guides.py", line 636, in _setup_prototype
    super()._setup_prototype(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\infer\autoguide\guides.py", line 158, in _setup_prototype
    *args, **kwargs
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 198, in get_trace
    self(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 180, in __call__
    raise exc from e
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\nn\module.py", line 426, in __call__
    return super().__call__(*args, **kwargs)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "E:/科研/科研探索/20220901/test.py", line 47, in forward
    latent_variables = pyro.sample('latent_variables', dist.Normal(loc, scale).to_event(1))
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\primitives.py", line 163, in sample
    apply_stack(msg)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\runtime.py", line 213, in apply_stack
    frame._process_message(msg)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\plate_messenger.py", line 19, in _process_message
    return BroadcastMessenger._pyro_sample(msg)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\contextlib.py", line 52, in inner
    return func(*args, **kwds)
  File "D:\Users\83451\anaconda3\envs\BaseEnv\lib\site-packages\pyro\poutine\broadcast_messenger.py", line 71, in _pyro_sample
    target_batch_shape[f.dim],
ValueError: Shape mismatch inside plate('latent') at site latent_variables dim -1, 3 vs 8
     Trace Shapes:      
      Param Sites:      
     Sample Sites:      
        sigma dist |    
             value |    
linear.weight dist | 1 3
             value | 1 3
  linear.bias dist | 1  
             value | 1  
Trace Shapes:
 Param Sites:
Sample Sites:

Process finished with exit code 1

If convenient, the code presented above can be directly run. Thank you very much!

if using autoguides with data subsampling you must use the create_guides argument. see this test for a complete example.

Thank you very much!