Hi,
Is there currently a way to lift a haiku module from an nn to a bnn? I would like to able to do something like this:
class HaikuCNN:
def __init__(self):
self.act = jax.nn.relu
def __call__(self, x):
"""Simple CNN for diabetic retinopathy classification."""
x = x.astype(jnp.float32)
cnn = hk.Sequential([
hk.Conv2D(output_channels=32, kernel_shape=3, padding='SAME'),
self.act,
hk.Conv2D(output_channels=32, kernel_shape=3, padding='SAME'),
self.act,
hk.Linear(2),
])
return cnn(x)
def model(x, y=None, subsample_size=100):
net = haiku_module('cnn', hk.transform(HaikuCNN()), input_shape=(1, *x.shape[1:]))
with numpyro.plate('data', x.shape[0], subsample_size=subsample_size):
batch_x = numpyro.subsample(x, event_dim=3) # image with 3 color channels
batch_y = numpyro.subsample(y, event_dim=0) if y is not None else None
logits = net(batch_x)
numpyro.sample('obs', Categorical(logits=logits), obs=batch_y)
if __name__ == '__main__':
lifted_model = lift(model, prior={'cnn$params': Normal(0, .05)}) # add prior to (haiku) CNN params
svi = SVI(lifted_model, AutoDelta(lifted_model), Adagrad(1e-4), Trace_ELBO())
results = svi.run(rng_key, 1000, data.x, data.y)
However, this produces a scalar for cnn$params
and not a dictionary of sample statements for each of the init_val in the param msg.