I’m trying to make pyro models callable/usable from gen.jl. Conceptually, there are no barriers to this; all the basic capabilities gen requires from its generative functions are at least possible from pyro.
But as I try to code the gen “generate(model, args, constraints)” function there is a bit of a hiccup. I successfully turn the constraints into a pyro.condition call and run the model, but the resulting trace has no log_prob_sum on the conditioned nodes in order to calculate the weight.
I think it would be possible to get pyro to calculate the necessary log_prob_sum values, but as far as I can tell it would be a nuisance; both computationally-inefficient (because it requires a second pass through the model) and programatically-awkward (I think it would require some nasty abuse of poutine).
So, my questions:
- Is there some way to tell pyro to fill in the log_prob_sum even on conditioned nodes?
- Should there be?
(I realize that in ordinary pyro workflow, this value is not needed and so calculating it is a wasted extra step. But I think my use case is a valid one; not just for using pyro from gen, but also if you ever extend pyro to have gen-like features such as automated metropolis-rosenbluth-hastings.)