Is it possible to train by minibatches a model that has a latent variable along the sample axes that I want to minibatch?
I am imagining the scenario where the shared latent variables should be updated at each iteration but the sample-dependent ones only when the corresponding sample is drawn in the mini batch.
Runnable code example:
import numpy as np
import torch
import pyro
from pyro import poutine
from pyro.infer import Trace_ELBO, SVI, Predictive
from pyro.infer.autoguide import AutoDiagonalNormal
def inference_model(data):
feature_plate = pyro.plate("feature_plate", size=90, dim=-1)
sample_plate = pyro.plate("sample_plate", size=1000, subsample_size=20, dim=-2)
u = pyro.sample("u", pyro.distributions.Normal(0, 10))
with sample_plate:
mu = pyro.sample("mu", pyro.distributions.Normal(u, 1))
with feature_plate:
with sample_plate as ind:
X = pyro.sample("X", pyro.distributions.Normal(mu, 1), obs=data[ind])
pyro.clear_param_store()
guide = AutoDiagonalNormal(inference_model)
optim = pyro.optim.Adam({"lr": 0.005, "betas": (0.95, 0.999)})
elbo = Trace_ELBO()
svi = SVI(inference_model, guide, optim, loss=elbo)
init_loss = svi.loss(inference_model, guide, data)
for i in range(100):
loss = svi.step(data)
what I would expect is for pyro to consider 1000 dimensional latent variable mu but instead pyro is only considering a batched latent variable of size 20. As revealed by
param_store = pyro.get_param_store()
param_store['AutoDiagonalNormal.loc'].shape
# Prints: torch.Size([21])