Loss returns NaN’s

It’s my first time writing pyro and I am having some issue about the loss returning nan.

The model that I am implementing is as follows:

logF follows a 2-component mixture of normal, and logR | logF also follows a mixture of nomal whose component mean and sd depend linearly on logF.

Below is my code for this model. I found that gamma can have negative values even if it has a lognormal prior. The negative gamma would then let w_R have NaNs. Any suggestion about what’s warong in the code?

K = 3
@config_enumerate
def model(F_obs, F_tau, R_obs, R_tau):
    
    # global variables for F
    w_F = pyro.sample('w_F', dist.Dirichlet(1 * torch.ones(2)))
    with pyro.plate('components_F', 2):
        mu_logF = pyro.sample('mu_logF', dist.Normal(0., 5.))
        scale_logF = pyro.sample('scale_logF', dist.LogNormal(0., 2.))
     
    # global variables for R
    w_R = pyro.sample('w_R', dist.Dirichlet(1 * torch.ones(K)))
    assert not np.isnan(w_R.data.numpy().sum())
    with pyro.plate('components_R', K):
        alpha = pyro.sample('alpha', dist.Normal(0., 5.))
        beta = pyro.sample('beta', dist.Normal(0., 5.))
        gamma = pyro.sample('gamma', dist.LogNormal(0., 2.))
        print(gamma)

    with pyro.plate('F_data', len(F_obs)):
        assig_F = pyro.sample('assig_F', dist.Categorical(w_F))
        logF = pyro.sample('logF', dist.Normal(mu_logF[assig_F], scale_logF[assig_F]))
        F = 10**logF
        pyro.sample('F_obs', dist.Normal(F, F_tau), obs=F_obs)
    
    with pyro.plate('R_data', len(R_obs)):
        assig_R = pyro.sample('assig_R', dist.Categorical(w_R))
        mu_logR = alpha[assig_R] + beta[assig_R]*logF
        scale_logR = gamma[assig_R] 
        logR = pyro.sample('logR', dist.Normal(mu_logR, scale_logR))
        R = 10**logR
        pyro.sample('R_obs', dist.Normal(R, R_tau), obs=R_obs)

global_guide = AutoDelta(poutine.block(model, expose=['w_F', 'mu_logF', 'scale_logF',
                                                      'w_R', 'alpha', 'beta', 'gamma']))
optim = pyro.optim.Adam({'lr': 0.05, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, global_guide, optim, loss=elbo)

pyro.clear_param_store()
losses = []
for i in range(200):
    loss = svi.step(F_obs, F_tau, R_obs, R_tau)
    losses.append(loss)

Hi @qma, that is indeed weird, Gamma shouldn’t have negative values. All I can imagine is a value that converges to 0.0 and maybe a sign being flipped. What negative value do you see?

Often to avoid NANs I need to clamp values in models, e.g. you might try

scale_logR = gamma[assig_R].clamp(min=1e-3)  # appropriate data scale

Hi @fritzo, thanks for your reply! I played with the codes many times today and had some new observations.

  1. Gamma does NOT turn to negative in every excution. There are two situations: 1) Gamma becomes negative (A typycal values is about -0.05), and then logR cannot be sampled and gives NaNs. 2) All sample sites remain valid until the first defined sample site w_F becomes NaNs.

  2. I turned on the pyro.enable_validation(True) and had the following error message. However, I couldn’t find a torch distribution that has a parameter named v.

~/anaconda3/envs/dl/lib/python3.7/site-packages/torch/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
     34                     continue  # skip checking lazily-constructed args
     35                 if not constraint.check(getattr(self, param)).all():
---> 36                     raise ValueError("The parameter {} has invalid values".format(param))
     37         super(Distribution, self).__init__()
     38 

Delta distributions have a parameter named v. You can also print(type(self)) in the debugger.