Advice to build a Variational Graph Autoencoder in Pyro

Hi!
I have implemented my idea of the VGAE (inspired by [1611.07308] Variational Graph Auto-Encoders) for graph classification in Pyro (preferably unsupervised), but perhaps I am missing some crucial steps and you could spot something ? :sweat_smile:.

My architecture is something like this:

The guide is a function of the EasyGuide class

class GUIDES(EasyGuide):
    def __init__(self,....):
        super(GUIDES, self).__init__(model)
        #Init params
   def guide():
     """data_blosum is the data encoded with blosum vectors"""
          pyro.module("guide_data_embedder", self.guide_data_embedder)
          pyro.module("guide_laplacian_embedder", self.guide_laplacian_embedder)
          pyro.module("guide_gcn",self.guide_gcn)
   
          batch_data_blosum = batch_data["blosum"]
          batch_seq_blosum = torch.flatten(batch_data_blosum[:, 1:], start_dim=2).reshape(batch_data_blosum.shape[0],
                                                                            self.n_sequences * self.max_len,
                                                                            self.input_dim)  # skip labels and concatenate all sequences

  
          batch_mask = batch_data["mask"]
          mask_flat = torch.flatten(batch_mask, start_dim=2).reshape(batch_data_blosum.shape[0], self.n_sequences * self.max_len,
                                                                     self.input_dim)
          mask_flat_ = mask_flat[:,:,0]
  
  
          with pyro.plate("data", batch_data_blosum.shape[0],dim=-2):
              batch_seq_blosum = self.guide_data_embedder(batch_seq_blosum, mask=None)
              batch_laplacian = self.guide_laplacian_embedder(batch_adjacency,mask=None)  # Highlight: the mask seems to mess things up, this layer might be unnecessary
              z_mean, z_std = self.guide_gcn(batch_seq_blosum, batch_laplacian, mask_flat_)
              #z_mean.shape = [batch_size, n_seq*max_len, z_dim]
              z_mean = torch.mean(z_mean, dim=1) #[batch_size,z_dim]
              z_std = torch.mean(z_std, dim=1)
              latent_z = pyro.sample("latent_z", dist.Normal(z_mean, z_std))  # [z_dim,n]
    
              #COMMENT 1:
              #class_logits = self.guide_class_logits(latent_z, mask=None)
              #pyro.sample("predictions", dist.Categorical(logits=class_logits).to_event(1),infer={'is_auxiliary': True})

  
          return {"latent_z":latent_z,
                  "z_mean":z_mean,
                  "z_std":z_std}
def model(self,batch_data,batch_laplacian,batch_adjacency):
        """"""
        pyro.module("decoder_class",self.decoder_class)
        pyro.module("decoder_nodes", self.decoder_nodes)
        pyro.module("decoder_adjacency", self.decoder_adjacency)

        batch_data_int = batch_data["int"]
        batch_seq_int = torch.flatten(batch_data_int[:,1:], start_dim=2).reshape(batch_data_int.shape[0],self.n_sequences*self.max_len,1) #skip labels and concatenate all sequences
        batch_data_blosum = batch_data["blosum"]
        batch_seq_blosum = torch.flatten(batch_data_blosum[:,1:], start_dim=2).reshape(batch_data_blosum.shape[0],self.n_sequences*self.max_len,self.input_dim) #skip labels and concatenate all sequences

        batch_mask = batch_data["mask"]
      

        z_mean = torch.zeros((batch_data_blosum.shape[0],self.z_dim))
        z_var = torch.full((batch_data_blosum.shape[0], self.z_dim),1)

        with pyro.plate("data",batch_seq_blosum.shape[0],dim=-2):
  
            latent_z = pyro.sample("latent_z",dist.Normal(z_mean,z_var))
            class_logits = self.decoder_class(latent_z)
            with pyro.plate("plate_seq", batch_seq_int.shape[1], dim=-1):
                nodes_logits = self.decoder_nodes(latent_z)
                pyro.sample("nodes",dist.OneHotCategorical(logits=nodes_logits),obs=batch_seq_onehot) 

            triu_idx = torch.triu_indices(batch_adjacency.shape[1], batch_adjacency.shape[2])
            latent_z = self.decoder_adjacency(latent_z,None) #needed to have the same size of the original adjacency
            reconstructed_adjacency = nn.Sigmoid()(torch.matmul(latent_z[:,:,None],latent_z[:,None,:])) # I have also tried some other alternatives to the adjacency reconstruction
            triu_reconstructed_adjacency = reconstructed_adjacency[:, triu_idx[0],triu_idx[1]]

            if self.supervised:
                pyro.sample("predictions",dist.Categorical(logits=class_logits).to_event(1),obs=batch_data["blosum"][:,0,0,0])
            else:
                pyro.sample("predictions", dist.Categorical(logits=class_logits).to_event(1),obs=None)

        return {"triu_reconstructed_adjacency":triu_reconstructed_adjacency,
                "batch_size": batch_data_blosum.shape[0],
                "n_sequences": self.n_sequences,
                "max_len":self.max_len,
                "input_dim":self.input_dim}

