I have this fairly straight forward VAE that I wish to train and save
class Encoder(nn.Module):
""" Simple MLP encoder."""
def __init__(self, input_dim, hidden_dim, latent_dim):
super(Encoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim * 2) # mean and log variance. of the latent
).to(device)
def forward(self, x):
h = self.encoder(x)
z_mu, z_logvar = torch.chunk(h, 2, dim=-1)
return z_mu, z_logvar
class Decoder(nn.Module):
"""
Simple MLP decoder
"""
def __init__(self, latent_dim, hidden_dim, output_dim):
super(Decoder, self).__init__()
self.cuda()
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim*2), # mean and log variance. of the MVN
).to(device)
def forward(self, z):
x = self.decoder(z)
x_mu, x_logvar = torch.chunk(x, 2, dim=-1)
return x_mu, x_logvar
class XYEncoder(nn.Module):
""" Simple MLP encoder """
def __init__(self, input_dim, hidden_dim, latent_dim):
super(XYEncoder, self).__init__()
self.cuda()
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim * 2) # mean and log variance. of the latent
).to(device)
def forward(self, x,y):
l = torch.cat((x,y), dim=1)
h = self.encoder(l)
z_mu, z_logvar = torch.chunk(h, 2, dim=-1)
return z_mu, z_logvar
class VAE(nn.Module):
def __init__(self,input_dim, z_dim=50, hidden_dim=400, beta=1.0, device="cuda"):
super().__init__()
if device=='cuda':
# calling cuda() here will put all the parameters of
# the encoder and decoder networks into gpu memory
self.cuda()
# create the encoder and decoder networks
self.encoder = Encoder(input_dim, hidden_dim, z_dim)
self.decoder = Decoder(z_dim, hidden_dim, input_dim)
self.device = device
self.z_dim = z_dim
self.beta = beta
# define the model p(x|z)p(z)
def model(self, x):
# register PyTorch module `decoder` with Pyro
pyro.module("decoder", self.decoder)
batch_size = x.shape[0]
with pyro.plate("data", batch_size):
# setup hyperparameters for prior p(z)
z_loc = x.new_zeros(torch.Size((batch_size, self.z_dim)))
z_scale = x.new_ones(torch.Size((batch_size, self.z_dim)))
with pyro.poutine.scale(scale=self.beta):
# sample from prior (value will be sampled by guide when computing the ELBO)
z = pyro.sample("latent", dist.Normal(z_loc, torch.exp(z_scale)).to_event(1))
# decode the latent code z
x_mu, x_logvar = self.decoder(z)
# score against actual images
pyro.sample("obs", dist.Normal(x_mu,torch.exp(x_logvar)).to_event(1), obs=x)
# define the guide (i.e. variational distribution) q(z|x)
def guide(self, x):
# register PyTorch module `encoder` with Pyro
pyro.module("encoder", self.encoder)
batch_size = x.shape[0]
with pyro.plate("data", batch_size):
# use the encoder to get the parameters used to define q(z|x)
z_loc, z_scale = self.encoder(x)
# sample the latent code z
pyro.sample("latent", dist.Normal(z_loc, torch.exp(z_scale)).to_event(1))
def predict(self, maldi_data, density, sample="mean"):
z_loc, z_scale = self.encoder(maldi_data)
z = dist.Normal(z_loc, torch.exp(z_scale)).sample()
x_mu, x_logvar = self.decoder(z)
if sample == "mean":
return x_mu
else:
return dist.Normal(x_mu, torch.exp(x_logvar)).sample(
Wen I wish to do torch.save I get error due to weakref, any idea on where is the weakref?
optimizer = torch.optim.Adam(vae.parameters())
# train the model
logging.info("Training the VAE model")
print(model_name)
optimizer = Adam({"lr": 0.001})
pyro.get_param_store().clear()
svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(epochs):
total_loss = 0.0
for idx, (data, density) in enumerate(data_loader):
x = data.to(vae.device)
y = density.to(vae.device)
loss = svi.step(x)
wandb.log({"loss": loss})
total_loss += loss
print("Epoch ", epoch, " Loss ", total_loss)
logging.info("Saving the model")
torch.save(vae, str(model_file))
Any help would be highly appreciated, thank you and best regards