Error with variable minibatch size

Hey!
I am having issues in training a model with SVI using minibatches. Since my model
only has local latent variables, following what was suggested in SVI Part II: Conditional Independence, Subsampling, and Amortization — Pyro Tutorials 1.9.0 documentation
in the section relative to Subsampling when there are only local random variables,
i used pytorch DataLoader to generate minibatches outside of the model (so far so good ?).
Since the training dataset is not exactly divisible by the size of the minibatch the last minibatch is
smaller than the others and this raises an issue when training the model. I suppose this can be also
a problem for later, when I will need to etestuate the model on the train set, or more in general on
an arbitrary number of data points. This is my model and how i train it

import torch
from torch.distributions import constraints
from torch.utils.data import TensorDataset
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

def setup_data_loaders(X, Y, train_mask, test_mask,resplit = False, batch_size=128):

    if not resplit: 
        train_set = TensorDataset(X[train_mask], Y[train_mask])
        test_set = TensorDataset(X[test_mask], Y[test_mask])
    else:
        raise NotImplementedError("new train test split still not implemented")

    kwargs = {'num_workers': 1, 'pin_memory': True if torch.cuda.is_available() else False}
    train_loader = torch.utils.data.DataLoader(dataset=train_set,
        batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=test_set,
        batch_size=batch_size, shuffle=False, **kwargs)
    return train_loader, test_loader

def train_one_epoch(svi, train_loader):

    epoch_loss = 0.

    for x, y in train_loader:
        x.to(device)
        y.to(device)
        epoch_loss += svi.step(x, y)

    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = epoch_loss / normalizer_train
    return total_epoch_loss_train

def etestuate(svi, test_loader):

    test_loss = 0.

    for x, y in test_loader:
        x.to(device)
        y.to(device)
        test_loss += svi.etestuate_loss(x, y)

    normalizer_train = len(test_loader.dataset)
    total_epoch_loss_test = test_loss / normalizer_train
    return total_epoch_loss_test

def train(model, guide, train_loader, test_loader, lr=0.001, NUM_EPOCHS = 100, test_FREQUENCY = 5):
    pyro.clear_param_store()

    train_elbo = []
    test_elbo = []

    adam_params = {"lr": lr}
    adam = pyro.optim.Adam(adam_params)

    svi = SVI(model, guide, adam, loss=Trace_ELBO())
    

    for epoch in range(NUM_EPOCHS):
        total_epoch_loss_train = train_one_epoch(svi, train_loader)
        train_elbo.append(-total_epoch_loss_train)
        print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))

        if epoch % test_FREQUENCY == 0:
            # report train diagnostics
            total_epoch_loss_train = etestuate(svi, test_loader)
            test_elbo.append(-total_epoch_loss_train)
            print("[epoch %03d] average train loss: %.4f" % (epoch, total_epoch_loss_train))
    
    return train_elbo, test_elbo

def latent_model(X, Y):

    batch_size = X.shape[0]

    c = pyro.param("c", lambda: torch.randn(())).to(device)
    gamma = pyro.param("gamma", torch.randn(2)).to(device)

    plate_individuals = pyro.plate("N", size = X.shape[0])

    with plate_individuals:
        z = pyro.sample("z", dist.MultivariateNormal(torch.zeros(2, device = device), torch.eye(2, device = device)) )

    W_G   = pyro.param("W_G", torch.randn(1760, 2, 2)).to(device)
    W_Gxz = torch.permute(torch.matmul(W_G, z.T), (2,0,1))
    mean = c + torch.matmul(z, gamma)

    with plate_individuals:

        pyro.sample("X", dist.Categorical(logits=torch.cat((W_Gxz, torch.ones_like(W_Gxz[..., :1])), dim=2).to(device)).to_event(1), obs=X)
        pyro.sample("y", dist.Bernoulli(logits=mean), obs=Y)


"""here i import X and Y and preprocess them"""

auto_guide_normal = pyro.infer.autoguide.guides.AutoNormal(latent_model)
train_loader, test_loader = setup_data_loaders(X, Y, train_idx, test_idx)
train(latent_model, auto_guide_normal, train_loader, test_loader, lr=0.001, NUM_EPOCHS = 200, test_FREQUENCY = 5)

Thank you very much!! :slight_smile:
Federico