Parametric Interventions in Pyro

Hi Team,

I’m trying to implement what is termed as “soft”, “shift” or “parametric” interventions and I understand that pyro.do statements represent hard interventions. While I understand the theory, I’m curious as to what an implementation would look like in the context of Pyro’s inbuilt utilities.

For instance if I want to maintain the linear relationship between X and Y s.t. Y = wX + b where w is a constant weight and b ~ N(0, 1); and in intervening I want to change the noise variance i.e. for b.

Is this possible with some modification in the definition of the do function or is it something I must change by hand in my model each time I want to intervene on a variable?

I assume the worst case is that I could pass the noise variance for each node in my DAG as an argument but I was just wondering if soft interventions are planned to be offered as part of the library.

Hi, the implementation of pyro.do is pretty straightforward. You could fork it and add some code at the point in DoMessenger._pyro_sample where the intervention is applied to handle the case where the intervention is specified as a function of the information at the sample site:

...
intervention = self.data[msg['name']]
...
if isinstance(intervention, (numbers.Number, torch.Tensor)):
    ...
elif callable(intervention):  # this part is new
    # here intervention is a function that returns a new distribution
    msg["fn"] = intervention(msg)
else:
    raise NotImplementedError(...)
...
1 Like

Sounds good I’ll take a look and revert if I run into issues!

import pyro

def dummymodel(x):
            s = pyro.param("s", torch.tensor(0.5))
            z = pyro.sample("z", dist.Normal(x, s))
        return z ** 2

def intervention(msg):
    return msg['fn']    # for now don't change anything

intervention_fn = intervention
intervened_model = pyro.poutine.do(dummymodel, data={"z": intervention_fn})

tr = pyro.poutine.trace(intervened_model).get_trace(2.0)

Error:

/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/do_messenger.py in _pyro_sample(self, msg)
     80             if isinstance(intervention, (numbers.Number, torch.Tensor)):
---> 81                 msg['value'] = intervention
     82                 msg['is_observed'] = True

NotImplementedError: Interventions of type <class 'function'> not implemented (yet)

The above exception was the direct cause of the following exception:

NotImplementedError                       Traceback (most recent call last)
<ipython-input-261-d3cb2510b9b2> in <module>
     13 intervened_model = do(dummymodel, data={"z": intervention_fn})
     14 
---> 15 tr = pyro.poutine.trace(intervened_model).get_trace(2.0)

do_messenger.py

        if isinstance(intervention, (numbers.Number, torch.Tensor)):
            msg['value'] = intervention
            msg['is_observed'] = True
            msg['stop'] = True
        elif callable(intervention):
            msg['fn'] = intervention(msg)  # return a reparametrized function for the node
        else:
            raise NotImplementedError(
                "Interventions of type {} not implemented (yet)".format(type(intervention)))

So I ran into this error. I think it’s not recognizing intervention_fn as a callable within the implementation in do_messenger.py. Is there a recommended approach to debugging this such that I can see the message contents or is a local IDE the best approach? I tried to use pdb.set_trace() but it doesn’t pause execution.

I was working on a server with Jupyter notebooks so I’m moving everything to a local machine with an IDE now so I can add breakpoints and view the msg in-memory. I will keep you posted but if I’m doing something silly, please feel free to correct me.

I can’t say exactly what’s going on with your code, but it looks like you could fix it by deleting the elif callable and final NotImplementedError:

if isinstance(intervention, (numbers.Number, torch.Tensor)):
    ...
else:
    msg['fn'] = intervention(msg)

Thanks E.B., I did try that but it seems like there might be some sequence of expected behaviours that breaks because of this change and reflects downstream. I’ll keep trying to understand what’s breaking.

I might be naive about this but I still don’t follow how I can define a custom function for instance returning dist.Normal(x, s + 2.) because I’m unclear on how the variable s can be accessed at the time the intervention_fn is called. Do I access it through the global param store?

Meanwhile, here’s the complete traceback and I’ve already posted my code above and made the changes word-for-word as you instructed.

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
164             try:
--> 165                 ret = self.fn(*args, **kwargs)
166             except (ValueError, RuntimeError) as e:

/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
 11     with context:
---> 12         return fn(*args, **kwargs)
 13 

<ipython-input-355-cf18069b2011> in dummymodel(x)
  2     s = pyro.param("s", torch.tensor(0.5))
----> 3     z = pyro.sample("z", dist.Normal(x, s))
  4     return z ** 2

/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
112         # apply the stack and return its return value
--> 113         apply_stack(msg)
114         return msg["value"]

/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
192 
--> 193         frame._process_message(msg)
194 

/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/messenger.py in _process_message(self, msg)
138         if method is not None:
--> 139             return method(msg)
140         return None

