 # How is model conditioned using pyro.sample in factor?

search_inference implements `factor` as follows:

``````def factor(name, value):
"""
Like factor in webPPL, adds a scalar weight to the log-probability of the trace
"""
value = value if torch.is_tensor(value) else torch.tensor(value)
d = dist.Bernoulli(logits=value)
pyro.sample(name, d, obs=torch.ones(value.size()))
``````

An example use is demonstrated in RSA implicature definition of `literal_listener`.

I can’t understand how does sampling with `obs` conditions the model in the given context.

Hi @true, we consider the implementation of pyro.factor() an implementation detail: its effect should be to add a log prob term to the trace when used by inference algorithms. The way `pyro.factor()` is currently implemented is using an idiom

``````pyro.sample(name, distribution, obs=value)
``````

which creates the log prob term `distribution.log_prob(value)` in the trace. This idiom is kind of a hack. The sample site isn’t a real sample site, in the sense that it samples from a non-normalized `Unit` distribution which can have only a single possible value: the empty tensor `torch.zeros((0,)) == torch.ones((0,))`. This is like the unit type in type theory. The entire purpose of the unit distribution is to serve in the `pyro.factor` statement. This might seem weird, but when we introduced `pyro.factor`, a bunch of inference code assumed all log prob terms originated in sample sites, and the `Unit` hack was by far the easiest implementation.

Some day Pyro may introduce a more native `factor` statement, but for now this hack gets the job done and is quite useful Hi @fritzo,

Could you please expand the interpretation of `factor`?
If you could provide a minimum example of using it for inference that would be awesome.

The following models result in equivalent inference

``````def model_1(data):
loc = pyro.sample("loc", dist.Normal(0, 1))
pyro.sample("obs", dist.Normal(loc, 1), obs=data)
``````
``````def model_2(data):
loc = pyro.sample("loc", dist.Normal(0, 1))
pyro.factor("obs", dist.Normal(loc, 1).log_prob(data))
``````

The latter form using pyro.factor() can be useful if you have a bunch of PyTorch code to compute a (possibly non-normalized) likelihood like `fn(loc, data)`, but it is inconvenient to wrap that code in a `Distribution` interface.

BTW I just realized the `factor()` function you pointed to has been superseded by a public `pyro.factor()` primitive. I’ll submit a little PR to replace that custom usage with our standard `pyro.factor()`.

Thank you for the example it clarifies the confusion.

On a related note, multiple files in `examples/rsa` and the RSA tutorials independently define `Marginal` as:

``````def Marginal(fn):
return memoize(lambda *args: HashingMarginal(Search(fn).run(*args)))
``````

Defining `Marginal` in `search_inference.py` would be a natural choice, but the caveat is that `semantic_parsing.py` uses BFS instead:

``````def Marginal(fn=None, **kwargs):
if fn is None:
return lambda _fn: Marginal(_fn, **kwargs)
return memoize(lambda *args: HashingMarginal(BestFirstSearch(fn, **kwargs).run(*args)))
``````

If you can suggest how to reconcile the two `Marginal` decorators, I would be happy to submit a PR refactoring the code.

cc @eb8680_2 who understands those tutorials more thoroughly.

It’s been a couple years, but IIRC the two implementations of `Marginal` are supposed to correspond to the two different `Infer` implementations in the webPPL tutorials from which these examples were translated. Is there any particular reason you think they should be reconciled?

I’d like to move the `Marginal` definitions to `search_inference.py` to avoid the repeated definitions in examples. Combining both definitions with apropraite parameters looks to me like the way to go.

Do you think it makes sense to refactor?

The inference functionality in the RSA examples is not used anywhere else, so I’m fine either way. Feel free to put up a PR with your proposed refactoring.