GNNs using random_flax_module with Jraph

Hi,

I was trying to Bayesianize graph neural networks using NumPyro and Jraph.

Firstly, if I do not use Jraph to define the graphs of interest, then there is no real issue. One can go ahead and define graph operations with base JAX, and couple it with Flax for NN definitions, and then proceed with random_flax_module.

However, Jraph provides a nice object for dealing with standard graphs. I have copied a model from Jraph’s examples to illustrate:

class ExplicitMLP(nn.Module):
    """A flax MLP."""
    features: Sequence[int]

    @nn.compact
    def __call__(self, inputs):
        x = inputs
        for i, lyr in enumerate([nn.Dense(feat) for feat in self.features]):
          x = lyr(x)
          if i != len(self.features) - 1:
            x = nn.relu(x)
        return x
    
def make_embed_fn(latent_size):
    def embed(inputs):
        return nn.Dense(latent_size)(inputs)
    return embed

def make_mlp(features):
    @jraph.concatenated_args
    def update_fn(inputs):
        return ExplicitMLP(features)(inputs)
    return update_fn

class GraphNetwork(nn.Module):
    """A flax GraphNetwork."""
    mlp_features: Sequence[int]
    latent_size: int

    @nn.compact
    def __call__(self, graph):
    
        # Add a global parameter for graph classification computation
        
        graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1]))

        embedder = jraph.GraphMapFeatures(
            embed_node_fn=make_embed_fn(self.latent_size),
            embed_edge_fn=make_embed_fn(self.latent_size),
            embed_global_fn=make_embed_fn(self.latent_size))
        
        net = jraph.GraphNetwork(
            update_node_fn=make_mlp(self.mlp_features),
            update_edge_fn=make_mlp(self.mlp_features),
            # The global update outputs size 2 for binary classification.
            update_global_fn=make_mlp(self.mlp_features + (2,)))  # pytype: disable=unsupported-operands

        return net(embedder(graph))

The flax.linen module GraphNetwork is where the main aspects of the GNN are synthesized. However, instead of accepting a static tensor as an input, it accepts a graph object (this is the Jraph GraphsTuple object). This graph object has many attributes to it, that cannot be simplified to some constant array (or atleast doing so defeats the purpose). I’m struggling to figure out how to navigate the input_shape field for random_flax_module.

Would appreciate any pointers or if I’ve missed something.

For the NumPyro part, doing something like below seems to initiate sampling (albeit very slow because of my for loop), but I’m not sure if it is correct. I can pass a graph argument to random_flax_module. I don’t know if this argument is only used to initialize the flax module, or if it’s being passed everytime the model is called.

def numpyro_model(list_of_graphs, target):
    
    # Define base flax module

    module = GraphNetwork(mlp_features=[16, 16], latent_size=16)
    
    # Register as a random flax module (parameter inference done outside)
    # TODO: check if graph argument makes sense. Would it just keep the graph constant for all calls of net?
    
    net = random_flax_module(
        "Mutag_GraphNet_Model",
        module,
        prior = dist.StudentT(df=4.0, scale=0.1),
        input_shape=None, 
        graph=list_of_graphs[0]['input_graph']
    )
        
    # Prediction over a list of graphs (right now it must be looped because of issues with graph size)
    # TODO: need to vectorize this
    
    preds = jnp.zeros(shape=(len(list_of_graphs), 1))
    
    for idx, g in enumerate(list_of_graphs):
        # pred_graph = net.apply(params, g['input_graph'])
        pred_graph = net(g['input_graph'])
        pred = jax.nn.log_softmax(pred_graph.globals)
        preds = preds.at[idx].set(pred.flatten()[0])
    
    # Compute likelihood function
    # TODO: make this the standard classification target (predict probabilities and observe binary responses)
    
    return numpyro.sample("target", dist.Normal(jnp.exp(preds), 0.03), obs=target)
    
    # return preds
    
# Initialize model

model2 = GraphNetwork(mlp_features=[16, 16], latent_size=16)
key = jax.random.PRNGKey(0)
params = model2.init(key, train_mutag_ds[0]['input_graph'])
print(parameter_overview.get_parameter_overview(params)) # uses clu (common loop utils)
del model2

# Initialize MCMC

kernel = NUTS(numpyro_model, 
              init_strategy = init_to_feasible(), 
              target_accept_prob=0.80,
              max_tree_depth=10,
              )

mcmc = MCMC(
    kernel,
    num_warmup=1000,
    num_samples=1000,
    num_chains=1,
    progress_bar=True, # TOGGLE this...
    chain_method="vectorized", 
    # jit_model_args=True,
)

# Run MCMC

mcmc.run(jax.random.PRNGKey(3), train_mutag_ds, y_train)

yes, arguments in random_flax_module are only used to initialize/substitute required parameters. your code looks correct to me

1 Like

Great, thanks for checking @fehiepsi! It is pretty nice that a more complicated set of objects can be handled with this functionality!

@fehiepsi do you have any thoughts on how I might go about speeding up the forward-pass here? Particularly this bit of code:

preds = jnp.zeros(shape=(len(list_of_graphs), 1))
    
for idx, g in enumerate(list_of_graphs):
    # pred_graph = net.apply(params, g['input_graph'])
    pred_graph = net(g['input_graph'])
    pred = jax.nn.log_softmax(pred_graph.globals)
    preds = preds.at[idx].set(pred.flatten()[0])

The list_of_graphs contains a set of Jraph GraphTuples objects (a list of inputs), where the attribute input_graph needs to be accessed for the forward pass. Using the standard jax.lax.scan or jax.lax.for_iloop constructs don’t seem to work (or atleast I can’t seem to get them to work). This chunk of code seems to be non-jittable. Would appreciate any pointers!

I think you can try batching instead of list of graphs. There are some helpers here Jraph API — Jraph 0.0.1.dev documentation

Completely missed that. Seems to work now. Thanks so much for the pointer!

For future reference, linking to the final implementation of GNNs using Jraph, Flax and NumPyro. Thanks again for the help @fehiepsi!

1 Like