Mask a sample downstream

I would like to do model selection based on principle of parsimony. The model is a linear differential equation. Upfront I set a max_order that I want to explore. I have a discrete variable that samples the current order of the differential equation that I want to simulate. However in order to have the effect of parsimony in the ELBO I should mask the parameters that are not used.

I have three sets of pyro.sample a, b, and c. All of them vectors with shape 1 x max_dim.

When I get a sample N<max_dim i need to mask N-max_dim samples.

The issue is that the selection of the parameters that are used in the simulation for a particular order N is not that simple. The straightforward approach is to take the first N parameters and in such a way, the mask setting is simple. First sample N and then use poutine.mask when sampling a, b, and c, to select the first N (mask all the rest).

However the selection of parameters is done based on a relation among a, b, and c. So I need to get the samples first, then select which N values should be used. But since I am not masking the rest, the ELBO always takes into account max_dim values for a, b, and c and the parsimony principle does not take effect.

I am thinking of overloading Trace so that I can alter scale_and_mask but I was wondering if there is another more elegant solution.

In a nutshell a simple (pseudo)code would be:

 def model(self):
    N_par = pyro.param('N_par', dist.Dirichlet(torch.ones(self.max_dim).to(self.device)*2500),
                          constraint=constraints.simplex)
   pole = pyro.sample('pole', dist.Categorical(N_par))

   with pyro.plate('dims', self.max_dim):
     a = pyro.sample('a', ...)
     b = pyro.sample('b', ...)
     c = pyro.sample('c', ...)

   sel_ind = select_which_parameters(a,b,c,pole)
   sim_mu = simulate(a,b,c,sel_ind)

   # Problematic part
   mask_somehow(a,b,c, mask=sel_ind>0) # this should have the same effect as poutine.mask

I made a hack by writing a new Messenger. The basic idea is to mutate the trace after a node has been created. The idea came from this line.

Since, the class Trace does not offer late updates of an existing node, I added new method def edit_node(self, site_name, **kwargs). The new messenger is implemented as class ModifyMessenger with a corresponding poutine.modify. The _process_message function is:

def _process_message(self, msg):
        msg['modify_attrs'] = self.modify_attrs

Finally, in the TraceMessenger :

if "modify_attrs" in msg:
    self.trace.edit_node(msg["name"], **msg["modify_attrs"])
else:
    self.trace.add_node(msg["name"], **msg.copy())

However, I am not sure whether such an approach is in line with the global strategy.