Predictive adds unnecessary dimension when using stacked plates

Hi, I am using pyro version 1.4.0. I have an issue with the Predictive class adding an extra dimension on one of my random variables. To present this issue, I have made the following Minimal Working Example:

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import Predictive

b_s = 2
max_num_targets = 5
num_logits_prior = torch.ones(max_num_targets+1)/(max_num_targets+1)  
num_logits_posterior = torch.ones(b_s,max_num_targets+1)/(max_num_targets+1)
dim_z = 3
mu_z_init = torch.zeros(dim_z)
sigma_z_init = torch.ones(dim_z)

def model():
    with pyro.plate('batch_plate', b_s, dim=-1):
        num_targets = pyro.sample('num_targets', dist.Categorical(logits=num_logits_prior))
        print('num_targets shape:', num_targets.shape)            
        '''
        with pyro.plate('targets_plate', max_num_targets, dim=-2):
            z = pyro.sample('z', dist.Normal(mu_z_init, sigma_z_init).to_event(1))
            print('z shape:', z.shape)
        '''
def guide():
    with pyro.plate('batch_plate', b_s, dim=-1):
        num_targets = pyro.sample('num_targets', dist.Categorical(logits=num_logits_posterior))
        '''
        with pyro.plate('targets_plate', max_num_targets, dim=-2):
            pyro.sample('z', dist.Normal(torch.ones(max_num_targets, b_s, dim_z), 0.1*torch.ones(max_num_targets, b_s, dim_z)).to_event(1))
        '''
           
num_samples = 10        
predictive = Predictive(model, guide=guide, num_samples=num_samples, parallel=True) 
posterior = predictive()

When ran as is, I observe the following output:

num_targets shape: torch.Size([2])
num_targets shape: torch.Size([2])
num_targets shape: torch.Size([10, 2])

I am only interested in the last two lines - the variable num_targets has, as expected, the dimensions of the batch_plate, and the extra (leftmost) vectorized dimension added by Predictive.

With the triple quotes removed in both the model and the guide, I observe the following output:

num_targets shape: torch.Size([2])
z shape: torch.Size([5, 2, 3])
num_targets shape: torch.Size([2])
z shape: torch.Size([5, 2, 3])
num_targets shape: torch.Size([10, 1, 2])
z shape: torch.Size([10, 5, 2, 3])

Again, I am interested in the last three lines. Whereas the shape of z is correct, the shape of num_targets has an extra dimension at index -2.

Hi, I’m not sure I understand what’s going on in make_targets_mask, but the shape manipulation logic looks suspicious - can you clarify the behavior you’re expecting here? What is the expected shape of num_targets_mask, and what effect do you expect it to have on site z?

num_targets has an extra batch dimension because there is an extra plate in your model and the extra batch dimension introduced by Predictive will always be to the left of all plates in the model.

In general, you should treat batch dimensions that are not explicitly set in a plate (such as the dimension introduced by Predictive, or new dimensions introduced during parallel enumeration) as implementation details that are subject to change under different interpretations and different versions of Pyro.

Hi, in fact the make_targets_mask is unnecessary for the point I am making here, I shouldn’t have included that function in the first place. I have modified my initial question removing this, and changing the b_s variable to 2.

The issue only occurs when the following code (and the equivalent in the guide) is added/removed in the model:

with pyro.plate('targets_plate', max_num_targets, dim=-2):
    z = pyro.sample('z', dist.Normal(mu_z_init, sigma_z_init).to_event(1)) 
    print('z shape:', z.shape)

Before it is added, the Predictive class correctly adds a vectorized plate of size num_samples=10, as expected, and the shape of num_targets becomes torch.Size([10, 2]).

However, after z is added, the shape of num_targets becomes torch.Size([10, 1, 2]). For me, that is unexpected, as the random variables num_targets and z are unrelated. This causes errors when broadcasting downstream in the model.

I think this comes from a trade-off in this PR to solve a technical issue. This makes the results of parallel=True and parallel=False consistent.

However, after z is added, the shape of num_targets becomes torch.Size([10, 1, 2]). For me, that is unexpected, as the random variables num_targets and z are unrelated.

This is expected behavior, and is a result of introducing the second plate targets_plate rather than any implicit relationship between num_targets and z or an incorrect shape at num_targets. Predictive moves its sample batch dimension to the left of all plates in the model.

To avoid being tripped up by this, your Pyro model code should always assume there is an arbitrary batch shape to the left of the leftmost plate dimension of samples. See the “Writing parallelizable code” section of the tensor shape tutorial for a guide to handling this behavior and an extended example.

Does that answer your question? If you’re still having trouble with shape errors downstream in your model after looking through the tutorial, feel free to post them here and ask for more help.