Hi! We are trying to build a MNIST classifier (batches of 128) by constructing a module (http://num.pyro.ai/en/v0.2.4/primitives.html#numpyro.primitives.module) with a stax neural network, and use NUTS for inference as shown below.
def net(hidden_dim, out_dim): return stax.serial( stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus, stax.Dense(out_dim, W_init=stax.randn()), stax.Sigmoid, ) def model(batch, y): batch_dim, input_dim, out_dim = np.shape(batch), np.shape(batch), 10 # 128, 784, 10 nn = numpyro.module('nn', net(1024, out_dim), (batch_dim, input_dim)) logits = nn(batch) # fails return numpyro.sample('Y', dist.Categorical(logits=logits), obs=y) nuts_kernel = NUTS(model) mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000) key = random.PRNGKey(10) mcmc.run(rng_key=key, batch=x, y=y)
logits = nn(batch) however results in the following error, and it seems that
rng is always None, possibly causing it:
/usr/local/lib/python3.6/dist-packages/jax/experimental/stax.py in apply_fun(params, inputs, **kwargs) 298 rng = kwargs.pop('rng', None) 299 rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers --> 300 for fun, param, rng in zip(apply_funs, params, rngs): 301 inputs = fun(param, inputs, rng=rng, **kwargs) 302 return inputs TypeError: zip argument #2 must support iteration
Since by documentation it seems that module takes in only input array, and params and **kwargs are specified somewhere internally, we are not sure where to look for the error. Perhaps someone can help us out?