I want to use the 1000*1000 matrix as input data but got a GPU memory error.
Is there any method to minimize VAE memory usage or optimize my model to reduce memory usage?
class Encoder(nn.Module):
def __init__(self, z_dim, hidden_dim):
super().__init__()
# setup the three linear transformations used
self.fc1 = nn.Linear(1000000, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, z_dim)
self.fc22 = nn.Linear(hidden_dim, z_dim)
# setup the non-linearities
self.softplus = nn.Softplus()
def forward(self, x):
# define the forward computation on the image x
# first shape the mini-batch to have pixels in the rightmost dimension
x = x.reshape(-1, 1000000)
# then compute the hidden units
hidden = self.softplus(self.fc1(x))
# then return a mean vector and a (positive) square root covariance
# each of size batch_size x z_dim
z_loc = self.fc21(hidden)
z_scale = torch.exp(self.fc22(hidden))
return z_loc, z_scale
class Decoder(nn.Module):
def __init__(self, z_dim, hidden_dim):
super().__init__()
# setup the two linear transformations used
self.fc1 = nn.Linear(z_dim, hidden_dim)
self.fc21 = nn.Linear(hidden_dim, 1000000)
# setup the non-linearities
self.softplus = nn.Softplus()
self.sigmoid = nn.Sigmoid()
def forward(self, z):
# define the forward computation on the latent z
# first compute the hidden units
hidden = self.softplus(self.fc1(z))
# return the parameter for the output Bernoulli
loc_img = self.sigmoid(self.fc21(hidden))
return loc_img
class VAE(nn.Module):
# by default our latent space is 50-dimensional
# and we use 400 hidden units
def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):
super().__init__()
# create the encoder and decoder networks
self.encoder = Encoder(z_dim, hidden_dim)
self.decoder = Decoder(z_dim, hidden_dim)
if use_cuda:
# calling cuda() here will put all the parameters of
# the encoder and decoder networks into gpu memory
self.cuda()
self.use_cuda = use_cuda
self.z_dim = z_dim
# define the model p(x|z)p(z)
def model(self, x):
# register PyTorch module `decoder` with Pyro
pyro.module("decoder", self.decoder)
with pyro.plate("data", x.shape[0]):
# setup hyperparameters for prior p(z)
z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
# sample from prior (value will be sampled by guide when computing the ELBO)
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# decode the latent code z
loc_img = self.decoder(z)
pyro.sample("obs", weighted_binary_cross_entropy(loc_img).to_event(1), obs=x.reshape(-1, 1000000))
# define the guide (i.e. variational distribution) q(z|x)
def guide(self, x):
# register PyTorch module `encoder` with Pyro
pyro.module("encoder", self.encoder)
with pyro.plate("data", x.shape[0]):
# use the encoder to get the parameters used to define q(z|x)
z_loc, z_scale = self.encoder(x)
# sample the latent code z
pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# define a helper function for reconstructing images
def reconstruct_img(self, x):
# encode image x
z_loc, z_scale = self.encoder(x)
loc_img = self.decoder(z_loc)
return loc_img
def getZ(self, x):
# encode image x
z_loc, z_scale = self.encoder(x)
return z_loc+z_scale
cell_loader=DataLoaderX(celldata, batch_size=250, num_workers=32,pin_memory=True)