Dear Pyro team,
I’m having some trouble at a conceptual level with the best way to structure
guide() code. In short, it would be useful to have access to the posterior of certain transformed variables in
guide(), and I wondered if using Poutine effect handlers is possible in this nested context? Any help greatly appreciated!
def sample_group_vars(): vars = pyro.sample('group', ...) # Further not necessarily deterministic transformations vars = transform(vars) # etc return vars def model(data): group_vars = sample_group_vars() subject_vars = pyro.sample('subject', ...) # Function of group_vars pyro.sample('data', ..., obs=data) # Function of subject_vars return def guide(data): group_vars = pyro.sample('group', distribution(pyro.param('a'))) # Logic copied from sample_group_vars() group_vars = transform(group_vars) # etc subject_param = amortise(data, group_vars) subject_vars = pyro.sample('subject', distribution(subject_param)) return
The crux of the problem is that it is best if the
amortise() function in
guide() has access to the group variables in the transformed space (i.e. such that we can combine them with
data more easily). The question is whether there is an easier way of doing this than what is sketched above?
Current approach: recapitulate all the logic in
guide(). Disadvantage: code duplication, as all deterministic code has to be repeated.
Alternative 1: As part of
sample_group_vars()with a sample from the posterior over
'group'. Question: is it possible to nest
guide(), or will this problematic when we then need to trace
guide()as a whole during inference?
Alternative 2: Define a context aware function,
sample_group_vars(model=True), that will sample from the prior or posterior accordingly. We can then call it in both
guide()just by changing the flag, but this feels very much like reinventing the wheel.