My guide keeps producing nan values, what am I doing wrong?

I have a model which is trying to simulate a Generalized Linear Model where the target distribution is a Gamma distribution with censored data. I have the model defined as follows:

from scipy.stats import gamma as scipygamma
def model_gamma_cen(X, y, censored_label):
    
    linear_combination = pyro.sample(f"beta_0", dist.Normal(0.0, 1.0))
    
    
    for i in range(0, X.shape[1]):
        beta_i = pyro.sample(f"beta_{i}", dist.Normal(0.0, 1.0))
        linear_combination = linear_combination + (beta_i * X[:, i])
    
    mean = torch.exp(linear_combination)
    
    rate = pyro.sample("rate", dist.HalfCauchy(scale=10.0))

    shape = mean * rate

    with pyro.plate("data", y.shape[0]):
        
        # non-censored data
        outcome_dist = dist.Gamma(shape, rate)
                    
        with pyro.poutine.mask(mask = (censored_label == 0.0)):
            observation = pyro.sample("obs", outcome_dist, obs=y)
        with pyro.poutine.mask(mask = (censored_label == 1.0)):

            truncation_prob = torch.tensor(1 - scipygamma(shape.detach(), rate.detach()).cdf(y).astype(np.float32))
            
            censored_observation = pyro.sample("censorship",
                                               dist.Bernoulli(truncation_prob),
                                               obs=torch.tensor(1.0))

And I have my guide defined as follows:

def guide_gamma_cen(X, y, censored_label):

    
    mu_intercept = pyro.param("mu_intercept", torch.tensor(0.0))
    sigma_intercept = pyro.param("sigma_intercept", torch.tensor(1.0))
    linear_combination = pyro.sample(f"beta_0", 
                                     dist.Normal(mu_intercept, sigma_intercept))
    
    
    for i in range(0, X.shape[1]):
        mu_coef = pyro.param(f"mu_{i}", torch.tensor(0.0))
        sigma_coef = pyro.param(f"sigma_{i}", torch.tensor(1.0))
        beta_i = pyro.sample(f"beta_{i}", dist.Normal(mu_coef, sigma_coef))
    
    
    scale_rate = pyro.param("scale_rate", torch.tensor(10.0))
    
    rate = pyro.sample("rate", dist.HalfCauchy(scale=scale_rate))

When I run these with an SVI that uses Trace_ELBO, and a ClippedAdam optimizer, after a few iterations, the vale for scale_rate becomes nan.

However, if instead of my own guide I use an autoguide (e.g. AutoDiagonalNormal), things work out. Why is my guide failing?

Hi @bshabash,
I haven’t fully understood your model, but one quick observation is that you should constraint your pyro.param statements when they are intended to be positive, e.g.

from torch.distributions import constraints

sigma_intercept = pyro.param("sigma_intercept", torch.tensor(1.0),
                             constraint=constraints.positive)
...
    sigma_coef = pyro.param(f"sigma_{i}", torch.tensor(1.0),
                            constraint=constraints.positive)
...
scale_rate = pyro.param("scale_rate", torch.tensor(10.0),
                        constraint=constraints.positive)

Let me know if you still get NAN’s after that fix, and I’ll take a closer look at the censoring logic.

Hi @fritzo
I added the constraints. I forgot about those.

However even with the constraint the value becomes nan after the first iteration. Also, for the first iteration, my loss is infinite after I take the SVI step. If I make the learning rate smaller, say 0.0000000000001, then the SVI goes for two iterations and then one of the other values becomes nan (in this case one of the scale values for beta_{i}).

For the censoring logic, I’m using scipy’s CDF for the Gamma distribution which is why the code there is messier

@fritzo
I also checked and if I remove the censoring logic this issue still persists. Having a learnable scale parameter for the HalfCauchy is the problem. If I remove this value as a learnable parameter (i.e. fix the scale value as torch.tensor(10.0) in the guide), everything else works.
I think there is something wrong with the gradient calculation there. Is there a better distribution which is a good practice for my rate of the Gamma distribution? I can’t use Normal since it has to be positive.
I tried HalfNormal, and Gamma. All of them fail with some values becoming Nan after a few if not one iteration

@fritzo
This is solved now. By explicitly setting my learnable parameters to be of type torch.float64 this now works. So I suspect there was an overflow error at some point which created an infinite gradient that somehow propagated to a nan value. Possibly because I’m using the torch.exp as my inverse link

This was really frustrating to try and debug and I don’t even know what made me think of trying this, but I think there should be some checkpoint in pyro that would have alerted me

Hi @bshabash, glad to hear you were able to fix. I’d be happy to help you add a checkpoint or validation logic if you can formulate a good idea of where the NAN error arose. In general I find NAN errors are difficult to track; we have some NAN checking in Pyro but it is difficult to trace NANs through PyTorch gradients, especially as NAN propagation rules are subtle and sometimes change across PyTorch releases.

Hi @fritzo ,
I am not that familiar with the details of ELBO maximization and SVI, but I think due to my use of torch.exp an inf value was produced as one of the gradient elements, which then produced a nan value by an action on that gradient.
While trying to find the source of the error I noticed there are quite a few warn_if_nan functions on different operations, but I think you need similar checks for inf values.
I am also aware this may be a PyTorch issue (most likely actually), but I think the Pyro team is in a better position to bring that up on the PyTorch issues page since you know the implementation details of SVI and the ELBO losses

Hi @bshabash, thanks for the tip. Following your advice I observed something interesting. Initially my guide was producing NaN for all iterations. When I changed to float64, it produced values for some initial iterations but then went back to NaN. I was running the VAE for non-MNIST images which were large. I resized the input images to lower dimensions and finally there were no more NaNs. I was running on my local laptop and that may have been the problem.

1 Like