def factor(name, value):
"""
Like factor in webPPL, adds a scalar weight to the log-probability of the trace
"""
value = value if torch.is_tensor(value) else torch.tensor(value)
d = dist.Bernoulli(logits=value)
pyro.sample(name, d, obs=torch.ones(value.size()))
An example use is demonstrated in RSA implicature definition of literal_listener.
I can’t understand how does sampling with obs conditions the model in the given context.
Hi @true, we consider the implementation of pyro.factor() an implementation detail: its effect should be to add a log prob term to the trace when used by inference algorithms. The way pyro.factor() is currently implemented is using an idiom
pyro.sample(name, distribution, obs=value)
which creates the log prob term distribution.log_prob(value) in the trace. This idiom is kind of a hack. The sample site isn’t a real sample site, in the sense that it samples from a non-normalized Unit distribution which can have only a single possible value: the empty tensor torch.zeros((0,)) == torch.ones((0,)). This is like the unit type in type theory. The entire purpose of the unit distribution is to serve in the pyro.factor statement. This might seem weird, but when we introduced pyro.factor, a bunch of inference code assumed all log prob terms originated in sample sites, and the Unit hack was by far the easiest implementation.
Some day Pyro may introduce a more native factor statement, but for now this hack gets the job done and is quite useful
def model_2(data):
loc = pyro.sample("loc", dist.Normal(0, 1))
pyro.factor("obs", dist.Normal(loc, 1).log_prob(data))
The latter form using pyro.factor() can be useful if you have a bunch of PyTorch code to compute a (possibly non-normalized) likelihood like fn(loc, data), but it is inconvenient to wrap that code in a Distribution interface.
BTW I just realized the factor() function you pointed to has been superseded by a public pyro.factor() primitive. I’ll submit a little PR to replace that custom usage with our standard pyro.factor().
It’s been a couple years, but IIRC the two implementations of Marginal are supposed to correspond to the two different Infer implementations in the webPPL tutorials from which these examples were translated. Is there any particular reason you think they should be reconciled?
I’d like to move the Marginal definitions to search_inference.py to avoid the repeated definitions in examples. Combining both definitions with apropraite parameters looks to me like the way to go.
The inference functionality in the RSA examples is not used anywhere else, so I’m fine either way. Feel free to put up a PR with your proposed refactoring.