Training VAE on multi GPUs

Hello,I want to put encode and decode on 2 GPUs respectively, like training encoder on GPU 0 and training decoder on GPU 1, for distributed training because my VAE model is a little big. How can I achieve this purpose?

Hi @Amy, my guess is you’d call .to() before and after one of your neural nets, something like this:

decoder = MyDecoder()
encoder = MyEncoder()

data = data.to(device="cuda:0")
decoder.to(device="cuda:0")
encoder.to(device="cuda:1")

def model(data):
    pyro.module("decoder", decoder)
    z = pyro.sample("z", dist.Normal(0, 1))
    pyro.sample("x", dist.Normal(decoder(z), 1))

def guide(data):
    pyro.module("encoder", encoder)

    # This step lives on cuda:1:
    z_mean = encoder(data.to(device="cuda:1")).to(device="cuda:0")

    pyro.sample("z", dist.Normal(z_mean, 1))

The above assumes encoder is bigger than decoder and so does the ELBO computation on the same device as decoder.

Good luck, LMK if you get that working!

Hi @fritzo , Thank you for your reply!
If I want to implement data parallelism on GPU 0 and GPU 1 in my VAE module, how to do it?
My VAE code is as follows:

def train(svi, train_loader, use_cuda=True):
    # initialize loss accumulator
    epoch_loss = 0.
    # do a training epoch over each mini-batch x returned
    # by the data loader
    for x in train_loader:
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        # do ELBO gradient and accumulate loss
        epoch_loss += svi.step(x)

    # return epoch loss
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = epoch_loss / normalizer_train
    return total_epoch_loss_train

def evaluate(svi, test_loader, use_cuda=True):
    # initialize loss accumulator
    test_loss = 0.
    # do a training epoch over each mini-batch x returned
    # by the data loader
    for x in test_loader:
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        # do ELBO gradient and accumulate loss
        test_loss += svi.evaluate_loss(x)

    # return epoch loss
    normalizer_test = len(test_loader.dataset)
    total_epoch_loss_test = test_loss / normalizer_test
    return total_epoch_loss_test

class VAE(nn.Module):
    def __init__(self, z_dim=32, hidden_dim=1000, use_cuda=True):
        super().__init__()
        # create the encoder and decoder networks
        self.encoder = Encoder(z_dim, hidden_dim)
        self.decoder = Decoder(z_dim, hidden_dim)
        
        if use_cuda:
            # calling cuda() here will put all the parameters of
            # the encoder and decoder networks into gpu memory
            self.cuda()
        self.use_cuda = use_cuda
        self.z_dim = z_dim
    # define the model p(x|z)p(z)
    def model(self, x):
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)
        with pyro.plate("data", x.shape[0]):
            # setup hyperparameters for prior p(z)
            z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
            z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
            # sample from prior (value will be sampled by guide when computing the ELBO)
            z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
            # decode the latent code z
            loc_img = self.decoder(z)
            loc_img = loc_img.reshape(-1,1261795)
            pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 1261795))

    # define the guide (i.e. variational distribution) q(z|x)
    def guide(self, x):
        # register PyTorch module `encoder` with Pyro
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", x.shape[0]):
            # use the encoder to get the parameters used to define q(z|x)
            z_loc, z_scale = self.encoder(x)
            # sample the latent code z
            pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

    # define a helper function for reconstructing images
    def reconstruct_img(self, x):
        # encode image x
        z_loc, z_scale = self.encoder(x)
        loc_img = self.decoder(z_loc)
        return loc_img
    
    def getZ(self, x):
        # enczode image x
        z_loc, z_scale = self.encoder(x)
        return (z_loc,z_scale)

# Run options
LEARNING_RATE = 1.0e-3
USE_CUDA = True
NUM_EPOCHS = 100
sc_path = []
###custom Dataset
celldata=DataSet(sc_path)
train_loader = DataLoader(train_dataset, batch_size=batch_siz, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=batch_siz, num_workers=8)

svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())

### train VAE model
train_elbo = []
test_elbo = []
early_stopping = EarlyStopping(patience=20, verbose=True)
# training loop
for epoch in range(NUM_EPOCHS):
    sampler.set_epoch(epoch)
    total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)
    train_elbo.append(-total_epoch_loss_train)
    total_epoch_loss_test = evaluate(svi, test_loader, use_cuda=USE_CUDA)
    test_elbo.append(-total_epoch_loss_test)
    print("[epoch %03d]  training loss: %.4f" % (epoch, total_epoch_loss_train))
    print("[epoch %03d]  testing loss: %.4f" % (epoch, total_epoch_loss_test))
    early_stopping(total_epoch_loss_test, vae)
    if early_stopping.early_stop:
        print("Early stopping")