I am trying to replicate Code 2.6 from Statistical Rethinking by Richard McElreath.
def model():
p = pyro.sample("p", dist.Uniform(0, 1.0))
return pyro.sample("w", dist.Binomial(9, p))
def guide():
p_param = pyro.param("p_param", torch.tensor(0.1), constraint=constraints.unit_interval)
return pyro.sample("p", dist.Delta(p_param))
conditioned = pyro.condition(model, data={"w": tensor(6.)})
pyro.clear_param_store()
svi = pyro.infer.SVI(model=conditioned,
guide=guide,
optim=pyro.optim.SGD({"lr": 0.001, "momentum":0.1}),
loss=pyro.infer.Trace_ELBO())
losses = []
num_steps = 5000
for t in range(num_steps):
losses.append(svi.step())
I would also like to get the standard deviation and the 89% percentile interval for p. Looking at other examples of pyro on this thread and tutorials, here is my attempt (with incorrect stdev of 0.28 instead of 0.16):
pred = Predictive(model, posterior_samples={"w": tensor([6.]*5000)})
pred.get_samples()['p'].std()
> 0.2856
How can I get the stdev of the posterior?