First, I have an “unrelated” small issue, but I think it might be PyTorch (I have not been able to find it though). When I register the gradients norms, ONLY the parameters from the “decoder_adjacency” module are always missing (the rest are fine). There is not an apparent error (unless something is off with the architecture), they just do not get registered with the register_hook(). I looked "manually "at the param.norm() values and it is there and it is not nan (it does not change value though…)

loss = TraceELBO_graph() #custom error loss
gradient_norms = defaultdict(list)
while epoch <= num_epochs:
       for batch in dataloader:
                   ....
                   svi.step(...)
       for name_i, value in pyro.get_param_store().named_parameters():
                    value.register_hook(lambda g, name_i=name_i:gradient_norms[name_i].append(g.norm().item()))

Second, something is wrong with the implementation (or the design, but I want to see if I missed something on the implementation first).
a) The model does train, the train error goes down at the beginning and then gets stuck (higher than the validation error).
b) The error losses are huge (~numberx10^{23}).
c) The latent space projection is completely random.
d) The classifications of the graphs are also random (does not matter the technique I use to make a classification).
e) The gradient norms descend but at some point get stuck.

I use a custom loss that accounts for the needed graph reconstruction errors by inserting some modifications to the TraceELBO class (by using the returns and inputs of the model trace in the loss_and_grads function). As I mentioned before, the error loss is huge, but the problem comes from the ELBO (so I am afraid that q(Z|X) != p(Z)) (the reconstruction losses seem irrelevant for now, the huge ELBO overtakes everything)

I have inspected the graph Laplacian and so on and there should be some clustering, but the model does not see it.

It seems that there are no “code errors” but perhaps something is missing or unnecessary from the architecture. For example, should I try to make the guide also approximate the labels/classification of the graph as well (see how in #COMMENT1 in the code). Or I leave that task solely to the model. If I add it to the guide, it warns me to add “is_auxiliary”: True and I am not entirely sure why.

Otherwise, I am guessing that the graph is the problem or I need a lot more NN and time, and computer power.

Thank you in advance for your reply and time :smiley:

@fritzo @martinjankowiak Any clues :upside_down_face:? Thanks!

May I please have this post deleted if no answer is provided? Thanks!

@artistworking regarding your first issue, it’s hard to say what could be going on without seeing the definition of the model class and self.decoder_adjacency, but you should probably refactor your model to use PyroModule instead of pyro.module and remove pyro.module calls from your model and guide.

It seems likely that the second issue is related to that and your custom loss, so you should go through the points in the SVI tips and tricks tutorial and try doing some diagnostic experiments like

  • Verifying that your decoder_adjacency parameters are actually being updated
  • Using .to_event(1) on your distributions for latent_z, to avoid silent shape errors and ensure that every dimension is accounted for by Pyro
  • replacing your custom loss with a standard ELBO combined with custom distributions or even just pyro.factor statements in the model (to guarantee that Pyro computes gradient estimates correctly)
  • attempting to fit the guide to simulated data from the model (to confirm that you can train the guide to convergence)
  • using a simple AutoGuide like AutoDelta and simpler neural network internals to fit the model on a subset of the data (to confirm that you can train the model to convergence).

If these experiments suggest that you need to mitigate the effects of large initial divergence between the prior and guide, you might try improving the initialization of the neural networks by pretraining them separately, or using KL annealing as in the deep Markov model.

Thanks @eb8680_2 , yeah, you are right, I may have started building the house from the roof. I will step back down. Thanks!!