 # Default guide for pyro.infer.Importance with some strange behaviour

#1

Hi, I’m having some trouble understanding why the following simple model has a different standard deviation (hence, different posterior, right?) when using the default guide and when using a guide that is exactly the model itself, for instance:

``````import numpy as np
import matplotlib.pyplot as plt
try:
import seaborn as sns
sns.set()
except ImportError:
pass

import torch
from torch.autograd import Variable

import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist

torch.manual_seed(101)

def scale(guess):
weight = pyro.sample("weight", dist.normal, guess,
Variable(torch.ones(1)))
return weight

def scale_prior_guide(guess):
return pyro.sample("weight", dist.normal, guess, Variable(torch.ones(1)))

def scale_prior_guide_weird(guess):
return pyro.sample("weight", dist.normal, guess+3, Variable(torch.ones(1))+2)

posterior = pyro.infer.Importance(scale, num_samples=1000)
guess = Variable(torch.Tensor([8.5]))
marginal = pyro.infer.Marginal(posterior)
res_importance = np.array([marginal(guess).data for _ in range(10000)])

posterior = pyro.infer.Importance(scale, num_samples=1000, guide=scale_prior_guide)
guess = Variable(torch.Tensor([8.5]))
marginal = pyro.infer.Marginal(posterior)
res_importance_guide = np.array([marginal(guess).data for _ in range(10000)])

posterior = pyro.infer.Importance(scale, num_samples=1000, guide=scale_prior_guide_weird)
guess = Variable(torch.Tensor([8.5]))
marginal = pyro.infer.Marginal(posterior)
res_importance_guide_weird = np.array([marginal(guess).data for _ in range(10000)])

print("Mean")
print(res_importance.mean())
print(res_importance_guide.mean())
print(res_importance_guide_weird.mean())

print("Std")
print(res_importance.std())
print(res_importance_guide.std())
print(res_importance_guide_weird.std())
``````

would output something like:

``````Mean
8.49218823838
8.44276266947
8.46633415203
Std
0.703899055717 # <<<<<<< here's what's strange
0.983066442502
1.01712713134
``````

I inspected Pyro source code to check what’s the default guide for `pyro.infer.Importance` and it seems to be the model itself for this simple model with data:

``````import pyro.poutine as poutine
guide = poutine.block(scale, hide_types=["observe"])
res_model = np.array([scale(guess).data for _ in range(10000)])
res_guide = np.array([guide(guess).data for _ in range(10000)])
print(res_model.mean())
print(res_guide.mean())
print(res_model.std())
print(res_guide.std())
``````

would output something like:

``````8.47813085794
8.49829654074
1.00555316964
0.998794367065``````

#2

I believe this is a bug in a decision rule in `poutine.block`. See also this issue. I’ve assigned myself to push a fix but haven’t had a chance to get to it yet.

#3

Thanks, I think this is only a matter of having `hide_all = False` when `expose_types` or `hide_types` are not `None` (just like when `hide` is not `None`). See https://github.com/uber/pyro/compare/dev...randommm:patch-1

After this, the code from rstebbing works as expected:

``````In :
...: # Using `Importance` with the default `guide`, the `log_weight` is equal
...:  to the
...: # `model_trace.log_pdf()`. That is, the `guide_trace.log_pdf()` (evaluat
...: ed
...: # internally) is incorrectly `0.0`.
...: print('importance_default_guide:')
...: importance_default_guide = infer.Importance(gaussian, num_samples=10)
...: for model_trace, log_weight in importance_default_guide._traces():
...:     model_trace_log_pdf = model_trace.log_pdf()
...:     are_equal = log_weight.data == model_trace_log_pdf.data
...:     print(log_weight.data, are_equal)
...:
...: # However, setting the `guide` to expose `x` ensures that it is replayed
...:  so
...: # that the `log_weight` is exactly zero for each sample.
...: print('importance_exposed_guide:')
...: importance_exposed_guide = infer.Importance(
...:     gaussian,
...:     guide=poutine.block(gaussian, expose=['x']),
...:     num_samples=10)
...: for model_trace, log_weight in importance_exposed_guide._traces():
...:     print(log_weight.data)
...:
importance_default_guide:
0.0 False
0.0 False
0.0 False
0.0 False
0.0 False
0.0 False
0.0 False
0.0 False
0.0 False
0.0 False
importance_exposed_guide:
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0``````