VAE for tabular data NaN loss -

Hello, I am trying to implement a “toy” example on a tabular dataset by following the VAE tutorial in order to debug a more complicated implementation that returns NaN losses. My dataset has 7 covariates. For my activations on the decoder and the encoder I am using relus.

My net is as following:

class VAE(nn.Module):

def __init__(self, in_dim, z_dim, hid_dim, out_dim=1):
    super().__init__()
    self.in_dim = in_dim
    self.hid_dim = hid_dim
    self.z_dim = z_dim

    self.encoder = Encoder(in_dim, hid_dim, z_dim)
    self.decoder = Decoder(z_dim, hid_dim, in_dim)

def model(self, x):
    pyro.module("decoder", self.decoder)
    with pyro.plate("data", x.shape[0]):
        # setup hparams for prior
        z_prior_mu = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
        z_prior_sig = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
        # sample from prior
        z = pyro.sample("latent", dist.Normal(z_prior_mu, z_prior_sig).to_event(1))
        # decode latent representation
        x_pred_mu, x_pred_sig = self.decoder(z)
        # score against observations
        pyro.sample("x_obs", dist.Normal(x_pred_mu, x_pred_sig).to_event(1), obs=x)

def guide(self, x):
    #register model
    pyro.module("encoder", self.encoder)
    with pyro.plate("data", x.shape[0]):
        # get params of z from encoder
        zq_mu, zq_sig = self.encoder(x)
        # sample latent z
        pyro.sample("latent", dist.Normal(zq_mu, zq_sig).to_event(1))

vae = VAE(X_tr_.shape[1], 5, 200)
optimizer = Adam({“lr”: 1.e-4})
svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())
epochs=20
loss_arr = np.zeros((epochs))
for e in range(epochs):
loss=svi.step(X_tr_)
print(loss)
loss_arr[e]=loss

I tried a few different architectures and different learning rates but am still getting a NaN loss. I would expect large losses and the network not converting but why does it return NaN? I am experiencing the same issue while trying to reimplement the CEVAE from scratch. How can I debug such a case? I 've also seen a couple of similar issues on the forum but their solutions do not seem to apply to this (at least imho).

Thanks in advance!

it may be that your initial zq_sig is large. this has the potential to quickly lead to NaNs. you might try a different initialization. the simplest way to do this would be to do pyro.sample("latent", dist.Normal(zq_mu, zq_sig / 100.0).to_event(1))

please see here for some general advice that may be of help

Unfortunately, I tried the majority of the steps on the tutorial and it still happens. Is there any way to print the gradients? Or to add noise to the loss?

EDIT: Also, an additional thing I observed when printing the layer weights of the model and guide, some of them are not changing at all. Any idea as to what this may mean?

what is the structure of your encoder/decoder?

if you want to inspect the gradients, add noise to the loss, etc the best way to do that is using this lower level pattern