How to add indices of MaxPooling in VAE model

I have added Maxpooling layer to the VAE, so I need use indices returned by nn.MaxUnpool2d() in the decoder.
How to add indices in def model(self, x):?
The code is as follows:

class VAE(nn.Module):
# by default our latent space is 50-dimensional
# and we use 400 hidden units
def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):
    # create the encoder and decoder networks
    self.encoder = Encoder(z_dim, hidden_dim)
    self.decoder = Decoder(z_dim, hidden_dim)

    if use_cuda:
        # calling cuda() here will put all the parameters of
        # the encoder and decoder networks into gpu memory
    self.use_cuda = use_cuda
    self.z_dim = z_dim
# define the model p(x|z)p(z)
def model(self, x):
    # register PyTorch module `decoder` with Pyro
    pyro.module("decoder", self.decoder)
    with pyro.plate("data", x.shape[0]):
        # setup hyperparameters for prior p(z)
        z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
        z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
        # sample from prior (value will be sampled by guide when computing the ELBO)
        z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
        # decode the latent code z
        loc_img = self.decoder(z, indices1, indices2, indices3)
        loc_img = loc_img.reshape(-1,1000000)
        # score against actual images
        # my_log_loss = weighted_binary_cross_entropy(loc_img, x.reshape(-1, 10000), weights=None)
        # pyro.factor("myfactor", -my_log_loss)
        pyro.sample("obs", weighted_binary_cross_entropy(loc_img).to_event(1), obs=x.reshape(-1, 1000000))

# define the guide (i.e. variational distribution) q(z|x)
def guide(self, x):
    # register PyTorch module `encoder` with Pyro
    pyro.module("encoder", self.encoder)
    with pyro.plate("data", x.shape[0]):
        # use the encoder to get the parameters used to define q(z|x)
        z_loc, z_scale,indices1, indices2, indices3 = self.encoder(x)
        # sample the latent code z
        pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

# define a helper function for reconstructing images
def reconstruct_img(self, x):
    # encode image x
    z_loc, z_scale, indices1, indices2, indices3 = self.encoder(x)
    loc_img = self.decoder(z_loc, indices1, indices2, indices3)
    return loc_img

def getZ(self, x):
    # encode image x
    z_loc, z_scale, indices1, indices2, indices3 = self.encoder(x)
    return z_loc+z_scale