When introducing a BatchNorm layer in a flax linen module and using SVI inference, I get assertion errors when attempting to undertake posterior prediction. As per https://github.com/pyro-ppl/numpyro/issues/1446 I have changed the line 611 ["param"]
to ["param", "mutable"]
in numpyro/infer/util.py which solved the training phase, but this does not seem to have solved the prediction phase. Any ideas/suggestions?
I get the following assertion error when executing the prediction part (the ‘preds = pred(pred_key,…’ line):
**File /opt/conda/envs/py39/lib/python3.9/site-packages/numpyro/contrib/module.py:82, in flax_module(name, nn_module, input_shape, apply_rng, mutable, *args, **kwargs)
** 81 assert nn_state is None or isinstance(nn_state, dict)**
—> 82 assert (nn_state is None) == (nn_params is None)
** 84 if nn_params is None:**
** 85 # feed in dummy data to init params**
** 86 args = (jnp.ones(input_shape),) if input_shape is not None else args**
**AssertionError:
Here are some fragments of the code: First, flax linen model (sorry for the poor rendering!):
class BNN_Net(nn.Module):
hidden_dim: int
@nn.compact
def call(self, x, is_training: bool):
x = nn.Dense(self.hidden_dim, name=f’layer_0’)(x[…, None].squeeze()) # make sure input data is squeezed
x = nn.BatchNorm(name=f’batch_norm_0’, use_bias=False, use_scale=False, momentum=0.9, use_running_average=not is_training)(x)
x = nn.relu(x)
x = nn.Dense(self.hidden_dim, name=f’layer_1’)(x)
x = nn.BatchNorm(name=f’batch_norm_1’, use_bias=False, use_scale=False, momentum=0.9, use_running_average=not is_training)(x)[quote=“oli42, post:1, topic:5327, full:true”]
Next, part of the model:
def model(x, y=None, hidden_dim=args[‘hidden_dim’], subsample_size=args[‘subsample_size’], is_training=False):
bnn_module = BNN_Net(hidden_dim=hidden_dim)
bnn_net = random_flax_module(“bnn”, bnn_module, dist.Normal(0, 0.1/jnp.sqrt(prec_bnn_prior)), input_shape=(x.shape), mutable=[“batch_stats”], is_training=True)
with numpyro.plate(“batch”, x.shape[0], subsample_size=subsample_size, dim=-1):
batch_x = numpyro.subsample(x, event_dim=1)
if y is not None:
batch_y = numpyro.subsample(y, event_dim=0)
else:
batch_y = y
out = bnn_net(batch_x, is_training) #forward data in NN
return numpyro.sample(“y_obs”, dist.Normal(out.squeeze(), 0.1/jnp.sqrt(prec_obs)), obs=batch_y) # likelihood
Next, part of the SVI inference section:
data = load_data()
inf_key, pred_key, data_key = random.split(random.PRNGKey(args[‘rng_key’]), 3)
# normalize data and labels to zero mean unit variance!
x, xtr_mean, xtr_std = normalize(data.xtr)
y, ytr_mean, ytr_std = normalize(data.ytr)
rng_key, inf_key = random.split(inf_key)
optimizer = numpyro.optim.Adam(step_size=0.0005)
guide = AutoNormal(model, init_loc_fn = partial(init_to_uniform, radius=0.1))
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(rng_key, args[‘num_optimisation_steps’], x, y, is_training=True, progress_bar=True)
Finally, part of the prediction section:
pred = Predictive(model=model, guide=guide, params=params, num_samples=100)
preds = pred(pred_key, xte, subsample_size=xte.shape[0])[“y_obs”]