Hello!
I noticed a curious difference in the shapes of random variables between calls to svi.step and pyro.infer.Predictive that will cause problems when performing matrix multiplication.
If I have a two plate model of shape (plate2_size, plate1_size) and a random variable located in only plate1, the shape of that random variable will be of shape (plate1_size) during calls to svi but of shape (1, plate1_size) during calls to Predictive. It seems that Predictive adds as many dimensions to the variables as there are plates. I include below simple code to illustrate this case:
import pyro
import torch
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.optim import Adam
from pyro.infer.autoguide import AutoNormal
pyro.set_rng_seed(101)
pyro.clear_param_store()
def model(x, plate1_dim, plate2_dim):
plate1 = pyro.plate("plate1", plate1_dim, dim=-1)
plate2 = pyro.plate("plate2", plate2_dim, dim=-2)
with plate2:
p_plate2 = pyro.sample("p_plate2", dist.Beta(2., 2.))
with plate1:
x_loc = pyro.sample("x_loc", dist.Normal(0., 1.))
print(x_loc.shape)
pyro.sample("x", dist.Normal(x_loc, 1.), obs=x)
guide = AutoNormal(model)
plate1_dim = 10
plate2_dim = 5
x = torch.rand(plate1_dim)
optim = Adam({"lr": 0.01})
elbo = Trace_ELBO()
svi = SVI(model, guide, optim, loss=elbo)
losses = []
for j in range(5):
svi.step(x, plate1_dim, plate2_dim)
torch.Size([10])
torch.Size([10])
torch.Size([10])
torch.Size([10])
torch.Size([10])
torch.Size([10])
predictive = Predictive(model, guide=guide, num_samples=5)(None, plate1_dim, plate2_dim)
torch.Size([10])
torch.Size([10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
In this case, this of course does not cause any issues and the model can run without error. But in more complex models involving matrix multiplication, such as with torch.einsum (where it is required to provide dimensions of each variable), there will be an error thrown because the dimensions of variables can change from svi.step to Predictive!
Is there a potential work around to this? I can use squeeze() to standardize the variable dimensions, but this will slow down model performance and isn’t completely fool-proof. Doesn’t it make more sense for all plate dimensions to be present in each random variable for svi, as they are during Predictive?
Thanks in advance for your feedback and help!