Importance sampling and Empirical Margin

I am trying to code up Chapter 1 in the book “Modeling-based Machine Learning”.

Here is my code, which models a crime scene and two possible murderers, modeled by a Bernoulli function. See the function model5(). There are two possible weapons, also modeled by a Bernoulli. Note that the probability p of the Bernoulli depends on who is the murderer. Finally, a bullet is found on the scene, which is an observation of the murder weapon. When I run this program I get a warning error on line 71 that

"/Users/erlebach/anaconda3/lib/python3.7/site-packages/pyro/primitives.py:71: RuntimeWarning: trying to observe a value outside of inference at weapon
RuntimeWarning) (LIne 71 is past the last line of code)
"

Why does this warning occur?

I also get a runtime error: RuntimeError: output with type torch.LongTensor doesn’t match the desired type torch.FloatTensor

which I can trace back to the line
emp = EmpiricalMarginal(posterior.run(), sites=“murderer”)

I include the error below the source code.

import sys, os
import numpy as np
import torch
import pyro
import pyro.infer
import torch.tensor as tensor
import pyro.optim
import pyro.distributions as dist

from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
from pyro.infer import EmpiricalMarginal

import torch.distributions.constraints as constraints
from torch.distributions.beta import Beta
from torch.distributions import Normal, Bernoulli

from pyro.infer.mcmc import HMC, NUTS, MCMC
from functools import reduce, partial

import time

assert pyro.__version__.startswith('0.3.0')

# Check function arguments
pyro.enable_validation()

# Generate new random numbers on each run
pyro.set_rng_seed(0)

def model5():
    print("enter model5")


    # murderer = 1 (Mr. Auburn)
    # murderer = 0 (Ms. Gray)
    murderer = dist.Bernoulli(0.7)
    # prior
    murderer = pyro.sample("murderer", murderer)
    
    # weapon == 1 (Gun)
    # weapon == 0 (dagger)
    if murderer == 0:  # Ms. Gray
        p = 0.9
    else: # Mr. Auburn
        p = 0.2
        
    # conditional probability: weapon given the murderer
    weapon = dist.Bernoulli(p)
    
    # A bullet is found on the murder scene. That means that the weapon is observed. 
    obs = pyro.sample("weapon", weapon, obs=torch.tensor(1))

	# Latent variable: murderer
	# Observed variable: weapon

model5()

posterior = pyro.infer.Importance(model5, None, num_samples=10)
print("posterior= ", posterior)
emp = EmpiricalMarginal(posterior.run(), sites="murderer")
print(emp)

===================================

RuntimeError                              Traceback (most recent call last)
~/Documents/src/2018/model-based_machine_learning/chapter1/infer.py in <module>()
     59 posterior = pyro.infer.Importance(model5, None, num_samples=10)
     60 print("posterior= ", posterior)
---> 61 emp = EmpiricalMarginal(posterior.run(), sites="murderer")
     62 print(emp)
     63 

~/anaconda3/lib/python3.7/site-packages/pyro/infer/abstract_infer.py in run(self, *args, **kwargs)
    198         self._reset()
    199         with poutine.block():
--> 200             for i, vals in enumerate(self._traces(*args, **kwargs)):
    201                 if len(vals) == 2:
    202                     chain_id = 0

~/anaconda3/lib/python3.7/site-packages/pyro/infer/importance.py in _traces(self, *args, **kwargs)
     43             model_trace = poutine.trace(
     44                 poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs)
---> 45             log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum()
     46             yield (model_trace, log_weight)
     47 

~/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_struct.py in log_prob_sum(self, site_filter)
    134                 else:
    135                     try:
--> 136                         log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
    137                     except ValueError:
    138                         _, exc_value, traceback = sys.exc_info()

~/anaconda3/lib/python3.7/site-packages/torch/distributions/bernoulli.py in log_prob(self, value)
     92             self._validate_sample(value)
     93         logits, value = broadcast_all(self.logits, value)
---> 94         return -binary_cross_entropy_with_logits(logits, value, reduction='none')
     95 
     96     def entropy(self):

~/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py in binary_cross_entropy_with_logits(input, target, weight, size_average, reduce, reduction, pos_weight)
   2075         raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
   2076 
-> 2077     return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
   2078 
   2079 

RuntimeError: output with type torch.LongTensor doesn't match the desired type torch.FloatTensor
1 Like

I answered my second question. The message

I also get a runtime error: RuntimeError: output with type torch.LongTensor doesn't match the desired type torch.FloatTensor

arises because I set my observable as a long (or integer) and it should be a float. So my new question is:

Why doesn’t pyro or torch convert longs to floats? Is there a reason? Thanks.

In general, there’s no reason these frameworks should do “guesswork” for the programmer. Makes the code unnecessarily complicated and unreadable.

1 Like

So you saying that the programmer should program more precisely. I understand, although it is inconsistent with Python being a typeless language. Conversions occur all over the place. Thank you for the feedback.

Thanks to this presentation and this thread I was able to put together a notebook for the Model-based machine learning - Chapter 1 - A Murder Mystery which seems to work: pyro/MBML_Chapter1_MurderMystery.ipynb at master · MicPie/pyro · GitHub

Feedback is very welcome, as I am still getting started with pyro! :smiley:

1 Like

I also have Chapter 1 coded up. Willing to share my notebook with you.

Let me know.

Yes, please, I would be curious how you solved it. :slight_smile:

I solved it after spending lots of time with Pyro :-). Here is a link to Google Drive:

https://drive.google.com/file/d/1SXX2j-RV3AFTqDSTPVjgZQnSgJV7_t3g/view?usp=sharing

Please provide any feedback you deem appropriate to improve what I have done. Thanks!

Gordon

1 Like

Thanks for sharing these notebooks. They really demonstrate some of the basic Pyro functions quite well.