Lifting a haiku module in NumPyro

Hi,

Is there currently a way to lift a haiku module from an nn to a bnn? I would like to able to do something like this:

class HaikuCNN:
    def __init__(self):
        self.act = jax.nn.relu

    def __call__(self, x):
        """Simple CNN for diabetic retinopathy classification."""
        x = x.astype(jnp.float32)
        cnn = hk.Sequential([
            hk.Conv2D(output_channels=32, kernel_shape=3, padding='SAME'),
            self.act,
            hk.Conv2D(output_channels=32, kernel_shape=3, padding='SAME'),
            self.act,
            hk.Linear(2),
        ])
        return cnn(x)


def model(x, y=None, subsample_size=100):
    net = haiku_module('cnn', hk.transform(HaikuCNN()), input_shape=(1, *x.shape[1:]))

    with numpyro.plate('data', x.shape[0], subsample_size=subsample_size):
        batch_x = numpyro.subsample(x, event_dim=3)  # image with 3 color channels
        batch_y = numpyro.subsample(y, event_dim=0) if y is not None else None

        logits = net(batch_x)
        numpyro.sample('obs', Categorical(logits=logits), obs=batch_y)

if __name__ == '__main__':

    lifted_model = lift(model, prior={'cnn$params': Normal(0, .05)})  # add prior to  (haiku) CNN params

    svi = SVI(lifted_model, AutoDelta(lifted_model), Adagrad(1e-4), Trace_ELBO())
    results = svi.run(rng_key, 1000, data.x, data.y)

However, this produces a scalar for cnn$params and not a dictionary of sample statements for each of the init_val in the param msg.

Hi Ola, I think you can use random_haiku_module primitive. :slight_smile:

Thanks @fehiepsi exactly what I was looking for C: