Hey guys,
I’m trying to use Pyro’s discrete exact inference to solve this problem (trying to master this for purposes of solving other problems).
My main question is how to set-up the guide. On the model side, I have code structure like this:
def model(self,data, batch_indices, freeze_params=False, data_impute=None, obs_indices=None):
# other code
prior_probs = pyro.param( "num_fig_prob",
torch.tensor([.01,.495,.495], device=self.device) / (self.max_figs_per_image+1),
constraint=constraints.simplex
)
with pyro.plate('data', N, subsample_size=300) as indices:
num_figs = pyro.sample(f"num_figs", pyro.distributions.Categorical(prior_probs),infer={"enumerate": "parallel"})
# code similar to tutorial
def prior_step(self, n, t, prev_x, indices, num_figs):
mask = (t<=num_figs)
mask_1=mask.unsqueeze(-1)
mask_2=mask.unsqueeze(-1).unsqueeze(-1)
z_where = pyro.sample(f'where_{t}_0',
pyro.distributions.Normal(
loc=torch.ones_like(mask_1, device=self.device) * torch.tensor([3., 0., 0.],device=device).expand(self.batch_size,-1),
#loc=torch.tensor([3., 0., 0.],device=device).expand(self.batch_size,-1),
scale=torch.tensor([.1, 1.,1.],device=device).expand(self.batch_size,-1)
).to_event(1).mask(mask)
)
z_what = pyro.sample(f'what_{t}_0',
pyro.distributions.Normal(
loc=torch.ones_like(mask_1, device=self.device) * torch.tensor(0.,device=device),
#loc=torch.tensor(0.,device=device),
scale=torch.ones(self.batch_size,self.lat_dim,device=device)
).to_event(1).mask(mask)
)
# other code
However, on the guide side, do I mask the log-probability similarly? For example, would I do something like this?
def step_guide(self, t, data, indices, prev, freeze_params=False):
# other code here
num_figs = torch.arange(self.max_figs_per_image+1,device=self.device).reshape(-1,1,1)
mask = (t<=num_figs)
mask_1=mask.unsqueeze(-1)
mask_2=mask.unsqueeze(-1).unsqueeze(-1)
dist_where = pyro.distributions.Normal(
loc=torch.ones_like(mask_1,device=self.device)*z_where_loc,
scale=z_where_scale
)
scale_where = 1.0
with pyro.poutine.scale(scale=scale_where):
z_where = pyro.sample(f'where_{t}_0',
dist_where.to_event(1).mask(mask)
)
#other code here
x_att = image_to_object(z_where, data)
z_what_loc, z_what_scale = encode(x_att)
dist_what = pyro.distributions.Normal(
loc=torch.ones_like(mask_1,device=self.device)*z_what_loc,
scale=z_what_scale#torch.exp(lat_embed_scale)
)
scale_what = scale_where
with pyro.poutine.scale(scale=scale_what):
z_what = pyro.sample(f'what_{t}_0',
dist_what.to_event(1).mask(mask)
)
# more code
Basically - would I have to create masks that match the model side - so I appropriately mask the probabilities for a given value of t?
Thanks,
Mike