I’m trying to teach myself a bit of Pyro by translating some of the examples in the probmods book from WebPPL to Pyro. I’m getting stuck on an error that is probably just a torch newbie mistake and hoping someone here can’t tell me what I’m doing wrong.
I’ve reached chapter 3’s section on conditioning in some small causal bayes nets, and wanted to try implementing this with exact inference, equivalent to webPPL’s Infer({method: "enumerate"}, ...
method.
Here’s the model and guide:
## minimal exmaple: works w/ Enumerate() but not exact
import torch
import pyro
from pyro.infer import Importance, EmpiricalMarginal
import pyro.distributions as dist
import seaborn as sns
import numpy as np
pyro.clear_param_store()
def noisy_or_p(c_list, w_list):
ps = torch.prod(torch.tensor([c_list, w_list]), dim = 0)
return torch.tensor(1.) - torch.prod(torch.tensor(1.) - ps)
@pyro.infer.config_enumerate
def model():
breastCancer = pyro.sample("breastCancer", dist.Bernoulli(.01))
benignCyst = pyro.sample("benignCyst", dist.Bernoulli(.2))
p_pos = noisy_or_p([breastCancer, benignCyst], [.8, .5])
pyro.sample("positiveMammogram", dist.Bernoulli(p_pos), obs=torch.tensor(1.))
return breastCancer
def guide(**kwargs):
pass
This works with Importance()
as called below:
posterior = Importance(model, num_samples=2000)
marginal = EmpiricalMarginal(posterior.run())
print("est. p(Cancer = 1|positiveMammogram) = {:.4f}".format(
np.mean([marginal.sample() for _ in range(2000)]))
)
Output: > est. p(Cancer = 1|positiveMammogram) = 0.0810
But doesn’t work when I try to do enumeration:
elbo = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0)
conditional_marginal = elbo.compute_marginals(model, guide)
print("exact p(Cancer = 1|positiveMammogram) = {:.4f}".format(
float(conditional_marginal["breastCancer"].probs))
)
This produces ValueError: only one element tensors can be converted to Python scalars
. Appreciate any help!
Full error below (folded):
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
164 try:
--> 165 ret = self.fn(*args, **kwargs)
166 except (ValueError, RuntimeError) as e:
/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
11 with context:
---> 12 return fn(*args, **kwargs)
13
/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
11 with context:
---> 12 return fn(*args, **kwargs)
13
/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
11 with context:
---> 12 return fn(*args, **kwargs)
13
<ipython-input-956-9d3193f54304> in model()
20
---> 21 p_pos = noisy_or_p([breastCancer, benignCyst], [.8, .5])
22
<ipython-input-956-9d3193f54304> in noisy_or_p(c_list, w_list)
12 def noisy_or_p(c_list, w_list):
---> 13 ps = torch.prod(torch.tensor([c_list, w_list]), dim = 0)
14 return torch.tensor(1.) - torch.prod(torch.tensor(1.) - ps)
ValueError: only one element tensors can be converted to Python scalars
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
<ipython-input-956-9d3193f54304> in <module>
30
31 elbo = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0)
---> 32 conditional_marginal = elbo.compute_marginals(model, guide)
33
34 # posterior = Importance(model, num_samples=2000)
/usr/local/lib/python3.6/dist-packages/pyro/infer/traceenum_elbo.py in compute_marginals(self, model, guide, *args, **kwargs)
432 raise NotImplementedError("TraceEnum_ELBO.compute_marginals() is not "
433 "compatible with multiple particles.")
--> 434 model_trace, guide_trace = next(self._get_traces(model, guide, args, kwargs))
435 for site in guide_trace.nodes.values():
436 if site["type"] == "sample":
/usr/local/lib/python3.6/dist-packages/pyro/infer/traceenum_elbo.py in _get_traces(self, model, guide, args, kwargs)
345 q.put(poutine.Trace())
346 while not q.empty():
--> 347 yield self._get_trace(model, guide, args, kwargs)
348
349 def loss(self, model, guide, *args, **kwargs):
/usr/local/lib/python3.6/dist-packages/pyro/infer/traceenum_elbo.py in _get_trace(self, model, guide, args, kwargs)
298 """
299 model_trace, guide_trace = get_importance_trace(
--> 300 "flat", self.max_plate_nesting, model, guide, args, kwargs)
301
302 if is_validation_enabled():
/usr/local/lib/python3.6/dist-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
46 guide_trace.detach_()
47 model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
---> 48 graph_type=graph_type).get_trace(*args, **kwargs)
49 if is_validation_enabled():
50 check_model_guide_match(model_trace, guide_trace, max_plate_nesting)
/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
185 Calls this poutine and returns its trace instead of the function's return value.
186 """
--> 187 self(*args, **kwargs)
188 return self.msngr.get_trace()
/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
169 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
170 exc = exc.with_traceback(traceback)
--> 171 raise exc from e
172 self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret)
173 return ret
/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
163 args=args, kwargs=kwargs)
164 try:
--> 165 ret = self.fn(*args, **kwargs)
166 except (ValueError, RuntimeError) as e:
167 exc_type, exc_value, traceback = sys.exc_info()
/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
<ipython-input-956-9d3193f54304> in model()
19 benignCyst = pyro.sample("benignCyst", dist.Bernoulli(.2))
20
---> 21 p_pos = noisy_or_p([breastCancer, benignCyst], [.8, .5])
22
23 pyro.sample("positiveMammogram", dist.Bernoulli(p_pos), obs=torch.tensor(1.))
<ipython-input-956-9d3193f54304> in noisy_or_p(c_list, w_list)
11
12 def noisy_or_p(c_list, w_list):
---> 13 ps = torch.prod(torch.tensor([c_list, w_list]), dim = 0)
14 return torch.tensor(1.) - torch.prod(torch.tensor(1.) - ps)
15
ValueError: only one element tensors can be converted to Python scalars
Trace Shapes:
Param Sites:
Sample Sites:
breastCancer dist |
value 2 |
benignCyst dist |
value 2 1 |
</details>