Using pyro models from gen: log_prob_sum on conditioned nodes

I’m trying to make pyro models callable/usable from gen.jl. Conceptually, there are no barriers to this; all the basic capabilities gen requires from its generative functions are at least possible from pyro.

But as I try to code the gen “generate(model, args, constraints)” function there is a bit of a hiccup. I successfully turn the constraints into a pyro.condition call and run the model, but the resulting trace has no log_prob_sum on the conditioned nodes in order to calculate the weight.

I think it would be possible to get pyro to calculate the necessary log_prob_sum values, but as far as I can tell it would be a nuisance; both computationally-inefficient (because it requires a second pass through the model) and programatically-awkward (I think it would require some nasty abuse of poutine).

So, my questions:

  1. Is there some way to tell pyro to fill in the log_prob_sum even on conditioned nodes?
  2. Should there be?

(I realize that in ordinary pyro workflow, this value is not needed and so calculating it is a wasted extra step. But I think my use case is a valid one; not just for using pyro from gen, but also if you ever extend pyro to have gen-like features such as automated metropolis-rosenbluth-hastings.)

Hi @Jameson, IIUC you just need to trigger the computation after tracing. Pyro computes log prob factors lazily.

Conceptually:

with poutine.trace() as tr:
    with poutine.condition(data=constraints):
        model(*args)

trace = tr.trace
# At this point trace.nodes["my_name"]["log_prob_sum"] does not exist.

log_prob_sum = trace.log_prob_sum()
# At this point trace.nodes["my_name"]["log_prob_sum"] has been computed.
1 Like