Mask distribution based on enumerate discrete variables

Hello this question is related to the “Attend Infer Repeat” (AIR) example (http://pyro.ai/examples/air.html). I am running pyro 0.3.0.

During inference in AIR an image patch is encoded into a continuous latent variable z_what and a discrete random variable z_present. If the patch is “active”, i.e. z_present=1, then z_what is expected to be drawn from a Normal prior. If the patch is “inactive”, i.e. z_present=0, then z_what is unconstrained. This is achieved by:

  1. In the model z_present is sampled first
  2. In the model z_what is sampled from a masked distribution (masked by z_present)

My question is:

How can I use z_present to mask the distribution of z_what if z_present is fully enumerated (instead of being sampled)?

The problem is that when a variable is enumerated in parallel an extra dimension is added to the left and therefore I have a mismatch of dimensions which makes it impossible to use z_present to mask the distribution of z_what.

I understand the in AIR it makes no sense to do parallel enumeration since the inference procedure is sequential but in other application it does.

The code below is the minimal implementation to show the problem.
Masking works perfectly with Trace_ELBO (no parallel enumeration of z_present) but fails when using TraceEnum_ELBO (with parallel enumeration of z_present)

import torch 
import pyro
import pyro.distributions as dist
from pyro.infer import Trace_ELBO, TraceEnum_ELBO

def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)
    
    
batch_size=3
n_max_objects=8
dim_z_what=10
zero=torch.zeros([1])
one=torch.ones([1])

def guide():
    p = 0.5*torch.ones(batch_size,n_max_objects)
    z_mu  = torch.randn(batch_size,n_max_objects,dim_z_what)
    z_std = torch.exp(torch.randn(batch_size,n_max_objects,dim_z_what))
    
    with pyro.plate("batch", batch_size, dim=-2):
        with pyro.plate("n_objects", n_max_objects, dim =-1):
            z_pres = pyro.sample("prob_object",dist.Bernoulli(probs = p),infer={"enumerate": "parallel"}) 
            z_pres_mask = z_pres.unsqueeze(-1)
            z_what = pyro.sample("z_what",dist.Normal(z_mu,z_std).mask(z_pres_mask).to_event(1))            
            print("GUIDE z_pres.shape, z_what.shape",z_pres.shape,z_what.shape)

def model():
    with pyro.plate("batch", batch_size, dim=-2):
        with pyro.plate("n_objects", n_max_objects, dim =-1):
            z_pres = pyro.sample("prob_object",dist.Bernoulli(probs = 0.5))
            z_pres_mask = z_pres.unsqueeze(-1)
            z_what = pyro.sample("z_what",dist.Normal(zero.expand(batch_size,n_max_objects,dim_z_what),one).mask(z_pres_mask).to_event(1))            
            print("MODEL z_pres.shape, z_what.shape",z_pres.shape,z_what.shape)

            
print("TEST SAMPLE")
test_model(model, guide, Trace_ELBO(max_plate_nesting=2))

print("TEST PARALLEL ENUM")
test_model(model, guide, TraceEnum_ELBO(max_plate_nesting=2))

You should be able to use poutine.mask rather than the TorchDistributionMixin.mask() method in your case (and in most cases). The .mask() method is only needed when the mask has structure inside of the .event_dim() of an Independent (i.e. diagonal) distribution.

Here is an updated version of your model (you can do a similar rewrite of your guide):

def model():
    with pyro.plate("batch", batch_size, dim=-2):
        with pyro.plate("n_objects", n_max_objects, dim =-1):
            z_pres = pyro.sample("prob_object",dist.Bernoulli(probs = 0.5))
            z_pres_mask = (z_pres != 0)  # convert FloatTensor -> ByteTensor
            with poutine.mask(mask=z_pres_mask):
                z_what = pyro.sample("z_what",
                                     dist.Normal(zero.expand(dim_z_what),1.)
                                         .to_event(1))            
    print("MODEL z_pres.shape, z_what.shape",z_pres.shape,z_what.shape)