Normalising flow manual guide with mini-batch

Hello Pyro Gurus,

I am trying to setup a normalizing flow refining the posterior over the weights of the last layer of a neural network, starting from a Gaussian distribution.
I have implemented a class for my normalizing flow with a model and guide, which can be seen here with a few comments:

class NF(PyroModule):
    
    def __init__(self, dim, nf_len, device, parms_start_dist, nb_classes, prior_precision, nb_datapoints):
        
        super(NF, self).__init__()
        
        self.dim = dim
        self.nf_len = nf_len
        self.device = device
        self.nb_classes = nb_classes
        self.prior_precision = prior_precision
        self.nb_datapoints = nb_datapoints
        
        self.base_dist = dist.MultivariateNormal(params_start_dist['mean'].to(device), params_start_dist['covariance_m'].to(device))

        trans = dist.transforms.Radial
        self.transforms = [trans(dim) for _ in range(nf_len)]
        
        self.flow_dist = dist.TransformedDistribution(self.base_dist, self.transforms)
        
    
    def guide(self, x, y):

        pyro.module("nf", nn.ModuleList(self.transforms).to(self.device))
        pyro.sample("w", self.flow_dist)
                
    def model(self, x, y):

        w = pyro.sample("w", dist.Normal(torch.zeros(self.dim).to(self.device), math.sqrt(1/(self.prior_precision))).to_event(1))    #prior
        class_pbs  = act_lm1(w, x, y, self.nb_classes)     #returns the class probabilities for each class for the sampled weight.

        with pyro.plate("data", size=self.nb_datapoints, subsample_size=len(y.squeeze())):
            pyro.sample("obs", dist.Categorical(probs=class_pbs), obs=y)    #likelihood
    
    
    def sample(self, n):
        
        return self.flow_dist.sample(torch.Size([n]))
    
    def log_prob(self, z):

        return self.flow_dist.log_prob(z)

I am training the model with the following:

optimizer = torch.optim.Adam
n_steps = n_epochs * len(train_loader_lm1)
params_scheduler = {'optimizer': optimizer, 'optim_args': {'lr': 1e-3, 'weight_decay': 0}, 'T_max': n_steps}
scheduler = optim.CosineAnnealingLR(params_scheduler)
nf = Normalizing_flow(dim, nf_len, device, posterior_params, nb_classes, prec)

svi = SVI(nf.model, nf.guide, optim=scheduler, loss=Trace_ELBO())

for epoch in range(n_epochs):
    for x, y in train_loader_lm1:     # activations from the lasts hidden layer of the network are passed here
        loss = svi.step(x.to(device), y.to(device))
        scheduler.step()

When running things as they are here, I get results that make no sense.
When replacing the guide by a guide created with AutoNormalizingFlow and passing it to the SVI class instead of nf.guide, everything works as intended.

I am trying to replicate the results of a paper to have something to hold up my results against. I do get the same results as in the paper when using the AutoNormalizingFlow, which makes me think that everything but the guide part should be correct.

Is there anything in my guide that isn’t setup correctly? (The guide is intended to be a distribution over the weights of the last layer of the network). If not where do you think the model is coming from. I have been looking through the source code of the AutoNormalizingFlow to try to find the answer, but it seems like quite a lot is going on there.
My gut feeling is that the mini-batching could be the issue, because as far as I recall I was getting the same results with the manual guide and the autoguide before that.
Hope one of you can enlighten me :slight_smile:

what’s trans? maybe you’re using a wacky overcomplicated set of flows? have you tried something simple?

My bad. I tried to simplify the code in here, by removing options such as the type of normalizing flow, so that it would be shorter and easier to read. I have corrected the mistake in the above code now. trans is simply

dist.transforms.Radial

(where dist is pyro.distributions).

I have so far mainly tried with a simple flow of length 1.

can you show how you’re using AutoNormalizingFlow for comparison?

