Hi, I am using pyro version 1.4.0. I have an issue with the Predictive class adding an extra dimension on one of my random variables. To present this issue, I have made the following Minimal Working Example:
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import Predictive
b_s = 2
max_num_targets = 5
num_logits_prior = torch.ones(max_num_targets+1)/(max_num_targets+1)
num_logits_posterior = torch.ones(b_s,max_num_targets+1)/(max_num_targets+1)
dim_z = 3
mu_z_init = torch.zeros(dim_z)
sigma_z_init = torch.ones(dim_z)
def model():
with pyro.plate('batch_plate', b_s, dim=-1):
num_targets = pyro.sample('num_targets', dist.Categorical(logits=num_logits_prior))
print('num_targets shape:', num_targets.shape)
'''
with pyro.plate('targets_plate', max_num_targets, dim=-2):
z = pyro.sample('z', dist.Normal(mu_z_init, sigma_z_init).to_event(1))
print('z shape:', z.shape)
'''
def guide():
with pyro.plate('batch_plate', b_s, dim=-1):
num_targets = pyro.sample('num_targets', dist.Categorical(logits=num_logits_posterior))
'''
with pyro.plate('targets_plate', max_num_targets, dim=-2):
pyro.sample('z', dist.Normal(torch.ones(max_num_targets, b_s, dim_z), 0.1*torch.ones(max_num_targets, b_s, dim_z)).to_event(1))
'''
num_samples = 10
predictive = Predictive(model, guide=guide, num_samples=num_samples, parallel=True)
posterior = predictive()
When ran as is, I observe the following output:
num_targets shape: torch.Size([2])
num_targets shape: torch.Size([2])
num_targets shape: torch.Size([10, 2])
I am only interested in the last two lines - the variable num_targets has, as expected, the dimensions of the batch_plate, and the extra (leftmost) vectorized dimension added by Predictive.
With the triple quotes removed in both the model and the guide, I observe the following output:
num_targets shape: torch.Size([2])
z shape: torch.Size([5, 2, 3])
num_targets shape: torch.Size([2])
z shape: torch.Size([5, 2, 3])
num_targets shape: torch.Size([10, 1, 2])
z shape: torch.Size([10, 5, 2, 3])
Again, I am interested in the last three lines. Whereas the shape of z is correct, the shape of num_targets has an extra dimension at index -2.