# NN weights initialization in haiku_module

Hello,

I’m wondering how one can enforce the exact same weights initialization when using haiku_module in NumPyro (e.g., in PyTorch, I would just set torch.manual_seed() before each initialization). Thanks in advance!

Code snippet:

``````import numpyro
from numpyro.contrib.module import haiku_module
import haiku as hk
import jax
import numpy as np

# Define a simple MLP encoder
class MLP(hk.Module):
def __init__(self, zdim=2):
super().__init__()
self._zdim = zdim

def __call__(self, x):
x = hk.Linear(64)(x)
x = jax.nn.relu(x)
x = hk.Linear(64)(x)
x = jax.nn.relu(x)
x = hk.Linear(self._zdim)(x)
return x

z_dim = 3
input_dim = (144,)
X = np.random.randn(32, 144)

nn_module = hk.transform(lambda x: MLP(z_dim)(x))

with numpyro.handlers.seed(rng_seed=1):
nn1 = haiku_module("nn1", nn_module, input_shape=(1, *input_dim))
nn2 = haiku_module("nn2", nn_module, input_shape=(1, *input_dim))
z1 = nn1(X)
z2 = nn2(X)

print(np.allclose(z1, z2))   # False
``````

Interestingly, it seems that replacing

``````with numpyro.handlers.seed(rng_seed=1):
nn1 = haiku_module("nn1", nn_module, input_shape=(1, *input_dim))
nn2 = haiku_module("nn2", nn_module, input_shape=(1, *input_dim))
``````

with

``````with numpyro.handlers.seed(rng_seed=1):
nn1 = haiku_module("nn1", nn_module, input_shape=(1, *input_dim))
with numpyro.handlers.seed(rng_seed=1):
nn2 = haiku_module("nn2", nn_module, input_shape=(1, *input_dim))
``````

does the trick. But how would I enforce it when using haiku_module inside numpyro’s model?

I think you can use numpyro.prng_key() in a numpyro model and use it as the common seed.

1 Like