Potential parallel inference bug - Attend Infer Repeat (AIR) - Exact Parallel Inference - Questions

Hey guys,

I’m trying to use Pyro’s discrete exact inference to solve this problem (trying to master this for purposes of solving other problems).

My main question is how to set-up the guide. On the model side, I have code structure like this:

def model(self,data, batch_indices, freeze_params=False, data_impute=None, obs_indices=None):

       # other code

        prior_probs = pyro.param( "num_fig_prob",
            torch.tensor([.01,.495,.495], device=self.device) / (self.max_figs_per_image+1),
            constraint=constraints.simplex
        )

        with pyro.plate('data', N, subsample_size=300) as indices:
            num_figs = pyro.sample(f"num_figs", pyro.distributions.Categorical(prior_probs),infer={"enumerate": "parallel"})
        # code similar to tutorial

def prior_step(self, n, t, prev_x, indices, num_figs):

        mask = (t<=num_figs)
        mask_1=mask.unsqueeze(-1)
        mask_2=mask.unsqueeze(-1).unsqueeze(-1)

        z_where = pyro.sample(f'where_{t}_0',
                                    pyro.distributions.Normal(
                                        loc=torch.ones_like(mask_1, device=self.device) * torch.tensor([3., 0., 0.],device=device).expand(self.batch_size,-1),
                                        #loc=torch.tensor([3., 0., 0.],device=device).expand(self.batch_size,-1),
                                        scale=torch.tensor([.1, 1.,1.],device=device).expand(self.batch_size,-1)
                                    ).to_event(1).mask(mask)
                                )

        z_what = pyro.sample(f'what_{t}_0',
                                    pyro.distributions.Normal(
                                        loc=torch.ones_like(mask_1, device=self.device) * torch.tensor(0.,device=device),
                                        #loc=torch.tensor(0.,device=device),
                                        scale=torch.ones(self.batch_size,self.lat_dim,device=device)
                                    ).to_event(1).mask(mask)
                                )

# other code

However, on the guide side, do I mask the log-probability similarly? For example, would I do something like this?

def step_guide(self, t, data, indices, prev, freeze_params=False):

       # other code here

        num_figs = torch.arange(self.max_figs_per_image+1,device=self.device).reshape(-1,1,1)
        mask = (t<=num_figs)
        mask_1=mask.unsqueeze(-1)
        mask_2=mask.unsqueeze(-1).unsqueeze(-1)

        dist_where = pyro.distributions.Normal(
                                        loc=torch.ones_like(mask_1,device=self.device)*z_where_loc,
                                        scale=z_where_scale
                                    )

        scale_where = 1.0
        with pyro.poutine.scale(scale=scale_where):
            z_where = pyro.sample(f'where_{t}_0',
                                        dist_where.to_event(1).mask(mask)
                                    )

        #other code here

        x_att = image_to_object(z_where, data)
        z_what_loc, z_what_scale = encode(x_att)

        dist_what = pyro.distributions.Normal(
                                        loc=torch.ones_like(mask_1,device=self.device)*z_what_loc,
                                        scale=z_what_scale#torch.exp(lat_embed_scale)
                                   )

        scale_what = scale_where
        with pyro.poutine.scale(scale=scale_what):
            z_what = pyro.sample(f'what_{t}_0',
                                        dist_what.to_event(1).mask(mask)
                                    )

        # more code

Basically - would I have to create masks that match the model side - so I appropriately mask the probabilities for a given value of t?

Thanks,

Mike

There is also a potential bug here - the traceenum_elbo code can’t determine that the downstream sample sites (downstream from the discrete sample-site) have a dependency on the on discrete sample site. You might recognize this code from traceenum_elbo.py:

enum_dims -= non_enum_dims

# other code

 for t, sites_t in cost_sites.items():
        for site in sites_t:
            if enum_dims.isdisjoint(site["packed"]["log_prob"]._pyro_dims):
                # For sites that do not depend on an enumerated variable, proceed as usual.
                marginal_costs.setdefault(t, []).append(site["packed"]["log_prob"])
            else:
                # For sites that depend on an enumerated variable, we need to apply
                # the mask inside- and the scale outside- of the log expectation.
                if "masked_log_prob" not in site["packed"]:
                    site["packed"]["masked_log_prob"] = packed.scale_and_mask(
                        site["packed"]["unscaled_log_prob"], mask=site["packed"]["mask"]
                    )
                cost = site["packed"]["masked_log_prob"]
                log_factors.setdefault(t, []).append(cost)
                scales.append(site["scale"])

The code thinks that the sites don’t depend on the enumerated variable (which they do), and so the first block of the if-else statement gets executed incorrectly.

Is there a way to “trick” the software to “believe” there is a dependency on the enumerated variable, so the correct block in the if-else statement gets executed?

*** Update - it turns out that the issue is that, on the guide-side, the log_prob shape in my code has shape (enum_dim, plate-or-batch_dim). because I’m returning a tensor on the guide-side that matches the enum dim, the code assumes that this is not an enum dimension, which is not correct.

Is there a way to override this behavior? The only way I see a way to override this behavior is to mask the log-probs on the guide-side and add factors on the model side. However, this method of coding seems highly unusual and it seems like it would be best to avoid this if possible.