# Potential parallel inference bug - Attend Infer Repeat (AIR) - Exact Parallel Inference - Questions

Hey guys,

I’m trying to use Pyro’s discrete exact inference to solve this problem (trying to master this for purposes of solving other problems).

My main question is how to set-up the guide. On the model side, I have code structure like this:

``````def model(self,data, batch_indices, freeze_params=False, data_impute=None, obs_indices=None):

# other code

prior_probs = pyro.param( "num_fig_prob",
torch.tensor([.01,.495,.495], device=self.device) / (self.max_figs_per_image+1),
constraint=constraints.simplex
)

with pyro.plate('data', N, subsample_size=300) as indices:
num_figs = pyro.sample(f"num_figs", pyro.distributions.Categorical(prior_probs),infer={"enumerate": "parallel"})
# code similar to tutorial

def prior_step(self, n, t, prev_x, indices, num_figs):

z_where = pyro.sample(f'where_{t}_0',
pyro.distributions.Normal(
loc=torch.ones_like(mask_1, device=self.device) * torch.tensor([3., 0., 0.],device=device).expand(self.batch_size,-1),
#loc=torch.tensor([3., 0., 0.],device=device).expand(self.batch_size,-1),
scale=torch.tensor([.1, 1.,1.],device=device).expand(self.batch_size,-1)
)

z_what = pyro.sample(f'what_{t}_0',
pyro.distributions.Normal(
#loc=torch.tensor(0.,device=device),
scale=torch.ones(self.batch_size,self.lat_dim,device=device)
)

# other code

``````

However, on the guide side, do I mask the log-probability similarly? For example, would I do something like this?

``````def step_guide(self, t, data, indices, prev, freeze_params=False):

# other code here

num_figs = torch.arange(self.max_figs_per_image+1,device=self.device).reshape(-1,1,1)

dist_where = pyro.distributions.Normal(
scale=z_where_scale
)

scale_where = 1.0
with pyro.poutine.scale(scale=scale_where):
z_where = pyro.sample(f'where_{t}_0',
)

#other code here

x_att = image_to_object(z_where, data)
z_what_loc, z_what_scale = encode(x_att)

dist_what = pyro.distributions.Normal(
scale=z_what_scale#torch.exp(lat_embed_scale)
)

scale_what = scale_where
with pyro.poutine.scale(scale=scale_what):
z_what = pyro.sample(f'what_{t}_0',
)

# more code
``````

Basically - would I have to create masks that match the model side - so I appropriately mask the probabilities for a given value of t?

Thanks,

Mike

There is also a potential bug here - the traceenum_elbo code can’t determine that the downstream sample sites (downstream from the discrete sample-site) have a dependency on the on discrete sample site. You might recognize this code from traceenum_elbo.py:

``````enum_dims -= non_enum_dims

# other code

for t, sites_t in cost_sites.items():
for site in sites_t:
if enum_dims.isdisjoint(site["packed"]["log_prob"]._pyro_dims):
# For sites that do not depend on an enumerated variable, proceed as usual.
marginal_costs.setdefault(t, []).append(site["packed"]["log_prob"])
else:
# For sites that depend on an enumerated variable, we need to apply
# the mask inside- and the scale outside- of the log expectation.