Hi,
I was trying to Bayesianize graph neural networks using NumPyro and Jraph.
Firstly, if I do not use Jraph to define the graphs of interest, then there is no real issue. One can go ahead and define graph operations with base JAX, and couple it with Flax for NN definitions, and then proceed with random_flax_module.
However, Jraph provides a nice object for dealing with standard graphs. I have copied a model from Jraph’s examples to illustrate:
class ExplicitMLP(nn.Module):
"""A flax MLP."""
features: Sequence[int]
@nn.compact
def __call__(self, inputs):
x = inputs
for i, lyr in enumerate([nn.Dense(feat) for feat in self.features]):
x = lyr(x)
if i != len(self.features) - 1:
x = nn.relu(x)
return x
def make_embed_fn(latent_size):
def embed(inputs):
return nn.Dense(latent_size)(inputs)
return embed
def make_mlp(features):
@jraph.concatenated_args
def update_fn(inputs):
return ExplicitMLP(features)(inputs)
return update_fn
class GraphNetwork(nn.Module):
"""A flax GraphNetwork."""
mlp_features: Sequence[int]
latent_size: int
@nn.compact
def __call__(self, graph):
# Add a global parameter for graph classification computation
graph = graph._replace(globals=jnp.zeros([graph.n_node.shape[0], 1]))
embedder = jraph.GraphMapFeatures(
embed_node_fn=make_embed_fn(self.latent_size),
embed_edge_fn=make_embed_fn(self.latent_size),
embed_global_fn=make_embed_fn(self.latent_size))
net = jraph.GraphNetwork(
update_node_fn=make_mlp(self.mlp_features),
update_edge_fn=make_mlp(self.mlp_features),
# The global update outputs size 2 for binary classification.
update_global_fn=make_mlp(self.mlp_features + (2,))) # pytype: disable=unsupported-operands
return net(embedder(graph))
The flax.linen module GraphNetwork
is where the main aspects of the GNN are synthesized. However, instead of accepting a static tensor as an input, it accepts a graph
object (this is the Jraph GraphsTuple
object). This graph object has many attributes to it, that cannot be simplified to some constant array (or atleast doing so defeats the purpose). I’m struggling to figure out how to navigate the input_shape
field for random_flax_module
.
Would appreciate any pointers or if I’ve missed something.
For the NumPyro part, doing something like below seems to initiate sampling (albeit very slow because of my for loop), but I’m not sure if it is correct. I can pass a graph
argument to random_flax_module
. I don’t know if this argument is only used to initialize the flax module, or if it’s being passed everytime the model is called.
def numpyro_model(list_of_graphs, target):
# Define base flax module
module = GraphNetwork(mlp_features=[16, 16], latent_size=16)
# Register as a random flax module (parameter inference done outside)
# TODO: check if graph argument makes sense. Would it just keep the graph constant for all calls of net?
net = random_flax_module(
"Mutag_GraphNet_Model",
module,
prior = dist.StudentT(df=4.0, scale=0.1),
input_shape=None,
graph=list_of_graphs[0]['input_graph']
)
# Prediction over a list of graphs (right now it must be looped because of issues with graph size)
# TODO: need to vectorize this
preds = jnp.zeros(shape=(len(list_of_graphs), 1))
for idx, g in enumerate(list_of_graphs):
# pred_graph = net.apply(params, g['input_graph'])
pred_graph = net(g['input_graph'])
pred = jax.nn.log_softmax(pred_graph.globals)
preds = preds.at[idx].set(pred.flatten()[0])
# Compute likelihood function
# TODO: make this the standard classification target (predict probabilities and observe binary responses)
return numpyro.sample("target", dist.Normal(jnp.exp(preds), 0.03), obs=target)
# return preds
# Initialize model
model2 = GraphNetwork(mlp_features=[16, 16], latent_size=16)
key = jax.random.PRNGKey(0)
params = model2.init(key, train_mutag_ds[0]['input_graph'])
print(parameter_overview.get_parameter_overview(params)) # uses clu (common loop utils)
del model2
# Initialize MCMC
kernel = NUTS(numpyro_model,
init_strategy = init_to_feasible(),
target_accept_prob=0.80,
max_tree_depth=10,
)
mcmc = MCMC(
kernel,
num_warmup=1000,
num_samples=1000,
num_chains=1,
progress_bar=True, # TOGGLE this...
chain_method="vectorized",
# jit_model_args=True,
)
# Run MCMC
mcmc.run(jax.random.PRNGKey(3), train_mutag_ds, y_train)
yes, arguments in random_flax_module are only used to initialize/substitute required parameters. your code looks correct to me
1 Like
Great, thanks for checking @fehiepsi! It is pretty nice that a more complicated set of objects can be handled with this functionality!
@fehiepsi do you have any thoughts on how I might go about speeding up the forward-pass here? Particularly this bit of code:
preds = jnp.zeros(shape=(len(list_of_graphs), 1))
for idx, g in enumerate(list_of_graphs):
# pred_graph = net.apply(params, g['input_graph'])
pred_graph = net(g['input_graph'])
pred = jax.nn.log_softmax(pred_graph.globals)
preds = preds.at[idx].set(pred.flatten()[0])
The list_of_graphs
contains a set of Jraph GraphTuples
objects (a list of inputs), where the attribute input_graph
needs to be accessed for the forward pass. Using the standard jax.lax.scan
or jax.lax.for_iloop
constructs don’t seem to work (or atleast I can’t seem to get them to work). This chunk of code seems to be non-jittable. Would appreciate any pointers!
I think you can try batching instead of list of graphs. There are some helpers here Jraph API — Jraph 0.0.1.dev documentation
Completely missed that. Seems to work now. Thanks so much for the pointer!
For future reference, linking to the final implementation of GNNs using Jraph, Flax and NumPyro. Thanks again for the help @fehiepsi!
1 Like