I’ve written a simple algorithm for a gmm. I’m wondering what is the best way to obtain a trace of the ‘locs’ and ‘assignments’ parameters? I know I can access pyro.param objects by simply using pyro.param(‘parameter name’), but I’m not so sure how to do this for pyro.sample objects. How can I do this?
def model(data):
with pyro.plate('components', T):
locs = pyro.sample('locs', Normal(0, 1))
with pyro.plate('data', N):
# Local variables.
assignments = pyro.sample('assignments', Categorical(torch.ones(T) / T)) # returns a vector of length T
obs = pyro.sample('obs', Normal(locs[assignments], 1), obs=data)
def guide(data):
# amortize using MLP
pyro.module('pi_mlp', pi_mlp)
# sample mixture components mu
tau = pyro.param('tau', lambda: Normal(0, 1).sample([T]))
with pyro.plate('components', T) as i:
pyro.sample('locs', Normal(tau[i], 1))
# sample cluster assignments
pi = pi_mlp(data.double()) # returns a vector of length T
with pyro.plate("data", N):
pyro.sample("assignments", Categorical(pi)) # returns a vector of length N