Using Predictive with BatchNorm (flax.linen)

When introducing a BatchNorm layer in a flax linen module and using SVI inference, I get assertion errors when attempting to undertake posterior prediction. As per I have changed the line 611 ["param"] to ["param", "mutable"] in numpyro/infer/ which solved the training phase, but this does not seem to have solved the prediction phase. Any ideas/suggestions?

I get the following assertion error when executing the prediction part (the ‘preds = pred(pred_key,…’ line):

**File /opt/conda/envs/py39/lib/python3.9/site-packages/numpyro/contrib/, in flax_module(name, nn_module, input_shape, apply_rng, mutable, *args, **kwargs)

** 81 assert nn_state is None or isinstance(nn_state, dict)**
—> 82 assert (nn_state is None) == (nn_params is None)
** 84 if nn_params is None:**
** 85 # feed in dummy data to init params**
** 86 args = (jnp.ones(input_shape),) if input_shape is not None else args**


Here are some fragments of the code: First, flax linen model (sorry for the poor rendering!):

class BNN_Net(nn.Module):
hidden_dim: int
def call(self, x, is_training: bool):
x = nn.Dense(self.hidden_dim, name=f’layer_0’)(x[…, None].squeeze()) # make sure input data is squeezed
x = nn.BatchNorm(name=f’batch_norm_0’, use_bias=False, use_scale=False, momentum=0.9, use_running_average=not is_training)(x)
x = nn.relu(x)
x = nn.Dense(self.hidden_dim, name=f’layer_1’)(x)
x = nn.BatchNorm(name=f’batch_norm_1’, use_bias=False, use_scale=False, momentum=0.9, use_running_average=not is_training)(x)[quote=“oli42, post:1, topic:5327, full:true”]

Next, part of the model:

def model(x, y=None, hidden_dim=args[‘hidden_dim’], subsample_size=args[‘subsample_size’], is_training=False):
bnn_module = BNN_Net(hidden_dim=hidden_dim)
bnn_net = random_flax_module(“bnn”, bnn_module, dist.Normal(0, 0.1/jnp.sqrt(prec_bnn_prior)), input_shape=(x.shape), mutable=[“batch_stats”], is_training=True)

with numpyro.plate(“batch”, x.shape[0], subsample_size=subsample_size, dim=-1):
batch_x = numpyro.subsample(x, event_dim=1)
if y is not None:
batch_y = numpyro.subsample(y, event_dim=0)
batch_y = y
out = bnn_net(batch_x, is_training) #forward data in NN
return numpyro.sample(“y_obs”, dist.Normal(out.squeeze(), 0.1/jnp.sqrt(prec_obs)), obs=batch_y) # likelihood

Next, part of the SVI inference section:

data = load_data()
inf_key, pred_key, data_key = random.split(random.PRNGKey(args[‘rng_key’]), 3)
# normalize data and labels to zero mean unit variance!
x, xtr_mean, xtr_std = normalize(data.xtr)
y, ytr_mean, ytr_std = normalize(data.ytr)

rng_key, inf_key = random.split(inf_key)
optimizer = numpyro.optim.Adam(step_size=0.0005)
guide = AutoNormal(model, init_loc_fn = partial(init_to_uniform, radius=0.1))
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

svi_result =, args[‘num_optimisation_steps’], x, y, is_training=True, progress_bar=True)

Finally, part of the prediction section:

pred = Predictive(model=model, guide=guide, params=params, num_samples=100)
preds = pred(pred_key, xte, subsample_size=xte.shape[0])[“y_obs”]

To display a nice format, you can

put your code inside an ```...``` block

I guess you need to do

pred = Predictive(model=model, guide=guide, params={**svi_result.state.mutable_state, **params}, num_samples=100)

Thanks for prompt feedback. I’ll have a go at the predictive mod suggestion and get back.
BTW: Thank you for the effort you and the pyro team have invested in this project.

Thanks. It seems to work (BUT see further below regarding an error when specifying the number of ELBO particles).
For the record, here are the relevant sections of the code:
The NN:

class BNN(nn.Module):
  hidden_dim: int
  def __call__(self, x, is_training: bool):
    x = nn.Dense(self.hidden_dim, name=f'layer_input')(x[..., None].squeeze()) # make sure input data is squeezed
    # BatchNorm statiistics are only calculated during training phase, and NOT during the prediction/test phase
    x = nn.BatchNorm(name=f'batch_norm_0', use_bias=False, use_scale=False, momentum=0.9, use_running_average=not is_training)(x)
    x = nn.relu(x)
    x = nn.Dense(self.hidden_dim, name=f'layer_1')(x)
    #x = nn.BatchNorm(use_bias=False, use_scale=False, momentum=0.9, use_running_average=not self.is_training)(x)
    x = nn.BatchNorm(name=f'batch_norm_1', use_bias=False, use_scale=False, momentum=0.9, use_running_average=not is_training)(x)
    x = nn.relu(x)
    x = nn.Dense(features=1, name=f'layer_output')(x)
    return x

Parts of the model:

