Bayesian Variational Autoencoder

Hi all - I am trying to code a Bayesian VAE (i.e. a VAE with Baysian NNs) and I’m having some trouble getting the guide and model to align. Apologies if this is a simple fix as I am new to pyro but if anyone can help with this code I will be very grateful.

The main error I am getting is:

ValueError: Expected value argument (Tensor of shape (128, 784)) to be within the support (Boolean()) of the distribution Bernoulli(probs: torch.Size([128, 784])), but found invalid values:
tensor([[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
…,
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.]])

Here is my code:

(Thanks in advance for any help)

> class VariationalAutoencoder(nn.Module):`
> 
>     def __init__(self, latent_dims):
> 
>         super(VariationalAutoencoder, self).__init__()
>         self.linear1 = nn.Linear(784, 512)
>         self.linear2 = nn.Linear(512, latent_dims)
>         self.linear3 = nn.Linear(512, latent_dims)
> 
>         self.N = torch.distributions.Normal(0, 1)
>         self.kl = 0
>
>         self.linear4 = nn.Linear(latent_dims, 512)
>         self.linear5 = nn.Linear(512, 784)
>         self.softplus = nn.Softplus()
>         self.sigmoid = nn.Sigmoid()
> 
>     def forward(self, x):
> 
>         x = torch.flatten(x, start_dim=1)
>         x = F.relu(self.linear1(x))
>         mu =  self.linear2(x)
>         sigma = torch.exp(self.linear3(x))
>         z = mu + sigma*self.N.sample(mu.shape)
>         self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
>         x = F.relu(self.linear4(z))
>         x = torch.sigmoid(self.linear5(x))
> 
>         return x, mu, sigma

> def model(x):
>         linear1w_prior = Normal(loc=torch.zeros_like(vae.linear1.weight), scale=torch.ones_like(vae.linear1.weight))
>         linear1b_prior = Normal(loc=torch.zeros_like(vae.linear1.bias), scale=torch.ones_like(vae.linear1.bias))

>         linear2w_prior = Normal(loc=torch.zeros_like(vae.linear2.weight), scale=torch.ones_like(vae.linear2.weight))
>         linear2b_prior = Normal(loc=torch.zeros_like(vae.linear2.bias), scale=torch.ones_like(vae.linear2.bias))
> 
>         linear3w_prior = Normal(loc=torch.zeros_like(vae.linear3.weight), scale=torch.ones_like(vae.linear3.weight))
>         linear3b_prior = Normal(loc=torch.zeros_like(vae.linear3.bias), scale=torch.ones_like(vae.linear3.bias))
> 
>         linear4w_prior = Normal(loc=torch.zeros_like(vae.linear4.weight), scale=torch.ones_like(vae.linear4.weight))
>         linear4b_prior = Normal(loc=torch.zeros_like(vae.linear4.bias), scale=torch.ones_like(vae.linear4.bias))
> 
>         linear5w_prior = Normal(loc=torch.zeros_like(vae.linear5.weight), scale=torch.ones_like(vae.linear5.weight))
>         linear5b_prior = Normal(loc=torch.zeros_like(vae.linear5.bias), scale=torch.ones_like(vae.linear5.bias))
> 
>         priors = {'linear1_enc.weight': linear1w_prior, 'linear1_enc.bias': linear1b_prior,
>             'linear2_enc.weight': linear2w_prior, 'linear2_enc.bias': linear2b_prior,
>             'linear3_enc.weight': linear3w_prior, 'linear3_enc.bias': linear3b_prior,
>             'linear4_enc.weight': linear4w_prior, 'linear4_enc.bias': linear4b_prior,
>             'linear5_enc.weight': linear5w_prior, 'linear5_enc.bias': linear5b_prior
>               }
> 
>         lifted_module = pyro.module("module", vae, priors)
>         recon, _, _ = lifted_module(x)
> 
>         pyro.sample("obs", Bernoulli(recon).to_event(1), obs=x.reshape(-1, 784))

> def guide(x):
>             
>         # First layer weight distribution priors
>         linear1w_mu_param = pyro.param("linear1w_mu", torch.randn_like(vae.linear1.weight))
>         linear1w_sigma_param = vae.softplus(pyro.param("linear1w_sigma", torch.randn_like(vae.linear1.weight)))
>         linear1w_prior = Normal(loc=linear1w_mu_param, scale=linear1w_sigma_param)
> 
>         # First layer bias distribution priors
>         linear1b_mu_param = pyro.param("linear1b_mu", torch.randn_like(vae.linear1.bias))
>         linear1b_sigma_param = vae.softplus(pyro.param("linear1b_sigma", torch.randn_like(vae.linear1.bias)))
>         linear1b_prior = Normal(loc=linear1b_mu_param, scale=linear1b_sigma_param)
> 
>         # Second layer weight distribution priors
> 
>         linear2w_mu_param = pyro.param("linear2w_mu", torch.randn_like(vae.linear2.weight))
>         linear2w_sigma_param = vae.softplus(pyro.param("linear2w_sigma", torch.randn_like(vae.linear2.weight)))
>         linear2w_prior = Normal(loc=linear2w_mu_param, scale=linear2w_sigma_param)
> 
>         # Second layer bias distribution priors
> 
>         linear2b_mu_param = pyro.param("linear2b_mu", torch.randn_like(vae.linear2.bias))
>         linear2b_sigma_param = vae.softplus(pyro.param("linear2b_sigma", torch.randn_like(vae.linear2.bias)))
>         linear2b_prior = Normal(loc=linear2b_mu_param, scale=linear2b_sigma_param)
> 
>         # Third layer weight distribution priors
> 
>         linear3w_mu_param = pyro.param("linear3w_mu", torch.randn_like(vae.linear3.weight))
>         linear3w_sigma_param = vae.softplus(pyro.param("linear3w_sigma", torch.randn_like(vae.linear3.weight)))
>         linear3w_prior = Normal(loc=linear3w_mu_param, scale=linear3w_sigma_param)
> 
>         # Third layer bias distribution priors
> 
>         linear3b_mu_param = pyro.param("linear3b_mu", torch.randn_like(vae.linear3.bias))
>         linear3b_sigma_param = vae.softplus(pyro.param("linear3b_sigma", torch.randn_like(vae.linear3.bias)))
>         linear3b_prior = Normal(loc=linear3b_mu_param, scale=linear3b_sigma_param)
> 
>         # Forth layer weight distribution priors
> 
>         linear4w_mu_param = pyro.param("linear4w_mu", torch.randn_like(vae.linear4.weight))
>         linear4w_sigma_param = vae.softplus(pyro.param("linear4w_sigma", torch.randn_like(vae.linear4.weight)))
>         linear4w_prior = Normal(loc=linear4w_mu_param, scale=linear4w_sigma_param)
> 
>         # Forth layer bias distribution priors
> 
>         linear4b_mu_param = pyro.param("linear4b_mu", torch.randn_like(vae.linear4.bias))
>         linear4b_sigma_param = vae.softplus(pyro.param("linear4b_sigma", torch.randn_like(vae.linear4.bias)))
>         linear4b_prior = Normal(loc=linear4b_mu_param, scale=linear4b_sigma_param)
> 
>         # Fifth layer weight distribution priors
> 
>         linear5w_mu_param = pyro.param("linear5w_mu", torch.randn_like(vae.linear5.weight))
>         linear5w_sigma_param = vae.softplus(pyro.param("linear5w_sigma", torch.randn_like(vae.linear5.weight)))
>         linear5w_prior = Normal(loc=linear5w_mu_param, scale=linear5w_sigma_param)
> 
>         # Fifth layer bias distribution priors
> 
>         linear5b_mu_param = pyro.param("linear5b_mu", torch.randn_like(vae.linear5.bias))
>         linear5b_sigma_param = vae.softplus(pyro.param("linear5b_sigma", torch.randn_like(vae.linear5.bias)))
>         linear5b_prior = Normal(loc=linear5b_mu_param, scale=linear5b_sigma_param)
> 
>        
>         priors = {'linear1_enc.weight': linear1w_prior, 'linear1_enc.bias': linear1b_prior,
>               'linear2_enc.weight': linear2w_prior, 'linear2_enc.bias': linear2b_prior,
>               'linear3_enc.weight': linear3w_prior, 'linear3_enc.bias': linear3b_prior,
>               'linear4_enc.weight': linear4w_prior, 'linear4_enc.bias': linear4b_prior,
>               'linear5_enc.weight': linear5w_prior, 'linear5_enc.bias': linear5b_prior,
>               }
> 
>         return pyro.module("module", vae, priors)

> LEARNING_RATE = 1.0e-3
> NUM_EPOCHS = 50
> TEST_FREQUENCY = 1
> 
> # clear param store
> pyro.clear_param_store()
> 
> # setup the VAE
> vae = VariationalAutoencoder(latent_dims=2)
> 
> # setup the optimizer
> adam_args = {"lr": LEARNING_RATE}
> optimizer = Adam(adam_args)
> 
> # setup the inference algorithm
> svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

> def train(svi, train_loader):
>     epoch_loss = 0.

>     for x, _ in train_loader:
>         epoch_loss += svi.step(x)
> 
>     normalizer_train = len(train_loader.dataset)
>     total_epoch_loss_train = epoch_loss / normalizer_train
>     return total_epoch_loss_train
> 
> train_elbo = []
> test_elbo = []
> 
> 
> for epoch in range(NUM_EPOCHS):
> 
>     total_epoch_loss_train = train(svi, train_loader)
>     train_elbo.append(-total_epoch_loss_train)
>     print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))
> 
>     if epoch % TEST_FREQUENCY == 0:
> 
>         total_epoch_loss_test = evaluate(svi, test_loader)
>         test_elbo.append(-total_epoch_loss_test)
>         print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))

this code is unlikely to ever work in practice because you have large weight matrices and it doesn’t utilize the “local reparameteriztion trick” (see ref). i suggest using TyXe if you want to use bayesian neural networks in pyro