Thanks E.B., I did try that but it seems like there might be some sequence of expected behaviours that breaks because of this change and reflects downstream. I’ll keep trying to understand what’s breaking.
I might be naive about this but I still don’t follow how I can define a custom function for instance returning dist.Normal(x, s + 2.)
because I’m unclear on how the variable s
can be accessed at the time the intervention_fn is called. Do I access it through the global param store?
Meanwhile, here’s the complete traceback and I’ve already posted my code above and made the changes word-for-word as you instructed.
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
164 try:
--> 165 ret = self.fn(*args, **kwargs)
166 except (ValueError, RuntimeError) as e:
/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
11 with context:
---> 12 return fn(*args, **kwargs)
13
<ipython-input-355-cf18069b2011> in dummymodel(x)
2 s = pyro.param("s", torch.tensor(0.5))
----> 3 z = pyro.sample("z", dist.Normal(x, s))
4 return z ** 2
/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
112 # apply the stack and return its return value
--> 113 apply_stack(msg)
114 return msg["value"]
/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
192
--> 193 frame._process_message(msg)
194
/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/messenger.py in _process_message(self, msg)
138 if method is not None:
--> 139 return method(msg)
140 return None
/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/do_messenger.py in _pyro_sample(self, msg)
80 msg['fn'] = intervention(msg) # return a reparametrized function for the node
---> 81 return None
82
NotImplementedError: Interventions of type <class 'function'> not implemented (yet)
The above exception was the direct cause of the following exception:
NotImplementedError Traceback (most recent call last)
<ipython-input-355-cf18069b2011> in <module>
14 intervened_model = do(dummymodel, data={"z": intervention_fn})
15
---> 16 tr = pyro.poutine.trace(intervened_model).get_trace(2.0)
/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
185 Calls this poutine and returns its trace instead of the function's return value.
186 """
--> 187 self(*args, **kwargs)
188 return self.msngr.get_trace()
/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
169 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
170 exc = exc.with_traceback(traceback)
--> 171 raise exc from e
172 self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret)
173 return ret
/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
163 args=args, kwargs=kwargs)
164 try:
--> 165 ret = self.fn(*args, **kwargs)
166 except (ValueError, RuntimeError) as e:
167 exc_type, exc_value, traceback = sys.exc_info()
/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
<ipython-input-355-cf18069b2011> in dummymodel(x)
1 def dummymodel(x):
2 s = pyro.param("s", torch.tensor(0.5))
----> 3 z = pyro.sample("z", dist.Normal(x, s))
4 return z ** 2
5
/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
111 msg["is_observed"] = True
112 # apply the stack and return its return value
--> 113 apply_stack(msg)
114 return msg["value"]
115
/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
191 pointer = pointer + 1
192
--> 193 frame._process_message(msg)
194
195 if msg["stop"]:
/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/messenger.py in _process_message(self, msg)
137 method = getattr(self, "_pyro_{}".format(msg["type"]), None)
138 if method is not None:
--> 139 return method(msg)
140 return None
141
/opt/conda/envs/pyprobenv/lib/python3.6/site-packages/pyro/poutine/do_messenger.py in _pyro_sample(self, msg)
79 else:
80 msg['fn'] = intervention(msg) # return a reparametrized function for the node
---> 81 return None
82
NotImplementedError: Interventions of type <class 'function'> not implemented (yet)
Trace Shapes:
Param Sites:
s
Sample Sites:
z dist |
value |