Hello,
I’m wondering how one can enforce the exact same weights initialization when using haiku_module in NumPyro (e.g., in PyTorch, I would just set torch.manual_seed() before each initialization). Thanks in advance!
Code snippet:
import numpyro
from numpyro.contrib.module import haiku_module
import haiku as hk
import jax
import numpy as np
# Define a simple MLP encoder
class MLP(hk.Module):
def __init__(self, zdim=2):
super().__init__()
self._zdim = zdim
def __call__(self, x):
x = hk.Linear(64)(x)
x = jax.nn.relu(x)
x = hk.Linear(64)(x)
x = jax.nn.relu(x)
x = hk.Linear(self._zdim)(x)
return x
z_dim = 3
input_dim = (144,)
X = np.random.randn(32, 144)
nn_module = hk.transform(lambda x: MLP(z_dim)(x))
with numpyro.handlers.seed(rng_seed=1):
nn1 = haiku_module("nn1", nn_module, input_shape=(1, *input_dim))
nn2 = haiku_module("nn2", nn_module, input_shape=(1, *input_dim))
z1 = nn1(X)
z2 = nn2(X)
print(np.allclose(z1, z2)) # False