Default guide for pyro.infer.Importance with some strange behaviour


#1

Hi, I’m having some trouble understanding why the following simple model has a different standard deviation (hence, different posterior, right?) when using the default guide and when using a guide that is exactly the model itself, for instance:

import numpy as np
import matplotlib.pyplot as plt
try:
    import seaborn as sns
    sns.set()
except ImportError:
    pass

import torch
from torch.autograd import Variable

import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist

torch.manual_seed(101)

def scale(guess):
    weight = pyro.sample("weight", dist.normal, guess,
                         Variable(torch.ones(1)))
    return weight

def scale_prior_guide(guess):
    return pyro.sample("weight", dist.normal, guess, Variable(torch.ones(1)))

def scale_prior_guide_weird(guess):
    return pyro.sample("weight", dist.normal, guess+3, Variable(torch.ones(1))+2)

posterior = pyro.infer.Importance(scale, num_samples=1000)
guess = Variable(torch.Tensor([8.5]))
marginal = pyro.infer.Marginal(posterior)
res_importance = np.array([marginal(guess).data[0] for _ in range(10000)])

posterior = pyro.infer.Importance(scale, num_samples=1000, guide=scale_prior_guide)
guess = Variable(torch.Tensor([8.5]))
marginal = pyro.infer.Marginal(posterior)
res_importance_guide = np.array([marginal(guess).data[0] for _ in range(10000)])

posterior = pyro.infer.Importance(scale, num_samples=1000, guide=scale_prior_guide_weird)
guess = Variable(torch.Tensor([8.5]))
marginal = pyro.infer.Marginal(posterior)
res_importance_guide_weird = np.array([marginal(guess).data[0] for _ in range(10000)])

print("Mean")
print(res_importance.mean())
print(res_importance_guide.mean())
print(res_importance_guide_weird.mean())

print("Std")
print(res_importance.std())
print(res_importance_guide.std())
print(res_importance_guide_weird.std())

would output something like:

Mean
8.49218823838
8.44276266947
8.46633415203
Std
0.703899055717 # <<<<<<< here's what's strange
0.983066442502
1.01712713134

I inspected Pyro source code to check what’s the default guide for pyro.infer.Importance and it seems to be the model itself for this simple model with data:

import pyro.poutine as poutine
guide = poutine.block(scale, hide_types=["observe"])
res_model = np.array([scale(guess).data[0] for _ in range(10000)])
res_guide = np.array([guide(guess).data[0] for _ in range(10000)])
print(res_model.mean())
print(res_guide.mean())
print(res_model.std())
print(res_guide.std())

would output something like:

8.47813085794
8.49829654074
1.00555316964
0.998794367065

#2

I believe this is a bug in a decision rule in poutine.block. See also this issue. I’ve assigned myself to push a fix but haven’t had a chance to get to it yet.


#3

Thanks, I think this is only a matter of having hide_all = False when expose_types or hide_types are not None (just like when hide is not None). See https://github.com/uber/pyro/compare/dev...randommm:patch-1

After this, the code from rstebbing works as expected:

In [2]: 
   ...: # Using `Importance` with the default `guide`, the `log_weight` is equal
   ...:  to the
   ...: # `model_trace.log_pdf()`. That is, the `guide_trace.log_pdf()` (evaluat
   ...: ed
   ...: # internally) is incorrectly `0.0`.
   ...: print('importance_default_guide:')
   ...: importance_default_guide = infer.Importance(gaussian, num_samples=10)
   ...: for model_trace, log_weight in importance_default_guide._traces():
   ...:     model_trace_log_pdf = model_trace.log_pdf()
   ...:     are_equal = log_weight.data[0] == model_trace_log_pdf.data[0]
   ...:     print(log_weight.data[0], are_equal)
   ...: 
   ...: # However, setting the `guide` to expose `x` ensures that it is replayed
   ...:  so
   ...: # that the `log_weight` is exactly zero for each sample.
   ...: print('importance_exposed_guide:')
   ...: importance_exposed_guide = infer.Importance(
   ...:     gaussian,
   ...:     guide=poutine.block(gaussian, expose=['x']),
   ...:     num_samples=10)
   ...: for model_trace, log_weight in importance_exposed_guide._traces():
   ...:     print(log_weight.data[0])
   ...:     
importance_default_guide:
0.0 False
0.0 False
0.0 False
0.0 False
0.0 False
0.0 False
0.0 False
0.0 False
0.0 False
0.0 False
importance_exposed_guide:
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0