Hi!
I have implemented my idea of the VGAE (inspired by [1611.07308] Variational Graph Auto-Encoders) for graph classification in Pyro (preferably unsupervised), but perhaps I am missing some crucial steps and you could spot something ? .
My architecture is something like this:
The guide is a function of the EasyGuide class
class GUIDES(EasyGuide):
def __init__(self,....):
super(GUIDES, self).__init__(model)
#Init params
def guide():
"""data_blosum is the data encoded with blosum vectors"""
pyro.module("guide_data_embedder", self.guide_data_embedder)
pyro.module("guide_laplacian_embedder", self.guide_laplacian_embedder)
pyro.module("guide_gcn",self.guide_gcn)
batch_data_blosum = batch_data["blosum"]
batch_seq_blosum = torch.flatten(batch_data_blosum[:, 1:], start_dim=2).reshape(batch_data_blosum.shape[0],
self.n_sequences * self.max_len,
self.input_dim) # skip labels and concatenate all sequences
batch_mask = batch_data["mask"]
mask_flat = torch.flatten(batch_mask, start_dim=2).reshape(batch_data_blosum.shape[0], self.n_sequences * self.max_len,
self.input_dim)
mask_flat_ = mask_flat[:,:,0]
with pyro.plate("data", batch_data_blosum.shape[0],dim=-2):
batch_seq_blosum = self.guide_data_embedder(batch_seq_blosum, mask=None)
batch_laplacian = self.guide_laplacian_embedder(batch_adjacency,mask=None) # Highlight: the mask seems to mess things up, this layer might be unnecessary
z_mean, z_std = self.guide_gcn(batch_seq_blosum, batch_laplacian, mask_flat_)
#z_mean.shape = [batch_size, n_seq*max_len, z_dim]
z_mean = torch.mean(z_mean, dim=1) #[batch_size,z_dim]
z_std = torch.mean(z_std, dim=1)
latent_z = pyro.sample("latent_z", dist.Normal(z_mean, z_std)) # [z_dim,n]
#COMMENT 1:
#class_logits = self.guide_class_logits(latent_z, mask=None)
#pyro.sample("predictions", dist.Categorical(logits=class_logits).to_event(1),infer={'is_auxiliary': True})
return {"latent_z":latent_z,
"z_mean":z_mean,
"z_std":z_std}
def model(self,batch_data,batch_laplacian,batch_adjacency):
""""""
pyro.module("decoder_class",self.decoder_class)
pyro.module("decoder_nodes", self.decoder_nodes)
pyro.module("decoder_adjacency", self.decoder_adjacency)
batch_data_int = batch_data["int"]
batch_seq_int = torch.flatten(batch_data_int[:,1:], start_dim=2).reshape(batch_data_int.shape[0],self.n_sequences*self.max_len,1) #skip labels and concatenate all sequences
batch_data_blosum = batch_data["blosum"]
batch_seq_blosum = torch.flatten(batch_data_blosum[:,1:], start_dim=2).reshape(batch_data_blosum.shape[0],self.n_sequences*self.max_len,self.input_dim) #skip labels and concatenate all sequences
batch_mask = batch_data["mask"]
z_mean = torch.zeros((batch_data_blosum.shape[0],self.z_dim))
z_var = torch.full((batch_data_blosum.shape[0], self.z_dim),1)
with pyro.plate("data",batch_seq_blosum.shape[0],dim=-2):
latent_z = pyro.sample("latent_z",dist.Normal(z_mean,z_var))
class_logits = self.decoder_class(latent_z)
with pyro.plate("plate_seq", batch_seq_int.shape[1], dim=-1):
nodes_logits = self.decoder_nodes(latent_z)
pyro.sample("nodes",dist.OneHotCategorical(logits=nodes_logits),obs=batch_seq_onehot)
triu_idx = torch.triu_indices(batch_adjacency.shape[1], batch_adjacency.shape[2])
latent_z = self.decoder_adjacency(latent_z,None) #needed to have the same size of the original adjacency
reconstructed_adjacency = nn.Sigmoid()(torch.matmul(latent_z[:,:,None],latent_z[:,None,:])) # I have also tried some other alternatives to the adjacency reconstruction
triu_reconstructed_adjacency = reconstructed_adjacency[:, triu_idx[0],triu_idx[1]]
if self.supervised:
pyro.sample("predictions",dist.Categorical(logits=class_logits).to_event(1),obs=batch_data["blosum"][:,0,0,0])
else:
pyro.sample("predictions", dist.Categorical(logits=class_logits).to_event(1),obs=None)
return {"triu_reconstructed_adjacency":triu_reconstructed_adjacency,
"batch_size": batch_data_blosum.shape[0],
"n_sequences": self.n_sequences,
"max_len":self.max_len,
"input_dim":self.input_dim}
First, I have an “unrelated” small issue, but I think it might be PyTorch (I have not been able to find it though). When I register the gradients norms, ONLY the parameters from the “decoder_adjacency” module are always missing (the rest are fine). There is not an apparent error (unless something is off with the architecture), they just do not get registered with the register_hook(). I looked "manually "at the param.norm() values and it is there and it is not nan (it does not change value though…)
loss = TraceELBO_graph() #custom error loss
gradient_norms = defaultdict(list)
while epoch <= num_epochs:
for batch in dataloader:
....
svi.step(...)
for name_i, value in pyro.get_param_store().named_parameters():
value.register_hook(lambda g, name_i=name_i:gradient_norms[name_i].append(g.norm().item()))
Second, something is wrong with the implementation (or the design, but I want to see if I missed something on the implementation first).
a) The model does train, the train error goes down at the beginning and then gets stuck (higher than the validation error).
b) The error losses are huge (~numberx10^{23}).
c) The latent space projection is completely random.
d) The classifications of the graphs are also random (does not matter the technique I use to make a classification).
e) The gradient norms descend but at some point get stuck.
I use a custom loss that accounts for the needed graph reconstruction errors by inserting some modifications to the TraceELBO class (by using the returns and inputs of the model trace in the loss_and_grads function). As I mentioned before, the error loss is huge, but the problem comes from the ELBO (so I am afraid that q(Z|X) != p(Z)) (the reconstruction losses seem irrelevant for now, the huge ELBO overtakes everything)
I have inspected the graph Laplacian and so on and there should be some clustering, but the model does not see it.
It seems that there are no “code errors” but perhaps something is missing or unnecessary from the architecture. For example, should I try to make the guide also approximate the labels/classification of the graph as well (see how in #COMMENT1 in the code). Or I leave that task solely to the model. If I add it to the guide, it warns me to add “is_auxiliary”: True and I am not entirely sure why.
Otherwise, I am guessing that the graph is the problem or I need a lot more NN and time, and computer power.
Thank you in advance for your reply and time