Error with exact inference in simple probmods example

I’m trying to teach myself a bit of Pyro by translating some of the examples in the probmods book from WebPPL to Pyro. I’m getting stuck on an error that is probably just a torch newbie mistake and hoping someone here can’t tell me what I’m doing wrong.

I’ve reached chapter 3’s section on conditioning in some small causal bayes nets, and wanted to try implementing this with exact inference, equivalent to webPPL’s Infer({method: "enumerate"}, ... method.

Here’s the model and guide:

## minimal exmaple: works w/ Enumerate() but not exact 

import torch
import pyro
from pyro.infer import Importance, EmpiricalMarginal
import pyro.distributions as dist
import seaborn as sns
import numpy as np

pyro.clear_param_store()

def noisy_or_p(c_list, w_list):
    ps = torch.prod(torch.tensor([c_list, w_list]), dim = 0)
    return torch.tensor(1.) - torch.prod(torch.tensor(1.) - ps)

@pyro.infer.config_enumerate
def model():
    breastCancer = pyro.sample("breastCancer", dist.Bernoulli(.01))
    benignCyst = pyro.sample("benignCyst", dist.Bernoulli(.2))

    p_pos = noisy_or_p([breastCancer, benignCyst], [.8, .5])
    pyro.sample("positiveMammogram", dist.Bernoulli(p_pos), obs=torch.tensor(1.))
    
    return breastCancer

def guide(**kwargs):
    pass

This works with Importance() as called below:

posterior = Importance(model, num_samples=2000)
marginal = EmpiricalMarginal(posterior.run())

print("est. p(Cancer = 1|positiveMammogram) = {:.4f}".format(
    np.mean([marginal.sample() for _ in range(2000)]))
     )

Output: > est. p(Cancer = 1|positiveMammogram) = 0.0810

But doesn’t work when I try to do enumeration:

elbo = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0)
conditional_marginal = elbo.compute_marginals(model, guide)

print("exact p(Cancer = 1|positiveMammogram) = {:.4f}".format(
    float(conditional_marginal["breastCancer"].probs))
     )

This produces ValueError: only one element tensors can be converted to Python scalars. Appreciate any help!

Full error below (folded):

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    164             try:
--> 165                 ret = self.fn(*args, **kwargs)
    166             except (ValueError, RuntimeError) as e:

/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 

/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 

/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 

<ipython-input-956-9d3193f54304> in model()
     20 
---> 21     p_pos = noisy_or_p([breastCancer, benignCyst], [.8, .5])
     22 

<ipython-input-956-9d3193f54304> in noisy_or_p(c_list, w_list)
     12 def noisy_or_p(c_list, w_list):
---> 13     ps = torch.prod(torch.tensor([c_list, w_list]), dim = 0)
     14     return torch.tensor(1.) - torch.prod(torch.tensor(1.) - ps)

ValueError: only one element tensors can be converted to Python scalars

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
<ipython-input-956-9d3193f54304> in <module>
     30 
     31 elbo = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0)
---> 32 conditional_marginal = elbo.compute_marginals(model, guide)
     33 
     34 # posterior = Importance(model, num_samples=2000)

/usr/local/lib/python3.6/dist-packages/pyro/infer/traceenum_elbo.py in compute_marginals(self, model, guide, *args, **kwargs)
    432             raise NotImplementedError("TraceEnum_ELBO.compute_marginals() is not "
    433                                       "compatible with multiple particles.")
--> 434         model_trace, guide_trace = next(self._get_traces(model, guide, args, kwargs))
    435         for site in guide_trace.nodes.values():
    436             if site["type"] == "sample":

/usr/local/lib/python3.6/dist-packages/pyro/infer/traceenum_elbo.py in _get_traces(self, model, guide, args, kwargs)
    345             q.put(poutine.Trace())
    346             while not q.empty():
--> 347                 yield self._get_trace(model, guide, args, kwargs)
    348 
    349     def loss(self, model, guide, *args, **kwargs):

/usr/local/lib/python3.6/dist-packages/pyro/infer/traceenum_elbo.py in _get_trace(self, model, guide, args, kwargs)
    298         """
    299         model_trace, guide_trace = get_importance_trace(
--> 300             "flat", self.max_plate_nesting, model, guide, args, kwargs)
    301 
    302         if is_validation_enabled():

/usr/local/lib/python3.6/dist-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     46         guide_trace.detach_()
     47     model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
---> 48                                 graph_type=graph_type).get_trace(*args, **kwargs)
     49     if is_validation_enabled():
     50         check_model_guide_match(model_trace, guide_trace, max_plate_nesting)

/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    185         Calls this poutine and returns its trace instead of the function's return value.
    186         """
--> 187         self(*args, **kwargs)
    188         return self.msngr.get_trace()

/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    169                 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
    170                 exc = exc.with_traceback(traceback)
--> 171                 raise exc from e
    172             self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret)
    173         return ret

/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    163                                       args=args, kwargs=kwargs)
    164             try:
--> 165                 ret = self.fn(*args, **kwargs)
    166             except (ValueError, RuntimeError) as e:
    167                 exc_type, exc_value, traceback = sys.exc_info()

/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

<ipython-input-956-9d3193f54304> in model()
     19     benignCyst = pyro.sample("benignCyst", dist.Bernoulli(.2))
     20 
---> 21     p_pos = noisy_or_p([breastCancer, benignCyst], [.8, .5])
     22 
     23     pyro.sample("positiveMammogram", dist.Bernoulli(p_pos), obs=torch.tensor(1.))

<ipython-input-956-9d3193f54304> in noisy_or_p(c_list, w_list)
     11 
     12 def noisy_or_p(c_list, w_list):
---> 13     ps = torch.prod(torch.tensor([c_list, w_list]), dim = 0)
     14     return torch.tensor(1.) - torch.prod(torch.tensor(1.) - ps)
     15 

ValueError: only one element tensors can be converted to Python scalars
    Trace Shapes:      
     Param Sites:      
    Sample Sites:      
breastCancer dist     |
            value   2 |
  benignCyst dist     |
            value 2 1 |
</details>