def BNNmodel(x, y=None, hidden_dim=args['hidden_dim'], data_size=None, is_training=False):
  # hyper-prior for precision of bnn weights and biases
  prec_bnn_prior = numpyro.sample("prec_bnn_prior", dist.Gamma(1.0, 0.1))  
  bnn_module = BNN(hidden_dim=hidden_dim) # instantiate/declare model  
  bnn_net = random_flax_module("bnn", bnn_module, dist.Normal(0, 0.1/jnp.sqrt(prec_bnn_prior)), input_shape=(jnp.shape(x)), mutable=["batch_stats"], is_training=is_training)  # make sure 'input_shape' is specified
  # precision hyper-prior on observations
  prec_obs = numpyro.sample("prec_obs", dist.Gamma(1.0, 0.1))
  # No sub-sampling/mini-batching to be undertaken in model plate as data is already mini-batched.  
  with numpyro.plate("data", size=data_size, subsample_size=jnp.shape(x)[0], dim=-1): # this model assumes the data (x) has already been mini-batched prior to instantiated
    out = bnn_net(x, is_training) #forward data in NN
    return numpyro.sample("y_obs", dist.Normal(out.squeeze(), 0.1/jnp.sqrt(prec_obs)), obs=y) # likelihood

Parts of the SVI inference section:

BNNoptimizer = numpyro.optim.Adam(step_size=0.0005)
BNNguide = AutoNormal(model=BNNmodel, init_loc_fn = partial(init_to_uniform, radius=0.1))
svi = SVI(model=BNNmodel, guide=BNNguide, optim=BNNoptimizer, loss=Trace_ELBO(),
# Initialise the SVI state
batch_data_itr = data_stream_gen(rng_data_key, num_complete_batches, num_training_data) # generate an init batch to initialise SteinVI
batch_data_x_init, batch_data_y_init = next(batch_data_itr)
svi_state_init = svi.init(rng_key_init, batch_data_x_init, batch_data_y_init) 
svi_state = svi_state_init
def epoch_train(svi_state, rng_key):
  def body_fn(i, val):
    loss_sum, svi_state = val 
    batch_data_x, batch_data_y = next(data_stream_gen(rng_key, num_complete_batches, num_training_data)) # generate a mini-batch
    svi_state, loss = svi.update(svi_state, batch_data_x, batch_data_y, is_training=True) 
    loss_sum += loss
    return loss_sum, svi_state
  return lax.fori_loop(0, num_complete_batches, body_fn, (0.0, svi_state)) # ignore the last incomplete batch
start = time()
# epoch loop
epochs_iterator = tqdm(range(args["num_optimisation_steps"]), desc="Epoch count ")
for epoch_count in epochs_iterator:
        rng_key, rng_key_train = random.split(rng_key, 2) # rng_key_train to be used for permutation of batch indices
        epoch_train_loss, svi_state = epoch_train(svi_state, rng_key_train)

Parts of the prediction section:

posterior_predictive_distribution = Predictive(model=BNNmodel, guide=BNNguide, params={**svi_state.mutable_state, **params}, num_samples=args['num_posterior_samples'])
preds = posterior_predictive_distribution(pred_key, data_x_test, data_size=num_test_data_total)["y_obs"]

There is still one problem though: It gives an error when using a number of ELBO particles larger than 1 when training. Any fixes for this as a small number of ELBO particles might compromise the quality of the training phase?

File /opt/conda/envs/py39/lib/python3.9/site-packages/numpyro/infer/, in Trace_ELBO.loss_with_mutable_state(self, rng_key, param_map, model, guide, *args, **kwargs)
    164 else:
    165     rng_keys = random.split(rng_key, self.num_particles)
--> 166     elbos, mutable_state = vmap(single_particle_elbo)(rng_keys)
    167     return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state}

    [... skipping hidden 3 frame]

File /opt/conda/envs/py39/lib/python3.9/site-packages/numpyro/infer/, in Trace_ELBO.loss_with_mutable_state.<locals>.single_particle_elbo(rng_key)
    151         return elbo_particle, mutable_params
    152     else:
--> 153         raise ValueError(
    154             "Currently, we only support mutable states with num_particles=1."
    155         )
    156 else:
    157     return elbo_particle, None

ValueError: Currently, we only support mutable states with num_particles=1.

I don’t know how to update mutable stats from multple particles. Do you have any idea/formula?

I don’t really understand the details, but the Stein VI inference technique seems to be able to (I have tried it and it doesn’t complain). See

Here’s a dumb question: What’s the relationship or the interaction, if any, between the ELBO particles and the Stein particles?

cc @OlaRonning

Probably steinvi puts some assumptions on how mutable states are updated across particles. Or it just missed a check in the implementation.

Hi @oli42,

I’m sorry to disappoint, but @fehiepsi is right; there is no check in the implementation.

SteinVI’s development is currently motivated by my research interest, but if you have a use case, let me know, and we can look at a solution together.

With regards to the connection between ELBO (SVI) and Stein mixtures (SteinVI), if you use a single particle, you exactly recover ELBO optimization. If you use k multiple particles, a delta guide, and set the loss temperature to 1/k you recover Stein Variational Gradient Descent.

Best, Ola