How to get parameters of distribution if the parameters are not pyro.param variables?


#1

Hi, Maybe it’s a basic question, please forgive me.

Say we sample in model like

def model(data):
    with pyro.plate('dim',5):
        pyro.sample('obs',dist.Normal(0,1.0))

After training, How can I get the trained parameters (mu and scale) of the Normal distributions?

Thanks.


#2

depends what you are doing. if you are using svi, you would need pyro.param() statements in your guide or you’re not learning anything. if you are using hmc, you can get the marginal posterior over your variables as in this example.


#3

If you are using SVI, you’d want you write your model as

def model(data):
    loc = pyro.param('loc', torch.tensor(0.))
    scale = pyro.param('scale', torch.tensor(1.), constraint=constraints.positive)
    with pyro.plate('dim', 5):
        pyro.sample('obs', dist.Normal(loc, scale), obs=data)

def guide(data):
    pass

Then after training you could read out:

print(pyro.get_param_store().items())