How to reset parameters for multiple retrainings (normalizing flow model)

I have a (conditional) normalizing flow model very similar to the one outlined in the Normalizing Flows - Introduction (Part 1) tutorial. I am trying to train the model multiple times to assess its stability and take an average of its results. I am trying to use a loop to do this, but I am having issues resetting the parameters for each retraining:

modules   = torch.nn.ModuleList([transform_1, transform_2]).to(device)   
optimizer = torch.optim.Adam(modules.parameters(), lr=lr)

steps = 201
bootstraps = 10
for bootstrap in range(bootstraps):
    pyro.get_param_store().clear()
    for step in range(steps):
        optimizer.zero_grad()
        # log liklihood
        ln_p_1 = dist_1.log_prob(data_1)
        ln_p_2_given_1 = dist_2_given_1.condition(data_1).log_prob(data_2)
        ...
        loss_train.backward()
        optimizer.step()
        dist_1.clear_cache()
        dist_2_given_1.clear_cache()

which gives:

bootstrap: 0 step: 0, train loss: 11.535920143127441
bootstrap: 0 step: 200, train loss: 4.8237104415893555
bootstrap: 1 step: 0, train loss: 4.820072174072266
bootstrap: 1 step: 200, train loss: 4.592116832733154
bootstrap: 2 step: 0, train loss: 4.59176778793335
...

The model keeps training using the same parameters (essentially training only one model). Here, transform_1 is a SplineAutoregressive transform and transform_2 is a ConditionalSplineAutoregressive transform.

I’d appreciate if anyone could guide me on how to resolve this issue.
Thanks!

Could you try a pyro.clear_param_store() after each loop? Pyro synchronizes parameter between each PyroModule and the param store, so you’ll need to clear both.

pyro.get_param_store().clear() clears parameters that have been registered via pyro.param or pyro.module. but it looks like you aren’t using those mechanisms. i suppose you need to move your constructor for your transforms into your for bootstrap loop. let us know if that doesn’t fix your issue.

Thank you! The suggestion to move the constructor for the transforms into the loop worked.