Using stax with numpyro module

Hi! We are trying to build a MNIST classifier (batches of 128) by constructing a module (http://num.pyro.ai/en/v0.2.4/primitives.html#numpyro.primitives.module) with a stax neural network, and use NUTS for inference as shown below.

        def net(hidden_dim, out_dim):
            return stax.serial(
                stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus,
                stax.Dense(out_dim, W_init=stax.randn()), stax.Sigmoid,
            )

        def model(batch, y):
            batch_dim, input_dim, out_dim = np.shape(batch)[0], np.shape(batch)[1], 10 # 128, 784, 10 
            nn = numpyro.module('nn', net(1024, out_dim), (batch_dim, input_dim))
            logits = nn(batch) # fails
            return numpyro.sample('Y', dist.Categorical(logits=logits), obs=y)
    
        nuts_kernel = NUTS(model)
        mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
        key = random.PRNGKey(10)
        mcmc.run(rng_key=key, batch=x, y=y)

Calling logits = nn(batch) however results in the following error, and it seems that rng is always None, possibly causing it:

        /usr/local/lib/python3.6/dist-packages/jax/experimental/stax.py in apply_fun(params, inputs, **kwargs)
        298     rng = kwargs.pop('rng', None)
        299     rngs = random.split(rng, nlayers) if rng is not None else (None,) * nlayers
        --> 300     for fun, param, rng in zip(apply_funs, params, rngs):
        301       inputs = fun(param, inputs, rng=rng, **kwargs)
        302     return inputs

    TypeError: zip argument #2 must support iteration

Since by documentation it seems that module takes in only input array, and params and **kwargs are specified somewhere internally, we are not sure where to look for the error. Perhaps someone can help us out?

Thanks!

Hi @kiviliis, I think this feature is not supported yet. Could you open an issue in github so we can track down the progress toward a solution? FYI, here is one thing we need to address:

  • Support param statement with arbitrary JAX pytree type (this is the type of params in stax) in MCMC; currently, we only support numpy.ndarray type. A solution is to leverage ravel_pytree utility in JAX to transform the nn params to a flatten numpy.ndarrary.