Architecture of the network used in Pyro normalizing flows

I am reading the Pyro tutorial on normalizing flows ( and I would like to understand better how the examples work under the hood. For instance, I am referring to the architecture of the network used to obtain the marginal distributions of the concentric circles example. In the example the base distribution (in the latent space) is normal and the flow is a rational spline:

base_dist = dist.Normal(torch.zeros(2), torch.ones(2))
spline_transform = T.Spline(2, count_bins=16)
flow_dist = dist.TransformedDistribution(base_dist, [spline_transform])

According to the tutorial, the knots (of the spline) and their derivatives are parameters that can be learnt e.g., through stochastic gradient descent on a maximum likelihood objective. The tutorial shows how to do that :

steps = 1 if smoke_test else 1001
dataset = torch.tensor(X, dtype=torch.float)
optimizer = torch.optim.Adam(spline_transform.parameters(), lr=1e-2)
for step in range(steps):
    loss = -flow_dist.log_prob(dataset).mean()

    if step % 200 == 0:
        print('step: {}, loss: {}'.format(step, loss.item()))

Finally, it is indicated how to sample from the learned distribution in order to obtain a new sample :

X_flow = flow_dist.sample(torch.Size([1000,])).detach().numpy()

I would like to know what is the architecture of the NN used to learn those parameters and if is there a (possibly simple) way to modify this architecture (e.g. add or remove layers)

More generally, I would like to adapt these simple examples to the univariate case of learning the density of time series data.

cc @stefanwebb

Thanks for your contribution. Should I write directly to @stefanwebb?

Hi @bzaffora, that’s a great question! :slight_smile:

There are no NNs in T.Spline as this transform applies element-wise and does not condition on a vector. The parameters are simply nn.Parameter objects, see here:

On the other hand, T.ConditionalSpline will take a dense MLP, and T.SplineAutoregressive uses an autoregressive MLP known as MADE:

Hope this helps!

1 Like

Dear @stefanwebb thanks a lot for your answer!
Following your indications I managed to obtain some results which make sense. I will now play a bit with the NN architectures but the results are already satisfying.
Do you have any advice for handling time series with this machinery? Any pitfalls I should be aware of? I am considering simple univariate time series for my study.

I haven’t had any experience directly working with time series… I would start with an autoregressive network based on an LSTM. Then from there you could consider newer sequence models like WaveNet and Transformers