Minibatch-train a model with latent along the batch axis

Is it possible to train by minibatches a model that has a latent variable along the sample axes that I want to minibatch?

I am imagining the scenario where the shared latent variables should be updated at each iteration but the sample-dependent ones only when the corresponding sample is drawn in the mini batch.

Runnable code example:

import numpy as np
import torch
import pyro
from pyro import poutine 
from pyro.infer import Trace_ELBO, SVI, Predictive
from pyro.infer.autoguide import AutoDiagonalNormal

def inference_model(data):
    feature_plate = pyro.plate("feature_plate", size=90, dim=-1)
    sample_plate = pyro.plate("sample_plate", size=1000, subsample_size=20, dim=-2)
    u = pyro.sample("u", pyro.distributions.Normal(0, 10))
    with sample_plate:
        mu = pyro.sample("mu", pyro.distributions.Normal(u, 1))
    with feature_plate:
        with sample_plate as ind:
            X = pyro.sample("X", pyro.distributions.Normal(mu, 1), obs=data[ind])

pyro.clear_param_store()

guide = AutoDiagonalNormal(inference_model)
optim = pyro.optim.Adam({"lr": 0.005, "betas": (0.95, 0.999)})
elbo = Trace_ELBO()
svi = SVI(inference_model, guide, optim, loss=elbo)
init_loss = svi.loss(inference_model, guide, data)
for i in range(100):
    loss = svi.step(data)

what I would expect is for pyro to consider 1000 dimensional latent variable mu but instead pyro is only considering a batched latent variable of size 20. As revealed by

param_store = pyro.get_param_store()
param_store['AutoDiagonalNormal.loc'].shape
# Prints: torch.Size([21])

I think my problem is similar to this (this issue is about numpyro):

but I don’t seem to make it work for my example

Hi @Gioelelm, this is already possible, it’s just not really supported by Pyro’s autoguides. One thing you can do is use AutoGuideList to combine a standard autoguide for the non-subsampled latent variables (u here) and a custom guide for the subsampled ones (mu here).

We usually recommend using amortized guides for subsampled local latent variables, because non-amortized guides require storing a local parameter for each datapoint in the entire dataset (somewhat negating the point of subsampling) and because they are much simpler to apply to new data.

I will look into both options and come back with questions if I have any.
In the meanwhile, I was able to achieve sort of what I wanted by doing the following ( code below ).
Is it too much of a hack? Do you recommend avoiding this?

def inference_model(data, ind=torch.arange(1000)):
    feature_plate = pyro.plate("feature_plate", size=90, dim=-1)
    sample_plate = pyro.plate("sample_plate", size=1000, subsample=ind, dim=-2)
    u = pyro.sample("u", pyro.distributions.Normal(0, 10))
    with sample_plate:
        mu = pyro.sample("mu", pyro.distributions.Normal(u, 1))
    with feature_plate:
        with sample_plate as ind:
            X = pyro.sample("X", pyro.distributions.Normal(mu[ind], 1), obs=data[find])

data = pyro.distributions.Normal(0, 1).sample((1000,90)) + torch.arange(1000)[:, None]

from pyro.infer import Trace_ELBO, SVI, Predictive
from pyro.infer.autoguide import AutoDiagonalNormal

pyro.clear_param_store()

guide = AutoDiagonalNormal(inference_model)
optim = pyro.optim.Adam({"lr": 0.005, "betas": (0.95, 0.999)})
elbo = Trace_ELBO()
svi = SVI(inference_model, guide, optim, loss=elbo)
init_loss = svi.loss(inference_model, guide, data, ind=None)
for i in range(10000):
    loss = svi.step(data, torch.randint(0, 1000, (40,)))

Also if in the meanwhile you could draft a solution with AutoGuideList, it would be immensely helpful.

It is correct for AutoNormal but would not work with other autoguides whose implementations sample all latent variable values from a single joint Distribution object, e.g. AutoMultivariateNormal.

Note also that the guide is keeping the local parameters for the entire dataset in memory on whatever device they were initialized on, which sort of defeats the point of subsampling. There may be specific contexts where subsampling is useful for other reasons, e.g. when a model or guide includes a deep neural network that only learns effectively with minibatch gradients, but for the most part if all of your data fits into GPU memory then subsampling is just needlessly injecting noise and you should use full-batch gradients.

I’m afraid I don’t have time to write and test a complete working example and I can’t think of an existing one off the top of my head, but here’s a snippet that you can build on:

class AmortizedGuide(PyroModule):
    def __init__(self, data_size, hidden_size, feature_size):
        super().__init__()
        self.data_size = data_size
        self.encoder = PyroModule[nn.Sequential]([
            nn.Linear(feature_size, hidden_size),
            nn.ReLU(), 
            nn.Linear(hidden_size, 2)
        ])

    def forward(data, ind=None):
        with pyro.plate("sample_plate", size=self.data_size, subsample=ind, dim=-2):
            loc_logscale = self.encoder(data)
            loc, scale = loc_logscale[..., 0], torch.exp(loc_logscale[..., 1])
            pyro.sample("mu", Normal(loc, scale))

guide = AutoGuideList(model)
guide.append(AutoNormal(poutine.block(model, expose=["u"])))
guide.append(AmortizedGuide(1000, 10, 90))

Thank you for the clear reply.

Note also that the guide is keeping the local parameters for the entire dataset in memory on whatever device they were initialized on, which sort of defeats the point of subsampling

I was reasoning that it could speed up learning in situations where the hierarchy and rank of the latent variables conveniently allows it. For example, in matrix factorization where we would be learning the right low rank matrix even if not all the rows of the left matrix are not all determined.
But I can see that maybe that it might be a wrong intuition. I will try with amortization.