I have added Maxpooling layer to the VAE, so I need use indices
returned by nn.MaxUnpool2d()
in the decoder.
How to add indices
in def model(self, x):
?
The code is as follows:
class VAE(nn.Module):
# by default our latent space is 50-dimensional
# and we use 400 hidden units
def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):
super().__init__()
# create the encoder and decoder networks
self.encoder = Encoder(z_dim, hidden_dim)
self.decoder = Decoder(z_dim, hidden_dim)
if use_cuda:
# calling cuda() here will put all the parameters of
# the encoder and decoder networks into gpu memory
self.cuda()
self.use_cuda = use_cuda
self.z_dim = z_dim
# define the model p(x|z)p(z)
def model(self, x):
# register PyTorch module `decoder` with Pyro
pyro.module("decoder", self.decoder)
with pyro.plate("data", x.shape[0]):
# setup hyperparameters for prior p(z)
z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
# sample from prior (value will be sampled by guide when computing the ELBO)
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# decode the latent code z
loc_img = self.decoder(z, indices1, indices2, indices3)
loc_img = loc_img.reshape(-1,1000000)
# score against actual images
# my_log_loss = weighted_binary_cross_entropy(loc_img, x.reshape(-1, 10000), weights=None)
# pyro.factor("myfactor", -my_log_loss)
pyro.sample("obs", weighted_binary_cross_entropy(loc_img).to_event(1), obs=x.reshape(-1, 1000000))
# define the guide (i.e. variational distribution) q(z|x)
def guide(self, x):
# register PyTorch module `encoder` with Pyro
pyro.module("encoder", self.encoder)
with pyro.plate("data", x.shape[0]):
# use the encoder to get the parameters used to define q(z|x)
z_loc, z_scale,indices1, indices2, indices3 = self.encoder(x)
# sample the latent code z
pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# define a helper function for reconstructing images
def reconstruct_img(self, x):
# encode image x
z_loc, z_scale, indices1, indices2, indices3 = self.encoder(x)
loc_img = self.decoder(z_loc, indices1, indices2, indices3)
return loc_img
def getZ(self, x):
# encode image x
z_loc, z_scale, indices1, indices2, indices3 = self.encoder(x)
return z_loc+z_scale