How is model conditioned using pyro.sample in factor?

search_inference implements factor as follows:

def factor(name, value):
    Like factor in webPPL, adds a scalar weight to the log-probability of the trace
    value = value if torch.is_tensor(value) else torch.tensor(value)
    d = dist.Bernoulli(logits=value)
    pyro.sample(name, d, obs=torch.ones(value.size()))

An example use is demonstrated in RSA implicature definition of literal_listener.

I can’t understand how does sampling with obs conditions the model in the given context.

Hi @true, we consider the implementation of pyro.factor() an implementation detail: its effect should be to add a log prob term to the trace when used by inference algorithms. The way pyro.factor() is currently implemented is using an idiom

pyro.sample(name, distribution, obs=value)

which creates the log prob term distribution.log_prob(value) in the trace. This idiom is kind of a hack. The sample site isn’t a real sample site, in the sense that it samples from a non-normalized Unit distribution which can have only a single possible value: the empty tensor torch.zeros((0,)) == torch.ones((0,)). This is like the unit type in type theory. The entire purpose of the unit distribution is to serve in the pyro.factor statement. This might seem weird, but when we introduced pyro.factor, a bunch of inference code assumed all log prob terms originated in sample sites, and the Unit hack was by far the easiest implementation.

Some day Pyro may introduce a more native factor statement, but for now this hack gets the job done and is quite useful :slightly_smiling_face:

Hi @fritzo,

Could you please expand the interpretation of factor?
If you could provide a minimum example of using it for inference that would be awesome.

The following models result in equivalent inference

def model_1(data):
    loc = pyro.sample("loc", dist.Normal(0, 1))
    pyro.sample("obs", dist.Normal(loc, 1), obs=data)
def model_2(data):
    loc = pyro.sample("loc", dist.Normal(0, 1))
    pyro.factor("obs", dist.Normal(loc, 1).log_prob(data))

The latter form using pyro.factor() can be useful if you have a bunch of PyTorch code to compute a (possibly non-normalized) likelihood like fn(loc, data), but it is inconvenient to wrap that code in a Distribution interface.

BTW I just realized the factor() function you pointed to has been superseded by a public pyro.factor() primitive. I’ll submit a little PR to replace that custom usage with our standard pyro.factor().

Thank you for the example it clarifies the confusion.

On a related note, multiple files in examples/rsa and the RSA tutorials independently define Marginal as:

def Marginal(fn):
    return memoize(lambda *args: HashingMarginal(Search(fn).run(*args)))

Defining Marginal in would be a natural choice, but the caveat is that uses BFS instead:

def Marginal(fn=None, **kwargs):
    if fn is None:
        return lambda _fn: Marginal(_fn, **kwargs)
    return memoize(lambda *args: HashingMarginal(BestFirstSearch(fn, **kwargs).run(*args)))

If you can suggest how to reconcile the two Marginal decorators, I would be happy to submit a PR refactoring the code.

cc @eb8680_2 who understands those tutorials more thoroughly.

It’s been a couple years, but IIRC the two implementations of Marginal are supposed to correspond to the two different Infer implementations in the webPPL tutorials from which these examples were translated. Is there any particular reason you think they should be reconciled?

I’d like to move the Marginal definitions to to avoid the repeated definitions in examples. Combining both definitions with apropraite parameters looks to me like the way to go.

Do you think it makes sense to refactor?

The inference functionality in the RSA examples is not used anywhere else, so I’m fine either way. Feel free to put up a PR with your proposed refactoring.