I played with AutoNormalMessenger a bit.
I’m happy to contribute some docs for conditioning in amortized inference, if the code is amenable to that. I need some help getting a small working example. Here is my attempt:
Ok I understand how to use .get_posterior
and some if/elif name == 'param':
control flow.
But some questions remain:
-
How do I feed data into it? Where would I put a neural net that takes in data in an amortized inference setting? Where do I register the params, inside get_posterior
where the pyro.param
statements were in the AutoNormalMessenger example?
-
What if I wanted to get stuff
? Would I have to put in a sort of “dummy distribution” on stuff in the model, with a very wide prior, and name it and sample from it (then discard the sample), and then do .get_posterior
in the guide?
Here is a simplified 1D example of my model
def model(data):
size_mini_batch = data.shape[-1]
with pyro.plate('mini_batch',size_mini_batch,dim=-1):
sigma_signal = pyro.sample('sigma_signal',dist.HalfNormal(sigma_signal_loc_gt,sigma_signal_scale_gt))
trans = pyro.sample('translate',dist.Normal(trans_prior_loc,trans_prior_scale))
atom_center_trans = atom_center.reshape(-1,1) + trans.reshape(1,-1)
proj = torch.exp(-((coords.reshape(-1,1,1)-atom_center_trans.reshape(1,-1,size_mini_batch))**2)/(2*sigma_signal**2)).sum(1)
with pyro.plate('pixel',num_pix,dim=-2):
obs_dist = dist.Normal(proj,sigma_noise_gt)
pyro.sample("noise", obs_dist, obs=data)
sim = obs_dist.sample()
return sim, proj, trans, sigma_signal
And my guide
def guide(data):
size_mini_batch = data.shape[-1]
pyro.module("mlp", mlp)
with pyro.plate('mini_batch', size_mini_batch, dim=-1):
lam = mlp(data.T)
trans_loc, trans_log_scale, sigma_signal_log_loc, sigma_signal_log_scale = lam.T
posterior_trans_dist = dist.Normal(trans_loc,torch.exp(trans_log_scale))
trans = pyro.sample("translate", posterior_trans_dist)
posterior_sigma_signal_dist = dist.Gamma(torch.exp(sigma_signal_log_loc),torch.exp(sigma_signal_log_scale)) # half normal doesn't work for some reason, not sure why
sigma_signal = pyro.sample("sigma_signal", posterior_sigma_signal_dist)
return trans, sigma_signal, lam
Here is some pseudo code from my attempt to extend from the AutoNormalMessenger example in the docs
class MyGuideMessenger(AutoNormalMessenger):
def get_posterior(self, name, prior):
if name == "translate":
# Use a custom distribution at site translate that conditions on sigma_signal.
pyro.module("mlp_trans", mlp_trans)
sigma_signal = self.upstream_value("sigma_signal")
mbatch_data_sigma_signal_undone = undo_sigma_signal_func(mbatch_data,sigma_signal) # how can I pass data to here?
lam_trans = mlp_trans(mbatch_data_sigma_signal_undone)
trans_loc, trans_log_scale = lam_trans.T
return dist.Normal(trans_loc, torch.exp(trans_log_scale)) # Can replace Normal with dist of my choice
# Fall back to mean field.
return super().get_posterior(name, prior)