Variable shapes change between svi.step and Predictive

Hello!

I noticed a curious difference in the shapes of random variables between calls to svi.step and pyro.infer.Predictive that will cause problems when performing matrix multiplication.

If I have a two plate model of shape (plate2_size, plate1_size) and a random variable located in only plate1, the shape of that random variable will be of shape (plate1_size) during calls to svi but of shape (1, plate1_size) during calls to Predictive. It seems that Predictive adds as many dimensions to the variables as there are plates. I include below simple code to illustrate this case:

import pyro
import torch
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, Predictive
from pyro.optim import Adam
from pyro.infer.autoguide import AutoNormal

pyro.set_rng_seed(101)

pyro.clear_param_store()

def model(x, plate1_dim, plate2_dim):
    plate1 = pyro.plate("plate1", plate1_dim, dim=-1)
    plate2 = pyro.plate("plate2", plate2_dim, dim=-2)
    
    with plate2:
        p_plate2 = pyro.sample("p_plate2", dist.Beta(2., 2.))
    
    with plate1:
        x_loc = pyro.sample("x_loc", dist.Normal(0., 1.))
        print(x_loc.shape)
        pyro.sample("x", dist.Normal(x_loc, 1.), obs=x)

guide = AutoNormal(model)

plate1_dim = 10
plate2_dim = 5
x = torch.rand(plate1_dim)

optim = Adam({"lr": 0.01})
elbo = Trace_ELBO()
svi = SVI(model, guide, optim, loss=elbo)
losses = []
for j in range(5):
    svi.step(x, plate1_dim, plate2_dim)

torch.Size([10])
torch.Size([10])
torch.Size([10])
torch.Size([10])
torch.Size([10])
torch.Size([10])

predictive = Predictive(model, guide=guide, num_samples=5)(None, plate1_dim, plate2_dim)

torch.Size([10])
torch.Size([10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])
torch.Size([1, 10])

In this case, this of course does not cause any issues and the model can run without error. But in more complex models involving matrix multiplication, such as with torch.einsum (where it is required to provide dimensions of each variable), there will be an error thrown because the dimensions of variables can change from svi.step to Predictive!

Is there a potential work around to this? I can use squeeze() to standardize the variable dimensions, but this will slow down model performance and isn’t completely fool-proof. Doesn’t it make more sense for all plate dimensions to be present in each random variable for svi, as they are during Predictive?

Thanks in advance for your feedback and help!

1 Like

Predictive adds one additional dimension to the left that is used to concatenate num_samples. First two torch.Size([10]) you see are due to internal calling of function that guesses max plate nesting. Next five torch.Size([1, 10]) are from sequentially calling your model num_samples of times and then these samples are concatenated. It can be done in parallel by setting parallel=True flag in Predictive but it requires ensuring that your model can handle additional batch dimension from the left.

That is true indeed, so you have to be careful.