Trace with PyroModule using different function

Hi,
there is one thing I dont understand with PyroModule (which may expose my limited understanding of pyro). However, I want to use different functions within PyroModule to do different things. Specifically, I want to use one function to compute the likelihood (e.g. forward), and another one to compute some other values off the model (e.g. doing some argmax over the parameters). However, it does not seem that I am able to “replay” a specific posterior draw with a different function. Looking at the source code, I see that there is a context that is added to the call() statement. Is this the correct one to add?

A clean example to reproduce is the following:

import pyro
import torch
from pyro.nn import PyroSample, PyroModule
import pyro.distributions as dist
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro import poutine

class Model(PyroModule):
    def __init__(self):
        super().__init__()
        self.theta = PyroSample(dist.Uniform(0,1))

    def forward(self, obs):
        th = self.theta
        pyro.sample("obs", dist.Bernoulli(th), obs = obs)
        return th

    def get_theta(self):
        return self.theta

# %%
model = Model()
guide = AutoDiagonalNormal(model)
obs = torch.zeros((1,))

trace = poutine.trace(guide).get_trace(obs)

print("Theta using .forward()")
for _ in range(3):
    with poutine.replay(trace=trace):
        print(model(obs))

print("Theta using .get_theta()")
for _ in range(3):
    with poutine.replay(trace=trace):
        print(model.get_theta())

print("Theta using .get_theta() and context")
for _ in range(3):
    with poutine.replay(trace=trace), model._pyro_context:
        print(model.get_theta())

It gives the following results:

Theta using .forward() 
tensor(0.5523, grad_fn=<ExpandBackward>) 
tensor(0.5523, grad_fn=<ExpandBackward>)
tensor(0.5523, grad_fn=<ExpandBackward>) 

Theta using .get_theta() 
tensor(0.5031) 
tensor(0.1464) 
tensor(0.0965) 

Theta using .get_theta() and context
tensor(0.5523, grad_fn=<ExpandBackward>) 
tensor(0.5523, grad_fn=<ExpandBackward>) 
tensor(0.5523, grad_fn=<ExpandBackward>)

@enemis Probably, you want to use pyro_method?

Yes, you’re right! Silly, I should have read that documentation better. thank you :blush: