It’s my first time writing pyro and I am having some issue about the loss returning nan.
The model that I am implementing is as follows:
logF follows a 2-component mixture of normal, and logR | logF also follows a mixture of nomal whose component mean and sd depend linearly on logF.
Below is my code for this model. I found that gamma can have negative values even if it has a lognormal prior. The negative gamma would then let w_R have NaNs. Any suggestion about what’s warong in the code?
K = 3
@config_enumerate
def model(F_obs, F_tau, R_obs, R_tau):
# global variables for F
w_F = pyro.sample('w_F', dist.Dirichlet(1 * torch.ones(2)))
with pyro.plate('components_F', 2):
mu_logF = pyro.sample('mu_logF', dist.Normal(0., 5.))
scale_logF = pyro.sample('scale_logF', dist.LogNormal(0., 2.))
# global variables for R
w_R = pyro.sample('w_R', dist.Dirichlet(1 * torch.ones(K)))
assert not np.isnan(w_R.data.numpy().sum())
with pyro.plate('components_R', K):
alpha = pyro.sample('alpha', dist.Normal(0., 5.))
beta = pyro.sample('beta', dist.Normal(0., 5.))
gamma = pyro.sample('gamma', dist.LogNormal(0., 2.))
print(gamma)
with pyro.plate('F_data', len(F_obs)):
assig_F = pyro.sample('assig_F', dist.Categorical(w_F))
logF = pyro.sample('logF', dist.Normal(mu_logF[assig_F], scale_logF[assig_F]))
F = 10**logF
pyro.sample('F_obs', dist.Normal(F, F_tau), obs=F_obs)
with pyro.plate('R_data', len(R_obs)):
assig_R = pyro.sample('assig_R', dist.Categorical(w_R))
mu_logR = alpha[assig_R] + beta[assig_R]*logF
scale_logR = gamma[assig_R]
logR = pyro.sample('logR', dist.Normal(mu_logR, scale_logR))
R = 10**logR
pyro.sample('R_obs', dist.Normal(R, R_tau), obs=R_obs)
global_guide = AutoDelta(poutine.block(model, expose=['w_F', 'mu_logF', 'scale_logF',
'w_R', 'alpha', 'beta', 'gamma']))
optim = pyro.optim.Adam({'lr': 0.05, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, global_guide, optim, loss=elbo)
pyro.clear_param_store()
losses = []
for i in range(200):
loss = svi.step(F_obs, F_tau, R_obs, R_tau)
losses.append(loss)