Yes, sure! I am actually using a subclass of the AutoNormalizingFlow class. It is mainly just overwritting some methods of the AutoNormalizingFlow class to be a able to give a multivariate normal distribution as an input. (This is because I am using a gaussian from a laplace approximation as a base distribution, which does hence not have a diagonal covariance matrix). So it shouldn’t change much.

FLOW_TYPES = {
    'planar': dist.transforms.planar,
    'radial': dist.transforms.radial
}

class AutoNormalizingFlowCustom(AutoNormalizingFlow):
    """
    AutoNormalizingFlow guide with a custom base distribution
    """

    def __init__(self, model, base_dist_mean, base_dist_Cov, diag=False, flow_type='radial', flow_len=5, cuda=False):
        init_transform_fn = functools.partial(dist.transforms.iterated, flow_len, FLOW_TYPES[flow_type])
        super().__init__(model, init_transform_fn)

        self.base_dist_mean = base_dist_mean
        self.base_dist_Cov = base_dist_Cov

        if diag:
            assert self.base_dist_Cov.shape == self.base_dist_mean.shape
            self.base_dist = dist.Normal(self.base_dist_mean, torch.sqrt(self.base_dist_Cov))
        else:
            self.base_dist = dist.MultivariateNormal(self.base_dist_mean, self.base_dist_Cov)

        self.transform = None
        self._prototype_tensor = torch.tensor(0.0, device='cuda' if cuda else 'cpu')
        self.cuda = cuda

    def get_base_dist(self):
        return self.base_dist

    def get_posterior(self, *args, **kwargs):
        if self.transform is None:
            self.transform = self._init_transform_fn(self.latent_dim)

            if self.cuda:
                self.transform.to('cuda:0')

            # Update prototype tensor in case transform parameters
            # device/dtype is not the same as default tensor type.
            for _, p in self.named_pyro_params():
                self._prototype_tensor = p
                break
        return super().get_posterior(*args, **kwargs)

The training is then as follows:

optimizer = torch.optim.Adam
n_steps = n_epochs * len(train_loader_lm1)
params_scheduler = {'optimizer': optimizer, 'optim_args': {'lr': 1e-3, 'weight_decay': 0}, 'T_max': n_steps}
scheduler = optim.CosineAnnealingLR(params_scheduler)

nf = Normalizing_flow(dim, nf_len, device, posterior_params, nb_classes, best_prec, nb_datapoints)
guide = AutoNormalizingFlowCustom(nf.model, posterior_params['mean'].to(device), posterior_params['covariance_m'].to(device), diag=False, flow_type='radial', flow_len=nf_len, cuda=True)
svi = SVI(nf.model, guide, optim=scheduler, loss=Trace_ELBO())

for epoch in range(n_epochs):
    for x, y in train_loader_lm1:
        loss = svi.step(x.to(device), y.to(device))
        scheduler.step()

have you checked to see if there are any parameters that aren’t being updated? what exactly does “results that make no sense” mean?

The parameters of the flow seem to get updated. I have printed them at the beginning and the end of my runs.
To produce the following results, I have drawn 600 weight samples from the original distribution and passed them through the flow after each epoch.
For the version with the manual guide, I have done it as follows:

refined_posterior_samples = nf.sample(600)

As for the guide using AutoNormalizingFlow, I have done it as follows:

refined_posterior_samples = guide.get_posterior().sample((600,))

These samples were then used to monitor the performance of the model, with the last layer weights sampled from the base distribution and passed through the flow after each epoch of training. Weights from the refined posterior from which I computed the average accuracy, expected calibration error, and negative log likelihood.

Here is a run of the above code for a flow of length 1 on Cifar-10 with the guide from AutoNormalizingFlow:

epoch: 0
Initial parameters: [Parameter containing:
tensor([ 0.0090,  0.0040,  0.0082,  ..., -0.0118,  0.0143, -0.0066],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([0.0016], device='cuda:0', requires_grad=True), Parameter containing:
tensor([0.0125], device='cuda:0', requires_grad=True)]
loss: 12795386.334999084
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.4%; NLL: 0.003045
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 2.3%; NLL: 0.1644
epoch: 1
loss: 12565156.90934372
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.3%; NLL: 0.003103
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 2.1%; NLL: 0.1637
epoch: 2
loss: 12285179.66997528
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.3%; NLL: 0.003131
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 2.4%; NLL: 0.1634
epoch: 3
loss: 12005859.918476105
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.3%; NLL: 0.003203
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 2.3%; NLL: 0.1632
epoch: 4
loss: 11834738.629882812
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.2%; NLL: 0.003219
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 2.0%; NLL: 0.1628
epoch: 5
loss: 11608440.30405426
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.4%; NLL: 0.003297
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 2.1%; NLL: 0.1621
epoch: 6
loss: 11378924.625205994
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.3%; NLL: 0.003332
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 2.4%; NLL: 0.1625
epoch: 7
loss: 11182454.25932312
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.3%; NLL: 0.003351
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 2.1%; NLL: 0.1621
epoch: 8
loss: 11069251.063743591
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.4%; NLL: 0.003396
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 1.9%; NLL: 0.1619
epoch: 9
loss: 10926388.541629791
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.3%; NLL: 0.003461
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 2.2%; NLL: 0.1616
epoch: 10
loss: 10843495.305335999
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.3%; NLL: 0.003482
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 1.9%; NLL: 0.1616
epoch: 11
loss: 10744848.569850922
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.3%; NLL: 0.003478
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 2.0%; NLL: 0.1613
epoch: 12
loss: 10646590.513072968
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.4%; NLL: 0.00349
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 2.0%; NLL: 0.1616
epoch: 13
loss: 10606677.239040375
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.3%; NLL: 0.003486
Val: [Refined posterior nf_len: 1] Acc.: 94.5%; ECE: 2.1%; NLL: 0.1613
epoch: 14
loss: 10581528.76145935
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.4%; NLL: 0.00354
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 1.8%; NLL: 0.1615
epoch: 15
loss: 10505899.216732025
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.3%; NLL: 0.003534
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 1.9%; NLL: 0.1615
epoch: 16
loss: 10542939.202266693
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.4%; NLL: 0.003543
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 2.0%; NLL: 0.1614
epoch: 17
loss: 10484477.208011627
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.5%; NLL: 0.003516
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 1.8%; NLL: 0.1615
epoch: 18
loss: 10467579.383476257
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.4%; NLL: 0.003594
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 1.9%; NLL: 0.1617
epoch: 19
loss: 10522578.139362335
Train: [Refined posterior nf_len: 1] Acc.: 100.0%; ECE: 0.4%; NLL: 0.003558
Val: [Refined posterior nf_len: 1] Acc.: 94.4%; ECE: 2.0%; NLL: 0.1616
Final parameters: [Parameter containing:
tensor([-0.0747, -0.0240, -0.0002,  ...,  0.0170, -0.0477, -0.0353],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([4.1169], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-2.7297], device='cuda:0', requires_grad=True)]

Here is the equivalent for the manual guide for which the code can be seen in my post above:

tensor([ 0.0011, -0.0112, -0.0081,  ..., -0.0137, -0.0133,  0.0153],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([0.0110], device='cuda:0', requires_grad=True), Parameter containing:
tensor([0.0195], device='cuda:0', requires_grad=True)]
epoch: 0
loss: 171560011.37402344
Train [Refined posterior nf_len: 1] Acc.: 10.7%; ECE: 59.3%; NLL: 2.266
Val [Refined posterior nf_len: 1] Acc.: 10.6%; ECE: 58.8%; NLL: 2.275
epoch: 1
loss: 164302134.35913086
Train [Refined posterior nf_len: 1] Acc.: 12.2%; ECE: 60.7%; NLL: 2.144
Val [Refined posterior nf_len: 1] Acc.: 12.0%; ECE: 60.6%; NLL: 2.158
epoch: 2
loss: 154584200.39697266
Train [Refined posterior nf_len: 1] Acc.: 14.3%; ECE: 58.6%; NLL: 1.988
Val [Refined posterior nf_len: 1] Acc.: 14.0%; ECE: 59.9%; NLL: 2.011
epoch: 3
loss: 143199055.17285156
Train [Refined posterior nf_len: 1] Acc.: 17.4%; ECE: 73.2%; NLL: 1.791
Val [Refined posterior nf_len: 1] Acc.: 16.9%; ECE: 71.7%; NLL: 1.824
epoch: 4
loss: 132526197.07617188
Train [Refined posterior nf_len: 1] Acc.: 19.5%; ECE: 62.8%; NLL: 1.676
Val [Refined posterior nf_len: 1] Acc.: 18.8%; ECE: 62.3%; NLL: 1.716
epoch: 5
loss: 124575260.64111328
Train [Refined posterior nf_len: 1] Acc.: 22.2%; ECE: 49.9%; NLL: 1.542
Val [Refined posterior nf_len: 1] Acc.: 21.3%; ECE: 49.4%; NLL: 1.59
epoch: 6
loss: 115497900.87866211
Train [Refined posterior nf_len: 1] Acc.: 23.8%; ECE: 57.2%; NLL: 1.473
Val [Refined posterior nf_len: 1] Acc.: 22.8%; ECE: 56.8%; NLL: 1.525
epoch: 7
loss: 108210262.4050293
Train [Refined posterior nf_len: 1] Acc.: 27.5%; ECE: 48.6%; NLL: 1.335
Val [Refined posterior nf_len: 1] Acc.: 26.2%; ECE: 49.2%; NLL: 1.396
epoch: 8
loss: 102092282.08642578
Train [Refined posterior nf_len: 1] Acc.: 29.4%; ECE: 53.4%; NLL: 1.273
Val [Refined posterior nf_len: 1] Acc.: 28.0%; ECE: 52.4%; NLL: 1.335
epoch: 9
loss: 99723217.81591797
Train [Refined posterior nf_len: 1] Acc.: 31.6%; ECE: 53.7%; NLL: 1.196
Val [Refined posterior nf_len: 1] Acc.: 30.0%; ECE: 54.1%; NLL: 1.265
epoch: 10
loss: 93545905.40258789
Train [Refined posterior nf_len: 1] Acc.: 33.9%; ECE: 60.7%; NLL: 1.135
Val [Refined posterior nf_len: 1] Acc.: 32.1%; ECE: 60.9%; NLL: 1.209
epoch: 11
loss: 92619292.33007812
Train [Refined posterior nf_len: 1] Acc.: 34.6%; ECE: 29.9%; NLL: 1.109
Val [Refined posterior nf_len: 1] Acc.: 32.7%; ECE: 31.9%; NLL: 1.182
epoch: 12
loss: 88005605.0805664
Train [Refined posterior nf_len: 1] Acc.: 35.8%; ECE: 69.2%; NLL: 1.076
Val [Refined posterior nf_len: 1] Acc.: 33.8%; ECE: 67.9%; NLL: 1.152
epoch: 13
loss: 87215300.15893555
Train [Refined posterior nf_len: 1] Acc.: 35.7%; ECE: 41.7%; NLL: 1.078
Val [Refined posterior nf_len: 1] Acc.: 33.8%; ECE: 42.4%; NLL: 1.154
epoch: 14
loss: 85808482.27832031
Train [Refined posterior nf_len: 1] Acc.: 36.7%; ECE: 18.2%; NLL: 1.051
Val [Refined posterior nf_len: 1] Acc.: 34.7%; ECE: 22.7%; NLL: 1.128
epoch: 15
loss: 84359475.83886719
Train [Refined posterior nf_len: 1] Acc.: 37.0%; ECE: 40.8%; NLL: 1.047
Val [Refined posterior nf_len: 1] Acc.: 34.9%; ECE: 42.1%; NLL: 1.125
epoch: 16
loss: 84794714.33544922
Train [Refined posterior nf_len: 1] Acc.: 37.7%; ECE: 46.3%; NLL: 1.028
Val [Refined posterior nf_len: 1] Acc.: 35.5%; ECE: 46.3%; NLL: 1.108
epoch: 17
loss: 81745417.88500977
Train [Refined posterior nf_len: 1] Acc.: 37.8%; ECE: 56.9%; NLL: 1.021
Val [Refined posterior nf_len: 1] Acc.: 35.6%; ECE: 56.9%; NLL: 1.099
epoch: 18
loss: 81712124.76098633
Train [Refined posterior nf_len: 1] Acc.: 37.0%; ECE: 44.5%; NLL: 1.043
Val [Refined posterior nf_len: 1] Acc.: 35.0%; ECE: 45.8%; NLL: 1.121
epoch: 19
loss: 83449416.94702148
Train [Refined posterior nf_len: 1] Acc.: 38.2%; ECE: 37.2%; NLL: 1.015
Val [Refined posterior nf_len: 1] Acc.: 36.1%; ECE: 38.5%; NLL: 1.094
Final parameters: [Parameter containing:
tensor([-1.7001, -1.2074, -0.0183,  ...,  0.3942,  0.1517, -0.0794],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([4.6404], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-3.2050], device='cuda:0', requires_grad=True)]

The initial parameters of the flow seem to be initialized to very similar values in both cases and do clearly get changed by the training as can be seen at the end, just like the loss clearly decreases in both cases.
What I find particularly interesting though, is that for the guide from the AutoNormalizingFlow class, the accuracy of the model is already 100% on the training set after just one epoch (which is already the case for samples from the base distribution of the normalizing flow). The loss also starts at a value that is 1 order of magnitude lower than for the manual guide.
It’s also interesting to note that for the manual version, the initially sampled weights give an accuracy of 10%, which in this case corresponds to what would be expected from randomness as the model is doing classification for 10 classes.

I made equivalent runs for a flow of length 5 instead and here are the results:

AutoNormalizingFlow guide:

epoch: 0
Initial parameters: [Parameter containing:
tensor([-0.0099, -0.0080, -0.0159,  ...,  0.0138,  0.0062, -0.0135],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0046], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0034], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0189, -0.0077,  0.0131,  ..., -0.0119,  0.0132, -0.0060],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0072], device='cuda:0', requires_grad=True), Parameter containing:
tensor([0.0009], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0152, -0.0180, -0.0192,  ...,  0.0011, -0.0062,  0.0030],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0081], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0080], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0072, -0.0002,  0.0170,  ...,  0.0013, -0.0064, -0.0107],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([0.0204], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0067], device='cuda:0', requires_grad=True), Parameter containing:
tensor([ 0.0042,  0.0065,  0.0014,  ..., -0.0032, -0.0157,  0.0041],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0032], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0099], device='cuda:0', requires_grad=True)]
loss: 12423581.450374603
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.3%; NLL: 0.003276
Val : [Refined posterior nf_len: 5] Acc.: 94.4%; ECE: 2.1%; NLL: 0.1626
epoch: 1
loss: 11454646.946655273
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.3%; NLL: 0.003545
Val: [Refined posterior nf_len: 5] Acc.: 94.4%; ECE: 2.3%; NLL: 0.1616
epoch: 2
loss: 10479299.46893692
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.4%; NLL: 0.003707
Val: [Refined posterior nf_len: 5] Acc.: 94.4%; ECE: 1.9%; NLL: 0.1608
epoch: 3
loss: 9581285.133712769
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.3%; NLL: 0.003885
Val: [Refined posterior nf_len: 5] Acc.: 94.4%; ECE: 2.2%; NLL: 0.1607
epoch: 4
loss: 8816411.550674438
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.3%; NLL: 0.004113
Val: [Refined posterior nf_len: 5] Acc.: 94.5%; ECE: 1.7%; NLL: 0.16
epoch: 5
loss: 8112117.11333847
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.3%; NLL: 0.004203
Val: [Refined posterior nf_len: 5] Acc.: 94.5%; ECE: 2.0%; NLL: 0.1597
epoch: 6
loss: 7544694.462738037
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.5%; NLL: 0.004369
Val: [Refined posterior nf_len: 5] Acc.: 94.5%; ECE: 1.7%; NLL: 0.1593
epoch: 7
loss: 6967643.862586975
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.5%; NLL: 0.004486
Val: [Refined posterior nf_len: 5] Acc.: 94.5%; ECE: 1.7%; NLL: 0.1599
epoch: 8
loss: 6548382.631622314
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.3%; NLL: 0.004585
Val: [Refined posterior nf_len: 5] Acc.: 94.5%; ECE: 1.6%; NLL: 0.1593
epoch: 9
loss: 6166367.627822876
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.6%; NLL: 0.004598
Val: [Refined posterior nf_len: 5] Acc.: 94.5%; ECE: 1.1%; NLL: 0.159
epoch: 10
loss: 5911045.613105774
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.6%; NLL: 0.004755
Val: [Refined posterior nf_len: 5] Acc.: 94.5%; ECE: 1.4%; NLL: 0.159
epoch: 11
loss: 5680019.000030518
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.5%; NLL: 0.004736
Val: [Refined posterior nf_len: 5] Acc.: 94.5%; ECE: 1.3%; NLL: 0.1591
epoch: 12
loss: 5468559.282722473
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.6%; NLL: 0.004783
Val: [Refined posterior nf_len: 5] Acc.: 94.5%; ECE: 1.1%; NLL: 0.159
epoch: 13
loss: 5312881.596405029
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.5%; NLL: 0.0048
Val: [Refined posterior nf_len: 5] Acc.: 94.5%; ECE: 1.5%; NLL: 0.1589
epoch: 14
loss: 5186559.1647872925
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.5%; NLL: 0.004861
Val: [Refined posterior nf_len: 5] Acc.: 94.5%; ECE: 1.4%; NLL: 0.1589
epoch: 15
loss: 5098488.17074585
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.5%; NLL: 0.004865
Val: [Refined posterior nf_len: 5] Acc.: 94.6%; ECE: 1.4%; NLL: 0.1588
epoch: 16
loss: 5079465.265968323
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.5%; NLL: 0.004912
Val: [Refined posterior nf_len: 5] Acc.: 94.6%; ECE: 1.6%; NLL: 0.1589
epoch: 17
loss: 5032036.387924194
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.5%; NLL: 0.004838
Val: [Refined posterior nf_len: 5] Acc.: 94.6%; ECE: 1.4%; NLL: 0.159
epoch: 18
loss: 5041070.383125305
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.4%; NLL: 0.004902
Val: [Refined posterior nf_len: 5] Acc.: 94.6%; ECE: 1.5%; NLL: 0.159
epoch: 19
loss: 4984916.695335388
Train: [Refined posterior nf_len: 5] Acc.: 100.0%; ECE: 0.5%; NLL: 0.004861
Val: [Refined posterior nf_len: 5] Acc.: 94.6%; ECE: 1.5%; NLL: 0.1588
Final parameters: [Parameter containing:
tensor([-0.0099, -0.0080, -0.0159,  ...,  0.0138,  0.0062, -0.0135],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0046], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0034], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.1345, -0.0414, -0.0009,  ...,  0.0121, -0.0086, -0.0103],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([3.9336], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-2.5960], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.1348, -0.0415, -0.0009,  ...,  0.0122, -0.0085, -0.0102],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([3.9297], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-2.6036], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.1350, -0.0415, -0.0009,  ...,  0.0123, -0.0083, -0.0101],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([3.9473], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-2.6028], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.1354, -0.0417, -0.0009,  ...,  0.0123, -0.0081, -0.0100],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([3.9255], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-2.6059], device='cuda:0', requires_grad=True)]

