Dear Pyro team,
I’m having some trouble at a conceptual level with the best way to structure model()
and guide()
code. In short, it would be useful to have access to the posterior of certain transformed variables in guide()
, and I wondered if using Poutine effect handlers is possible in this nested context? Any help greatly appreciated!
Many thanks!!
def sample_group_vars():
vars = pyro.sample('group', ...)
# Further not necessarily deterministic transformations
vars = transform(vars)
# etc
return vars
def model(data):
group_vars = sample_group_vars()
subject_vars = pyro.sample('subject', ...) # Function of group_vars
pyro.sample('data', ..., obs=data) # Function of subject_vars
return
def guide(data):
group_vars = pyro.sample('group', distribution(pyro.param('a')))
# Logic copied from sample_group_vars()
group_vars = transform(group_vars)
# etc
subject_param = amortise(data, group_vars)
subject_vars = pyro.sample('subject', distribution(subject_param))
return
The crux of the problem is that it is best if the amortise()
function in guide()
has access to the group variables in the transformed space (i.e. such that we can combine them with data
more easily). The question is whether there is an easier way of doing this than what is sketched above?
-
Current approach: recapitulate all the logic in
sample_group_vars()
inguide()
. Disadvantage: code duplication, as all deterministic code has to be repeated. -
Alternative 1: As part of
guide()
, replaysample_group_vars()
with a sample from the posterior over'group'
. Question: is it possible to nesttrace()
andreplay()
statements withinguide()
, or will this problematic when we then need to traceguide()
as a whole during inference? -
Alternative 2: Define a context aware function,
sample_group_vars(model=True)
, that will sample from the prior or posterior accordingly. We can then call it in bothmodel()
andguide()
just by changing the flag, but this feels very much like reinventing the wheel.
Thanks again!