Mask distribution based on enumerate discrete variables


Hello this question is related to the “Attend Infer Repeat” (AIR) example ( I am running pyro 0.3.0.

During inference in AIR an image patch is encoded into a continuous latent variable z_what and a discrete random variable z_present. If the patch is “active”, i.e. z_present=1, then z_what is expected to be drawn from a Normal prior. If the patch is “inactive”, i.e. z_present=0, then z_what is unconstrained. This is achieved by:

  1. In the model z_present is sampled first
  2. In the model z_what is sampled from a masked distribution (masked by z_present)

My question is:

How can I use z_present to mask the distribution of z_what if z_present is fully enumerated (instead of being sampled)?

The problem is that when a variable is enumerated in parallel an extra dimension is added to the left and therefore I have a mismatch of dimensions which makes it impossible to use z_present to mask the distribution of z_what.

I understand the in AIR it makes no sense to do parallel enumeration since the inference procedure is sequential but in other application it does.

The code below is the minimal implementation to show the problem.
Masking works perfectly with Trace_ELBO (no parallel enumeration of z_present) but fails when using TraceEnum_ELBO (with parallel enumeration of z_present)

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import Trace_ELBO, TraceEnum_ELBO

def test_model(model, guide, loss):
loss.loss(model, guide)


def guide():
p = 0.5*torch.ones(batch_size,n_max_objects)
z_mu = torch.randn(batch_size,n_max_objects,dim_z_what)
z_std = torch.exp(torch.randn(batch_size,n_max_objects,dim_z_what))

with pyro.plate("batch", batch_size, dim=-2):
    with pyro.plate("n_objects", n_max_objects, dim =-1):
        z_pres = pyro.sample("prob_object",dist.Bernoulli(probs = p),infer={"enumerate": "parallel"}) 
        z_pres_mask = z_pres.unsqueeze(-1)
        z_what = pyro.sample("z_what",dist.Normal(z_mu,z_std).mask(z_pres_mask).to_event(1))            
        print("GUIDE z_pres.shape, z_what.shape",z_pres.shape,z_what.shape)

def model():
with pyro.plate(“batch”, batch_size, dim=-2):
with pyro.plate(“n_objects”, n_max_objects, dim =-1):
z_pres = pyro.sample(“prob_object”,dist.Bernoulli(probs = 0.5))
z_pres_mask = z_pres.unsqueeze(-1)
z_what = pyro.sample(“z_what”,dist.Normal(zero.expand(batch_size,n_max_objects,dim_z_what),one).mask(z_pres_mask).to_event(1))
print(“MODEL z_pres.shape, z_what.shape”,z_pres.shape,z_what.shape)

print(“TEST SAMPLE”)
test_model(model, guide, Trace_ELBO(max_plate_nesting=2))

test_model(model, guide, TraceEnum_ELBO(max_plate_nesting=2))