Manual guide:

Initial parameters: [Parameter containing:
tensor([-0.0174,  0.0139,  0.0070,  ...,  0.0017,  0.0171, -0.0130],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([0.0110], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0158], device='cuda:0', requires_grad=True), Parameter containing:
tensor([ 0.0137,  0.0095,  0.0192,  ..., -0.0121,  0.0003,  0.0156],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0057], device='cuda:0', requires_grad=True), Parameter containing:
tensor([0.0113], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0045, -0.0108,  0.0104,  ..., -0.0040,  0.0026,  0.0131],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([0.0152], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0146], device='cuda:0', requires_grad=True), Parameter containing:
tensor([0.0061, 0.0136, 0.0085,  ..., 0.0146, 0.0027, 0.0062], device='cuda:0',
       requires_grad=True), Parameter containing:
tensor([-0.0176], device='cuda:0', requires_grad=True), Parameter containing:
tensor([0.0049], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0166,  0.0093,  0.0119,  ...,  0.0030,  0.0116, -0.0098],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0039], device='cuda:0', requires_grad=True), Parameter containing:
tensor([0.0100], device='cuda:0', requires_grad=True)]
epoch: 0
loss: 166367113.0830078
Train [Refined posterior nf_len: 5] Acc.: 12.5%; ECE: 64.2%; NLL: 2.119
Val [Refined posterior nf_len: 5] Acc.: 12.3%; ECE: 64.2%; NLL: 2.136
epoch: 1
loss: 136611726.76660156
Train [Refined posterior nf_len: 5] Acc.: 21.7%; ECE: 47.5%; NLL: 1.578
Val [Refined posterior nf_len: 5] Acc.: 20.9%; ECE: 50.1%; NLL: 1.623
epoch: 2
loss: 100487249.58691406
Train [Refined posterior nf_len: 5] Acc.: 35.3%; ECE: 56.2%; NLL: 1.094
Val [Refined posterior nf_len: 5] Acc.: 33.4%; ECE: 56.7%; NLL: 1.169
epoch: 3
loss: 69214443.7644043
Train [Refined posterior nf_len: 5] Acc.: 49.5%; ECE: 38.8%; NLL: 0.7567
Val [Refined posterior nf_len: 5] Acc.: 46.4%; ECE: 40.9%; NLL: 0.8515
epoch: 4
loss: 49766237.83691406
Train [Refined posterior nf_len: 5] Acc.: 61.4%; ECE: 13.0%; NLL: 0.5384
Val [Refined posterior nf_len: 5] Acc.: 57.1%; ECE: 17.0%; NLL: 0.6467
epoch: 5
loss: 37055855.17236328
Train [Refined posterior nf_len: 5] Acc.: 70.3%; ECE: 13.6%; NLL: 0.401
Val [Refined posterior nf_len: 5] Acc.: 65.2%; ECE: 16.7%; NLL: 0.5166
epoch: 6
loss: 30721167.755126953
Train [Refined posterior nf_len: 5] Acc.: 77.8%; ECE: 14.6%; NLL: 0.2938
Val [Refined posterior nf_len: 5] Acc.: 72.0%; ECE: 19.3%; NLL: 0.417
epoch: 7
loss: 24503803.168701172
Train [Refined posterior nf_len: 5] Acc.: 82.1%; ECE: 3.0%; NLL: 0.2351
Val [Refined posterior nf_len: 5] Acc.: 75.8%; ECE: 7.3%; NLL: 0.3635
epoch: 8
loss: 21341726.696044922
Train [Refined posterior nf_len: 5] Acc.: 85.3%; ECE: 0.5%; NLL: 0.193
Val [Refined posterior nf_len: 5] Acc.: 78.9%; ECE: 4.5%; NLL: 0.3221
epoch: 9
loss: 19175277.17236328
Train [Refined posterior nf_len: 5] Acc.: 87.4%; ECE: 15.1%; NLL: 0.1655
Val [Refined posterior nf_len: 5] Acc.: 80.7%; ECE: 20.6%; NLL: 0.2964
epoch: 10
loss: 18014168.88470459
Train [Refined posterior nf_len: 5] Acc.: 89.6%; ECE: 1.6%; NLL: 0.1383
Val [Refined posterior nf_len: 5] Acc.: 82.8%; ECE: 5.5%; NLL: 0.2717
epoch: 11
loss: 16157617.845863342
Train [Refined posterior nf_len: 5] Acc.: 90.7%; ECE: 1.1%; NLL: 0.1236
Val [Refined posterior nf_len: 5] Acc.: 83.8%; ECE: 3.8%; NLL: 0.2572
epoch: 12
loss: 15569808.990112305
Train [Refined posterior nf_len: 5] Acc.: 91.2%; ECE: 1.0%; NLL: 0.1176
Val [Refined posterior nf_len: 5] Acc.: 84.2%; ECE: 4.4%; NLL: 0.2511
epoch: 13
loss: 14984843.924865723
Train [Refined posterior nf_len: 5] Acc.: 92.7%; ECE: 0.7%; NLL: 0.09905
Val [Refined posterior nf_len: 5] Acc.: 85.6%; ECE: 6.5%; NLL: 0.2351
epoch: 14
loss: 14715402.514190674
Train [Refined posterior nf_len: 5] Acc.: 93.1%; ECE: 5.0%; NLL: 0.0945
Val [Refined posterior nf_len: 5] Acc.: 86.0%; ECE: 8.8%; NLL: 0.2305
epoch: 15
loss: 14326710.75213623
Train [Refined posterior nf_len: 5] Acc.: 93.1%; ECE: 3.3%; NLL: 0.09284
Val [Refined posterior nf_len: 5] Acc.: 86.0%; ECE: 7.2%; NLL: 0.23
epoch: 16
loss: 14124777.824798584
Train [Refined posterior nf_len: 5] Acc.: 93.3%; ECE: 1.0%; NLL: 0.09222
Val [Refined posterior nf_len: 5] Acc.: 86.2%; ECE: 4.5%; NLL: 0.2272
epoch: 17
loss: 13546639.528060913
Train [Refined posterior nf_len: 5] Acc.: 93.7%; ECE: 11.1%; NLL: 0.08685
Val [Refined posterior nf_len: 5] Acc.: 86.6%; ECE: 15.0%; NLL: 0.2232
epoch: 18
loss: 14327481.138336182
Train [Refined posterior nf_len: 5] Acc.: 93.4%; ECE: 0.4%; NLL: 0.08979
Val [Refined posterior nf_len: 5] Acc.: 86.3%; ECE: 4.6%; NLL: 0.225
epoch: 19
loss: 13787830.698356628
Train [Refined posterior nf_len: 5] Acc.: 93.4%; ECE: 0.8%; NLL: 0.08965
Val [Refined posterior nf_len: 5] Acc.: 86.3%; ECE: 4.3%; NLL: 0.2252
Final parameters: [Parameter containing:
tensor([-0.0174,  0.0139,  0.0070,  ...,  0.0017,  0.0171, -0.0130],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([0.0110], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0158], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-1.2465, -0.8038,  0.0082,  ...,  0.1336,  0.1185,  0.0014],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([3.5557], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-2.4482], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-1.2572, -0.8112,  0.0083,  ...,  0.1392,  0.1148,  0.0086],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([3.5625], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-2.4613], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-1.2653, -0.8074,  0.0084,  ...,  0.1483,  0.1214, -0.0074],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([3.5465], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-2.4575], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-1.2869, -0.8113,  0.0086,  ...,  0.1400,  0.1278, -0.0128],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([3.5515], device='cuda:0', requires_grad=True), Parameter containing:
tensor([-2.4493], device='cuda:0', requires_grad=True)]

It can be seen for the manual guide that training clearly does improve the flow.
It seems like more training could potentially yield similar results to the ones obtained for the AutoNormalizingFlow, which makes me wonder if the difference could actually lay in some kind of weight initialization of the flow being made in the AutoNormalizingFlow class?
I’ve tried to look a bit through the source code of the class, but there is a lot of dependencies making it a bit hard to keep track of what is going on.

maybe the problem is you’re missing init_loc_fn=init_to_feasible in the init method? what are the initial locs in the various cases?

Hi @s163669 did you ever get this running? I have a similar problem and it would help to look at a working example if you have one.

Thanks!