SVI. complex guide structure with access to computations in model

I would like my guide to have access to things that happen in the model. In the guide I get samples of the latents, so I have those. But then I have to “redo” deterministic linking function that have already been coded up in the model. How can the guide get access to those?

What I’m trying to do is to give some conditional structure to the guide and make it aware of what has been estimated.

I’ve browsed Poutine (Effect handlers) — Pyro documentation a bit. Is there any simple examples worked somewhere?

  1. Can I do trace = pyro.poutine.trace(model).get_trace(0.0) in the guide? As the guide proceeds, will I have to run it again, to get things that were computed later on in the model?

  2. How do I ensure something appears in trace? The documentation mentions I need to pass it to pyro.sample() or pyro.param(). However, I don’t want it to affect the optimization. Wouldn’t it contribute to the logprob if I passed it to pyro.sample? And wouldn’t it be optimized if I passed it to pyro.param?

Sorry if I’m a bit fuzzy on how SVI works… please enlighten me… :pray:

Did you find any of the examples in contrib helpful? For instance the mue example shows how to have a class the encapsulates both the model and the guide.

That being said, I agree - this type of solution does require a bit of boiler plate code. I’d be curious what others think about model modularity.

Hi @geoffwoollard, I definitely agree that the original (model,guide) factorization makes it cumbersome to share deterministic computations between the model and guide. Last autumn Vitalii Kleshchevnikov and I built some machinery to accomplish this via messenger-based guides, a new factorization of variational inference that interleaves model and guide execution, and works with Pyro’s existing SVI. These new guides are available as of Pyro 1.8 as AutoNormalMessenger, AutoHierarchicalNormalMessenger, and AutoRegressiveMessenger.

The usage pattern is to subclass say AutoNormalMessenger and implement a method .get_posterior(name, prior) that is called by the framework at each sample site, with the name of the site passed as name. Within that .get_posterior() method you can call self.upstream_value("some_other_name") to get the posterior sample value of another upstream pyro.sample site. Also, the prior argument to .get_posterior(name,prior) is the model-side prior distribution conditioned on all upstream posterior samples, and you can use that prior’s parameters to capture upstream deterministic computations. For example if the model defines

prior = dist.Normal(loc=complex_deterministic_function(stuff), 1)
pyro.sample("x", prior)

then you can reuse that complex deterministic computation by reading off prior.loc when you handle .get_posterior("x", prior) in your guide implementation.

This is all pretty new machinery, and we welcome contributions of examples, tutorials, new AutoMessenger guides, and new methods :smile:

1 Like

I played with AutoNormalMessenger a bit.

I’m happy to contribute some docs for conditioning in amortized inference, if the code is amenable to that. I need some help getting a small working example. Here is my attempt:

Ok I understand how to use .get_posterior and some if/elif name == 'param': control flow.

But some questions remain:

  1. How do I feed data into it? Where would I put a neural net that takes in data in an amortized inference setting? Where do I register the params, inside get_posterior where the pyro.param statements were in the AutoNormalMessenger example?

  2. What if I wanted to get stuff? Would I have to put in a sort of “dummy distribution” on stuff in the model, with a very wide prior, and name it and sample from it (then discard the sample), and then do .get_posterior in the guide?

Here is a simplified 1D example of my model

def model(data):
  size_mini_batch = data.shape[-1]
  with pyro.plate('mini_batch',size_mini_batch,dim=-1):
    sigma_signal = pyro.sample('sigma_signal',dist.HalfNormal(sigma_signal_loc_gt,sigma_signal_scale_gt))
    trans = pyro.sample('translate',dist.Normal(trans_prior_loc,trans_prior_scale))
    atom_center_trans = atom_center.reshape(-1,1) + trans.reshape(1,-1)
    proj = torch.exp(-((coords.reshape(-1,1,1)-atom_center_trans.reshape(1,-1,size_mini_batch))**2)/(2*sigma_signal**2)).sum(1)
    with pyro.plate('pixel',num_pix,dim=-2):
      obs_dist = dist.Normal(proj,sigma_noise_gt)
      pyro.sample("noise", obs_dist, obs=data) 
      sim = obs_dist.sample()

      return sim, proj, trans, sigma_signal

And my guide

def guide(data):
  size_mini_batch = data.shape[-1]
  pyro.module("mlp", mlp)
  with pyro.plate('mini_batch', size_mini_batch, dim=-1):
    lam = mlp(data.T) 
    trans_loc, trans_log_scale, sigma_signal_log_loc, sigma_signal_log_scale = lam.T
    posterior_trans_dist = dist.Normal(trans_loc,torch.exp(trans_log_scale))
    trans = pyro.sample("translate", posterior_trans_dist)
    posterior_sigma_signal_dist = dist.Gamma(torch.exp(sigma_signal_log_loc),torch.exp(sigma_signal_log_scale)) # half normal doesn't work for some reason, not sure why
    sigma_signal = pyro.sample("sigma_signal", posterior_sigma_signal_dist)
  return trans, sigma_signal, lam

Here is some pseudo code from my attempt to extend from the AutoNormalMessenger example in the docs

class MyGuideMessenger(AutoNormalMessenger):
    def get_posterior(self, name, prior):
        if name == "translate":
          # Use a custom distribution at site translate that conditions on sigma_signal.
          pyro.module("mlp_trans", mlp_trans)
          sigma_signal = self.upstream_value("sigma_signal")
          mbatch_data_sigma_signal_undone = undo_sigma_signal_func(mbatch_data,sigma_signal) # how can I pass data to here?
          lam_trans = mlp_trans(mbatch_data_sigma_signal_undone)
          trans_loc, trans_log_scale = lam_trans.T
          return dist.Normal(trans_loc, torch.exp(trans_log_scale)) # Can replace Normal with dist of my choice
        # Fall back to mean field.
        return super().get_posterior(name, prior)

Thanks for sharing that example, but I couldn’t see any conditioning happening inside the guide. When the trace was accessed, it was outside both the model and guide. pyro/models.py at dev · pyro-ppl/pyro · GitHub

@fritzo’s example is more what I’m looking for. I just need some help extending to amortized inference.

Hi @geoffwoollard, we’re looking forward to any help clarifying docs! Answering:

  1. How do I feed data into it?

When you call svi.step(*args, **kwargs) those args and kwargs are saved as self.args_kwargs : Tuple[tuple, dict], which is accessible inside self.get_posterior(...).

Where would I put a neural net that takes in data in an amortized inference setting?

Declare it in the .__init__() method, pass data in via self.args_kwargs, and be sure to specify amortized_plates in super().__init__(amortized_plates=...).

Where do I register the params, inside get_posterior where the pyro.param statements were in the AutoNormalMessenger example?

You can register params either via pyro.param inside self.get_posterior() or via PyroParam inside .__init__() (after calling super().__init__(...)) because AutoGuideMessenger is a PyroModule. For the same reason you won’t need to call pyro.module on your neural net as long as you set it as an attribute in .__init__() after calling super().__init__(...). This is standard PyroModule behavior.

What if I wanted to get stuff ?

You could save stuff by calling pyro.deterministic("stuff", stuff) in the model and calling stuff = self.upstream_value("stuff") in the guide. More kludgey you just call self.stuff = stuff in the model and stuff = self.stuff in the guide, but that would leak some memory.

Here are some tweaks to your guide class:

class MyGuideMessenger(AutoNormalMessenger):
    def __init__(self, model, mlp_trans: torch.nn.Module):
        # Declare the amortized plates.
        super().__init__(model, amortized_plates=("mini_batch",))
        # Register the neural net.
        self.mlp_trans = mlp_trans

    def get_posterior(self, name, prior):
        # unpack args, kwargs
        args, kwargs = self.args_kwargs
        mbatch_data = args[0]

        # Declare a plate over the minibatch.
        size_mini_batch = mbatch_data.shape[-1]
        with pyro.plate('mini_batch', size_mini_batch, dim=-1):
            if name == "translate":
                # Use a custom distribution at site translate that conditions on sigma_signal.
                sigma_signal = self.upstream_value("sigma_signal")
                mbatch_data_sigma_signal_undone = undo_sigma_signal_func(mbatch_data,sigma_signal)
                lam_trans = self.mlp_trans(mbatch_data_sigma_signal_undone)
                trans_loc, trans_log_scale = lam_trans.T
                return dist.Normal(trans_loc, torch.exp(trans_log_scale)) # Can replace Normal with dist of my choice
            # Fall back to mean field.
            return super().get_posterior(name, prior)

which would be called in an svi loop as

guide = MyGuideMessenger(model)
SVI = SVI(model, guide, ..., Trace_ELBO())
for epoch in range(num_epochs):
    for mbatch_data in my_minibatcher(data):
        svi.step(mbatch_data)  # passed to guide.args_kwargs
1 Like