I have some count data which may follower a NB distribution. My goal is to get the posterior distribution of the mean (mu) for the NB distribution. I just started to learn Pyro as I need to sale the job. I really got stuck for several days and need some help.
# Simulate some data
mu = 0.3
variance = 1.6
phi = mu**2 / (variance-mu)
p = phi/(mu + phi)
y = np.random.negative_binomial(n=phi, p=p, size=10000)
y.mean(), y.var()
(0.309, 1.6999190000000002)
def model(y):
b0 = pyro.sample("b0", dist.Normal(0, 5))
phi = pyro.sample("phi", dist.HalfCauchy(2))
# phi = pyro.sample("phi", dist.Uniform(0., 10.))
mu = torch.exp(b0)
beta = phi/mu
alpha = phi
with pyro.plate("data", len(y)):
pyro.sample("obs", dist.GammaPoisson(alpha, beta), obs=y)
def guide(y):
a_loc = pyro.param('a_loc', torch.tensor(0.))
a_scale = pyro.param('a_scale', torch.tensor(1.),
constraint=constraints.positive)
phi_loc = pyro.param('phi_loc', torch.tensor(1.),
constraint=constraints.positive)
b0 = pyro.sample("b0", dist.Normal(a_loc, a_scale))
phi = pyro.sample("phi", dist.Normal(phi_loc, torch.tensor(1.0)))
mu = torch.exp(b0)
beta = phi/mu
alpha = phi
y_tensor = torch.tensor(y, dtype=torch.float)
from pyro.infer import SVI, Trace_ELBO
svi = SVI(model,
guide,
optim.Adam({"lr": .0005}),
loss=Trace_ELBO())
pyro.clear_param_store()
num_iters = 10000
loss = []
for i in range(num_iters):
elbo = svi.step(y_tensor)
loss.append(elbo)
if i % 500 == 0:
print("Elbo loss: {}".format(elbo))
/home/ec2-user/anaconda3/envs/pytorch_p36/lib/python3.6/site-packages/pyro/infer/trace_elbo.py:138: UserWarning: Encountered NaN: loss
warn_if_nan(loss, "loss")
Basically I want to model log(mu) through some linear model (just intercept right now). The phi in the code is the over dispersion parameter, then I mapped it to alpha and beta using the GammaPoisson format. Since phi is positive, I assigned a HCauchy distribution. In the end, the SVI runs but outputs are all NAs.
Can anyone give me some guidance? The only source I can find is here, but the way the over-dispersion parameter specified is confusing to me.