/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/do_messenger.py in _pyro_sample(self, msg)
 80                 msg['fn'] = intervention(msg)  # return a reparametrized function for the node
---> 81         return None
 82 

NotImplementedError: Interventions of type <class 'function'> not implemented (yet)

The above exception was the direct cause of the following exception:

NotImplementedError                       Traceback (most recent call last)
<ipython-input-355-cf18069b2011> in <module>
 14 intervened_model = do(dummymodel, data={"z": intervention_fn})
 15 
---> 16 tr = pyro.poutine.trace(intervened_model).get_trace(2.0)

/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
185         Calls this poutine and returns its trace instead of the function's return value.
186         """
--> 187         self(*args, **kwargs)
188         return self.msngr.get_trace()

/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
169                 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
170                 exc = exc.with_traceback(traceback)
--> 171                 raise exc from e
172             self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret)
173         return ret

/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
163                                       args=args, kwargs=kwargs)
164             try:
--> 165                 ret = self.fn(*args, **kwargs)
166             except (ValueError, RuntimeError) as e:
167                 exc_type, exc_value, traceback = sys.exc_info()

/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
 10 def _context_wrap(context, fn, *args, **kwargs):
 11     with context:
---> 12         return fn(*args, **kwargs)
 13 
 14 

<ipython-input-355-cf18069b2011> in dummymodel(x)
  1 def dummymodel(x):
  2     s = pyro.param("s", torch.tensor(0.5))
----> 3     z = pyro.sample("z", dist.Normal(x, s))
  4     return z ** 2
  5 

/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
111             msg["is_observed"] = True
112         # apply the stack and return its return value
--> 113         apply_stack(msg)
114         return msg["value"]
115 

/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
191         pointer = pointer + 1
192 
--> 193         frame._process_message(msg)
194 
195         if msg["stop"]:

/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/messenger.py in _process_message(self, msg)
137         method = getattr(self, "_pyro_{}".format(msg["type"]), None)
138         if method is not None:
--> 139             return method(msg)
140         return None
141 

/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/do_messenger.py in _pyro_sample(self, msg)
 79             else:
 80                 msg['fn'] = intervention(msg)  # return a reparametrized function for the node
---> 81         return None
 82 

NotImplementedError: Interventions of type <class 'function'> not implemented (yet)
Trace Shapes:  
 Param Sites:  
        s  
Sample Sites:  
   z dist |
    value |

Okay whoops, just restarting the kernel helped, I see why Joel Grus hates debugging issues in Jupyter Notebooks. Thanks!

1 Like

The distribution at a sample site is stored in msg["fn"], so assuming the original distribution is Normal(x, s), you could define an intervention_fn setting the distribution to Normal(x, s+2) as

lambda msg: dist.Normal(msg["fn"].loc, msg["fn"].scale + 2.)
1 Like

Thanks! My question was missing context since this is a small example for what I’m trying to do. Assuming I want to use the current estimate of a latent variable parameter, I should be able to use pyro.param and retrieve its value right?

Can you clarify what exactly you’re trying to do? pyro.param statements correspond to deterministic learnable parameters, not latent random variables. See the SVI tutorial for more on the difference between pyro.param and pyro.sample.

Thanks for correcting me, yes I meant learnable parameters, not latent variables.
We are working on expanding the Epidemiology tutorial with the idea of policy interventions that can reduce the rate of spread of the disease, inspired by this paper. To summarize our takeaway, you can think of what we’re doing as adding a term u such that rate_s = (1. - u) * rate_s from the SIR HMC tutorial.

So we proceed as follows:

  1. Simulate some data.
  2. Fit the SIR model to the data and obtain the posterior estimates over the parameters R0 and rho.
  3. Naively find a minimum constant value of u that reduces the rate of spread to acceptable levels (sample grid of images representing such a case)

Now we want to make this more realistic so we’re doing four things ordered by priority.

  1. Allow u to vary at each time step. So we want to pick a minimally invasive policy at each time step that still regulates the spread of infections as desired.
  2. Build more complex compartmental models–we’ve got an implementation of SEIR and are working on SEI3R along with policy interventions in each case.
  3. Use data from an agent-based simulator called FRED to offer more granular control over simulations of diseases (we’re limiited to using influenza models here).
  4. Building out code so users can import their own data and in a single function call just run everything.

So the reason I ask is that at each time step I want to use the (current) point estimates of the parameters R0 and rho to model the disease spread and consider the minimally invasive u that can limit the spread. To be honest, this seems like an inelegant, brute-force approach to me but for want of better ideas I want to start by trying it out.

Also, if this sounds interesting to you, we’d also love to contriibute back to the Pyro community once we’re done (2-3 weeks). The original paper used Pyprob and had a different focus entirely so our major effort has been trying to translate what we can with our goal of a Distill-style article for non-practitioners of ProbProg to follow along clearly.