Hi all - I am trying to code a Bayesian VAE (i.e. a VAE with Baysian NNs) and I’m having some trouble getting the guide and model to align. Apologies if this is a simple fix as I am new to pyro but if anyone can help with this code I will be very grateful.
The main error I am getting is:
ValueError: Expected value argument (Tensor of shape (128, 784)) to be within the support (Boolean()) of the distribution Bernoulli(probs: torch.Size([128, 784])), but found invalid values:
tensor([[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
…,
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.],
[0., 0., 0., …, 0., 0., 0.]])
Here is my code:
(Thanks in advance for any help)
> class VariationalAutoencoder(nn.Module):`
>
> def __init__(self, latent_dims):
>
> super(VariationalAutoencoder, self).__init__()
> self.linear1 = nn.Linear(784, 512)
> self.linear2 = nn.Linear(512, latent_dims)
> self.linear3 = nn.Linear(512, latent_dims)
>
> self.N = torch.distributions.Normal(0, 1)
> self.kl = 0
>
> self.linear4 = nn.Linear(latent_dims, 512)
> self.linear5 = nn.Linear(512, 784)
> self.softplus = nn.Softplus()
> self.sigmoid = nn.Sigmoid()
>
> def forward(self, x):
>
> x = torch.flatten(x, start_dim=1)
> x = F.relu(self.linear1(x))
> mu = self.linear2(x)
> sigma = torch.exp(self.linear3(x))
> z = mu + sigma*self.N.sample(mu.shape)
> self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
> x = F.relu(self.linear4(z))
> x = torch.sigmoid(self.linear5(x))
>
> return x, mu, sigma
> def model(x):
> linear1w_prior = Normal(loc=torch.zeros_like(vae.linear1.weight), scale=torch.ones_like(vae.linear1.weight))
> linear1b_prior = Normal(loc=torch.zeros_like(vae.linear1.bias), scale=torch.ones_like(vae.linear1.bias))
> linear2w_prior = Normal(loc=torch.zeros_like(vae.linear2.weight), scale=torch.ones_like(vae.linear2.weight))
> linear2b_prior = Normal(loc=torch.zeros_like(vae.linear2.bias), scale=torch.ones_like(vae.linear2.bias))
>
> linear3w_prior = Normal(loc=torch.zeros_like(vae.linear3.weight), scale=torch.ones_like(vae.linear3.weight))
> linear3b_prior = Normal(loc=torch.zeros_like(vae.linear3.bias), scale=torch.ones_like(vae.linear3.bias))
>
> linear4w_prior = Normal(loc=torch.zeros_like(vae.linear4.weight), scale=torch.ones_like(vae.linear4.weight))
> linear4b_prior = Normal(loc=torch.zeros_like(vae.linear4.bias), scale=torch.ones_like(vae.linear4.bias))
>
> linear5w_prior = Normal(loc=torch.zeros_like(vae.linear5.weight), scale=torch.ones_like(vae.linear5.weight))
> linear5b_prior = Normal(loc=torch.zeros_like(vae.linear5.bias), scale=torch.ones_like(vae.linear5.bias))
>
> priors = {'linear1_enc.weight': linear1w_prior, 'linear1_enc.bias': linear1b_prior,
> 'linear2_enc.weight': linear2w_prior, 'linear2_enc.bias': linear2b_prior,
> 'linear3_enc.weight': linear3w_prior, 'linear3_enc.bias': linear3b_prior,
> 'linear4_enc.weight': linear4w_prior, 'linear4_enc.bias': linear4b_prior,
> 'linear5_enc.weight': linear5w_prior, 'linear5_enc.bias': linear5b_prior
> }
>
> lifted_module = pyro.module("module", vae, priors)
> recon, _, _ = lifted_module(x)
>
> pyro.sample("obs", Bernoulli(recon).to_event(1), obs=x.reshape(-1, 784))
> def guide(x):
>
> # First layer weight distribution priors
> linear1w_mu_param = pyro.param("linear1w_mu", torch.randn_like(vae.linear1.weight))
> linear1w_sigma_param = vae.softplus(pyro.param("linear1w_sigma", torch.randn_like(vae.linear1.weight)))
> linear1w_prior = Normal(loc=linear1w_mu_param, scale=linear1w_sigma_param)
>
> # First layer bias distribution priors
> linear1b_mu_param = pyro.param("linear1b_mu", torch.randn_like(vae.linear1.bias))
> linear1b_sigma_param = vae.softplus(pyro.param("linear1b_sigma", torch.randn_like(vae.linear1.bias)))
> linear1b_prior = Normal(loc=linear1b_mu_param, scale=linear1b_sigma_param)
>
> # Second layer weight distribution priors
>
> linear2w_mu_param = pyro.param("linear2w_mu", torch.randn_like(vae.linear2.weight))
> linear2w_sigma_param = vae.softplus(pyro.param("linear2w_sigma", torch.randn_like(vae.linear2.weight)))
> linear2w_prior = Normal(loc=linear2w_mu_param, scale=linear2w_sigma_param)
>
> # Second layer bias distribution priors
>
> linear2b_mu_param = pyro.param("linear2b_mu", torch.randn_like(vae.linear2.bias))
> linear2b_sigma_param = vae.softplus(pyro.param("linear2b_sigma", torch.randn_like(vae.linear2.bias)))
> linear2b_prior = Normal(loc=linear2b_mu_param, scale=linear2b_sigma_param)
>
> # Third layer weight distribution priors
>
> linear3w_mu_param = pyro.param("linear3w_mu", torch.randn_like(vae.linear3.weight))
> linear3w_sigma_param = vae.softplus(pyro.param("linear3w_sigma", torch.randn_like(vae.linear3.weight)))
> linear3w_prior = Normal(loc=linear3w_mu_param, scale=linear3w_sigma_param)
>
> # Third layer bias distribution priors
>
> linear3b_mu_param = pyro.param("linear3b_mu", torch.randn_like(vae.linear3.bias))
> linear3b_sigma_param = vae.softplus(pyro.param("linear3b_sigma", torch.randn_like(vae.linear3.bias)))
> linear3b_prior = Normal(loc=linear3b_mu_param, scale=linear3b_sigma_param)
>
> # Forth layer weight distribution priors
>
> linear4w_mu_param = pyro.param("linear4w_mu", torch.randn_like(vae.linear4.weight))
> linear4w_sigma_param = vae.softplus(pyro.param("linear4w_sigma", torch.randn_like(vae.linear4.weight)))
> linear4w_prior = Normal(loc=linear4w_mu_param, scale=linear4w_sigma_param)
>
> # Forth layer bias distribution priors
>
> linear4b_mu_param = pyro.param("linear4b_mu", torch.randn_like(vae.linear4.bias))
> linear4b_sigma_param = vae.softplus(pyro.param("linear4b_sigma", torch.randn_like(vae.linear4.bias)))
> linear4b_prior = Normal(loc=linear4b_mu_param, scale=linear4b_sigma_param)
>
> # Fifth layer weight distribution priors
>
> linear5w_mu_param = pyro.param("linear5w_mu", torch.randn_like(vae.linear5.weight))
> linear5w_sigma_param = vae.softplus(pyro.param("linear5w_sigma", torch.randn_like(vae.linear5.weight)))
> linear5w_prior = Normal(loc=linear5w_mu_param, scale=linear5w_sigma_param)
>
> # Fifth layer bias distribution priors
>
> linear5b_mu_param = pyro.param("linear5b_mu", torch.randn_like(vae.linear5.bias))
> linear5b_sigma_param = vae.softplus(pyro.param("linear5b_sigma", torch.randn_like(vae.linear5.bias)))
> linear5b_prior = Normal(loc=linear5b_mu_param, scale=linear5b_sigma_param)
>
>
> priors = {'linear1_enc.weight': linear1w_prior, 'linear1_enc.bias': linear1b_prior,
> 'linear2_enc.weight': linear2w_prior, 'linear2_enc.bias': linear2b_prior,
> 'linear3_enc.weight': linear3w_prior, 'linear3_enc.bias': linear3b_prior,
> 'linear4_enc.weight': linear4w_prior, 'linear4_enc.bias': linear4b_prior,
> 'linear5_enc.weight': linear5w_prior, 'linear5_enc.bias': linear5b_prior,
> }
>
> return pyro.module("module", vae, priors)
> LEARNING_RATE = 1.0e-3
> NUM_EPOCHS = 50
> TEST_FREQUENCY = 1
>
> # clear param store
> pyro.clear_param_store()
>
> # setup the VAE
> vae = VariationalAutoencoder(latent_dims=2)
>
> # setup the optimizer
> adam_args = {"lr": LEARNING_RATE}
> optimizer = Adam(adam_args)
>
> # setup the inference algorithm
> svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
> def train(svi, train_loader):
> epoch_loss = 0.
> for x, _ in train_loader:
> epoch_loss += svi.step(x)
>
> normalizer_train = len(train_loader.dataset)
> total_epoch_loss_train = epoch_loss / normalizer_train
> return total_epoch_loss_train
>
> train_elbo = []
> test_elbo = []
>
>
> for epoch in range(NUM_EPOCHS):
>
> total_epoch_loss_train = train(svi, train_loader)
> train_elbo.append(-total_epoch_loss_train)
> print("[epoch %03d] average training loss: %.4f" % (epoch, total_epoch_loss_train))
>
> if epoch % TEST_FREQUENCY == 0:
>
> total_epoch_loss_test = evaluate(svi, test_loader)
> test_elbo.append(-total_epoch_loss_test)
> print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))