NN weights initialization in haiku_module

Hello,

I’m wondering how one can enforce the exact same weights initialization when using haiku_module in NumPyro (e.g., in PyTorch, I would just set torch.manual_seed() before each initialization). Thanks in advance!

Code snippet:

import numpyro
from numpyro.contrib.module import haiku_module
import haiku as hk
import jax
import numpy as np

# Define a simple MLP encoder
class MLP(hk.Module):
    def __init__(self, zdim=2):
        super().__init__()
        self._zdim = zdim

    def __call__(self, x):
        x = hk.Linear(64)(x)
        x = jax.nn.relu(x)
        x = hk.Linear(64)(x)
        x = jax.nn.relu(x)
        x = hk.Linear(self._zdim)(x)
        return x

z_dim = 3
input_dim = (144,)
X = np.random.randn(32, 144)

nn_module = hk.transform(lambda x: MLP(z_dim)(x))

with numpyro.handlers.seed(rng_seed=1):
    nn1 = haiku_module("nn1", nn_module, input_shape=(1, *input_dim))
    nn2 = haiku_module("nn2", nn_module, input_shape=(1, *input_dim))
    z1 = nn1(X)
    z2 = nn2(X)

print(np.allclose(z1, z2))   # False

Interestingly, it seems that replacing

with numpyro.handlers.seed(rng_seed=1):
    nn1 = haiku_module("nn1", nn_module, input_shape=(1, *input_dim))
    nn2 = haiku_module("nn2", nn_module, input_shape=(1, *input_dim))

with

with numpyro.handlers.seed(rng_seed=1):
    nn1 = haiku_module("nn1", nn_module, input_shape=(1, *input_dim))
with numpyro.handlers.seed(rng_seed=1):
    nn2 = haiku_module("nn2", nn_module, input_shape=(1, *input_dim))

does the trick. But how would I enforce it when using haiku_module inside numpyro’s model?

I think you can use numpyro.prng_key() in a numpyro model and use it as the common seed.

1 Like