Training Normalising Flow in `model()`

Hi all. I’m a new joiner to Pyro and have been playing around with the Normalising Flow example and wanted to experiment with a couple of fairly trivial examples. I want to fit a NF to a multivariate Gaussian and generate simple Bernoulli draws via a logit/sigmoid model of the 1st Gaussian parameter.

I generated the data using the following:

import pyro
import pyro.distributions as pyro_dist
import pyro.distributions.transforms as T
from pyro.infer import SVI, Trace_ELBO
import torch



size = 20000

z1 = pyro.sample('z1', pyro_dist.Normal(loc=0., scale=1.), sample_shape=(size,)).flatten()
z2 = pyro.sample('z2', pyro_dist.Normal(loc=2. * z1, scale=1.), sample_shape=(1,)).flatten()
p_true = torch.sigmoid(2 * z1)
x = pyro.sample('x', pyro_dist.Bernoulli(probs=p_true), sample_shape=(1,)).reshape((size, -1)).flatten()
z_vals = torch.stack((z1, z2)).T

My aim was to try and pull in the functionality in the NF Pyro demo into a model/guide function framework rather than the approach in the example. My attempt is pasted below:

def model_mle(z_vals=None, x_vals=None):
    beta_z1 = pyro.param('beta_z1', torch.rand(1))
    beta_z2 = pyro.param('beta_z2', torch.rand(1))
    with pyro.plate('data', x_vals.shape[0]):
        base_dist = pyro_dist.Normal(torch.zeros(2), torch.ones(2))
        spline_transform = T.spline_coupling(2, count_bins=16)
        flow_dist = pyro.sample('Z', pyro_dist.TransformedDistribution(base_dist, [spline_transform]), obs=z_vals)
        x = pyro.sample('x', 
            pyro_dist.Bernoulli(probs=torch.sigmoid(beta_z1 + beta_z2*flow_dist[:, 0])),
            obs=x_vals
        )

def guide_mle(z_vals=None, x_vals=None):
    pass       

lr=0.00005
n_steps=2000
pyro.clear_param_store()
adam_params = {"lr": lr}
adam = pyro.optim.Adam(adam_params)
svi = SVI(model_mle, guide_mle, adam, loss=Trace_ELBO())

for step in range(n_steps):
    loss = svi.step(z_vals, x)
    if step % 100 == 0:
        print('[iter {}]  loss: {:.4f}'.format(step, loss))

The first issue is that my loss is highly volatile and doesn’t appear to converge:

[iter 0]  loss: 139756.8848
[iter 100]  loss: 128274.8037
[iter 200]  loss: 162184.8350
[iter 300]  loss: 140769.3340
[iter 400]  loss: 129621.6895
[iter 500]  loss: 151650.7041
[iter 600]  loss: 158514.3320
[iter 700]  loss: 134495.2383
[iter 800]  loss: 146055.1777
[iter 900]  loss: 134883.5430
[iter 1000]  loss: 176964.1367

Although my parameter estimates appear to be correct (close-ish to 0 and 1):

pyro.get_param_store().get_state()['params']
>>>{'beta_z1': tensor([0.0111], requires_grad=True),
    'beta_z2': tensor([0.9617], requires_grad=True)}

Finally, I would like to draw samlpes from the trained model similar to what is done in the example page. I have tried using the Predictive class but that seems to generate discriminative samples rather than new samples from the NF. How could I do this in Pyro?

To summarise, my general questions are:

  1. Am I training an NF model correctly?
  2. How can I draw generative samples from a trained SVI model?

Thanks!

  • instantiate the flow once outside of model
  • use something like pyro.module("flow", my_flow) in model
  • you should be able to get generative samples if you pass x_vals=None

Thanks for the pointers Martin. I’m happy generating samples following your suggestion although I still have some issues with the first two points.

I’m a little unsure what pyro.module does but after browsing online it appears that it registers the parameters of the flow (defined outside of the model() function) within the Pyro sampler. I specified the spline transform outside of the model as follows:

spline_transform = T.spline_coupling(2, count_bins=16)

def model_mle(z_vals, x_vals):
    pyro.module("flow", spline_transform)    
    base_dist = pyro_dist.Normal(loc=torch.zeros(2), scale=torch.ones(2))
    with pyro.plate('data', x_vals.shape[0]):
        flow_dist = pyro.sample(
            'flow',
            pyro_dist.TransformedDistribution(base_dist, [spline_transform]),
            obs=z_vals
        )

def guide_mle(z_vals, x_vals):
    pass    

lr=0.0005
n_steps=2000
pyro.clear_param_store()
adam_params = {"lr": lr}
adam = pyro.optim.Adam(adam_params)
svi = SVI(model_mle, guide_mle, adam, loss=Trace_ELBO())

for step in range(n_steps):
    loss = svi.step(z_vals, x)
    if step % 500 == 0:
        print('[iter {}]  loss: {:.4f}'.format(step, loss))

If I run this snippet I run into the following error:

...
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

I’m not sure if I misunderstood your suggestion here?

i believe you need to add spline_transform.clear_cache() at the top of the model

see the flow tutorial