Hi! We are trying to build a MNIST classifier (batches of 128) by constructing a module (Pyro Primitives — NumPyro documentation) with a stax neural network, and use NUTS for inference as shown below.
def net(hidden_dim, out_dim):
return stax.serial(
stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus,
stax.Dense(out_dim, W_init=stax.randn()), stax.Sigmoid,
)
def model(batch, y):
batch_dim, input_dim, out_dim = np.shape(batch)[0], np.shape(batch)[1], 10 # 128, 784, 10
nn = numpyro.module('nn', net(1024, out_dim), (batch_dim, input_dim))
logits = nn(batch) # fails
return numpyro.sample('Y', dist.Categorical(logits=logits), obs=y)
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
key = random.PRNGKey(10)
mcmc.run(rng_key=key, batch=x, y=y)
Calling logits = nn(batch)
however results in the following error, and it seems that rng
is always None, possibly causing it:
/usr/local/lib/python3.6/dist-packages/jax/experimental/stax.py in apply_fun(params, inputs, **kwargs)
298 rng = kwargs.pop('rng', None)
299 rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
--> 300 for fun, param, rng in zip(apply_funs, params, rngs):
301 inputs = fun(param, inputs, rng=rng, **kwargs)
302 return inputs
TypeError: zip argument #2 must support iteration
Since by documentation it seems that module takes in only input array, and params and **kwargs are specified somewhere internally, we are not sure where to look for the error. Perhaps someone can help us out?
Thanks!