Misunderstanding of how to write guides for SVI

Hello - I’m quite new to pyro/numpyro and have read through most of the numpyro port of Statistical Rethinking (GitHub - fehiepsi/rethinking-numpyro: Statistical Rethinking (2nd ed.) with NumPyro), and read through all the introductory tutorials for Pyro (more than once for the SVI tutorials) and numpyro.

I’m trying to simply redo a simple linear regression model implemented by a colleague of mine using numpyro to see if I can recover similar coefficient to what they got using statsmodel.

I define a linear model as:

def linear_regression(X, y=None):
    a = numpyro.sample("a", numpyro.distributions.Normal(0, 100))
    n_scaled_features = X.shape[1] - 1 

    with numpyro.plate("bs", size=n_scaled_features):
        b = numpyro.sample("b", numpyro.distributions.Normal(0., 800.))

    b2 = numpyro.sample("b2", numpyro.distributions.Normal(0., 1.))
    mu = a + b2 * X[:, -1] + jnp.sum(X[:, :-1] * b, axis=1)
    sigma = numpyro.sample("sigma", numpyro.distributions.HalfNormal(5))
    with numpyro.plate("data", size=X.shape[0]):
        numpyro.sample("obs", numpyro.distributions.Normal(mu, sigma), obs=y)

When I run this model using:

num_warmup, num_samples = 1000, 1000
mcmc = MCMC(NUTS(model=linear_regression), num_warmup, num_samples)
mcmc.run(random.PRNGKey(2), X, y) 

I get essentially the same coefficients as those returned by the statsmodel linear regression. However when I try to fit the model using SVI, first by following the format shown in the tutorial: Bayesian Regression - Inference Algorithms (Part 2) — Pyro Tutorials 1.8.4 documentation,

def manual_guide(X, y=None):
    n_scaled_features = X.shape[1] - 1                            
    a_loc = numpyro.param("a_loc", jnp.zeros(1))
    a_scale = numpyro.param("a_scale", 100. * jnp.ones(1),
                            constraint=numpyro.distributions.constraints.positive)    
    a = numpyro.sample("a", numpyro.distributions.Normal(loc=a_loc, scale=a_scale))

    b_loc = numpyro.param("b_loc", jnp.zeros(n_scaled_features))
    b_scale = numpyro.param("b_scale",  1000 * jnp.ones(n_scaled_features))
    with numpyro.plate("bs", size=n_scaled_features):
        b = numpyro.sample("b", numpyro.distributions.Normal(loc=b_loc, scale=b_scale))
    
    b2_loc = numpyro.param("b2_loc", jnp.zeros(1))
    b2_scale = numpyro.param("b2_scale",  1. * jnp.ones(1))
    b2 = numpyro.sample("b2", numpyro.distributions.Normal(loc=b2_loc, scale=b2_scale))
                           
    sigma_scale = numpyro.param("sigma_scale", 5. * jnp.ones(1))
    sigma = numpyro.sample("sigma", numpyro.distributions.HalfNormal(sigma_scale))

I get totally different results. When I plot the loss it doesn’t seem to go down. I also find that when I try SVI with the AutoNormal autoguide, the loss goes down by the coefficeints are totally different.

I am wondering if this has to do with the data itself, which has features with different scales, or if it has to do with how I am initializing the guide params?

A full worked example of this can be found here:

Thanks for the cool framework and any hints of what I’m doing wrong would be greatly appreciated!

having wildly different scales certainly has the potential to be a problem for SVI. you’d probably get better results if you standardized all the data (and transformed back to the original space at the end).

Thanks @martinjankowiak for the quick reply. I’ve tried scaling the variables and still unfortunately get drastically different results: I am wondering if I am perhaps using the Predictive class wrong when trying to get posterior samples?

 posterior = Predictive(

    model=linear_regression,

    guide=guide,

    params=params,

    return_sites=['b'],

    num_samples=1000

)(random.PRNGKey(0), X=X, y=y)

jnp.mean(posterior['b'], axis=0)

@blazina Your colab link needs access permission.

Hi @fehiepsi, sorry about that was writing from my work computer which restricted access. This link should work

Thanks so much for taking the time to look at it!

I think you can try TraceMeanField_ELBO instead of Trace_ELBO objective, or using AutoDelta / AutoLaplaceApproximation guide, or increasing the number of steps to e.g. 1e5 to make sure that SVI converges. The scale parameters in your custom guide are initialized too large, probably you should decrease them. The sigma_scale parameter should have constraint positive. Overall, your code looks good to me, it just needs a bit of tuning to get the desired results. FYI, with the next version of NumPyro, you can use svi.run instead of init/update/lax.scan/get_params methods.

1 Like

Thanks a lot for taking a look at that! Maybe some follow up questions for my understanding:

  1. I’m wondering if there is somewhere in the documentation or literature about choosing which objective use (e.g. Trace_ELBO vs. TraceMeanField_ELBO vs. TraceEnum_ELBO) etc. As someone new to this, it’s a bit tricky deciphering this from the API documentation.

  2. Same question with regards to choosing one of the autoguides (i.e. AutorNormal vs AutoDelta etc.)? I notice when I still use the AutoNormal autoguide and just give it more steps it also seems to get similar estimates of the coefficients as MCMC and the statsmodel model.

  3. Are there some diagnostics in SVI for understanding when it converges? I figured it would that the loss goes down and more or less stabilizes, however with my original code using the AutoNormal autoguide and 3000 steps the loss went down and seemed to stabilize when I plotted it, but I guess maybe it was just the scale of the y-axis was so large that I missed that it was still going down and I guess this is why you were looking at the last 2000 steps?

  4. Unfortunately when I play around (by decreasing) with the initialization of the scale parameter by decreasing it, adding the positive constraint to the sigma_scale parameter the loss, increasing the number of steps to around 1e5 and playing around with the optimizer learning rate, I still not able to get similar results as with using the autogudes or MCMC. Are there some general hints for how to initialize the parameters in the guide? In the tutorials I have gone through and in the documentation it seems that the initialize of the guide parameters generally seems to follow the parameter settings for the distributions in the model function, however this does not seem to work in this case. I am looking at the source code for the AutoNormal and see (if im understanding the code correctly) it initializes the loc parameters using init_to_uniform which looks like it is using a Uniform distribution between -2 and 2, and it initializes the scale parameters simply using a value of 0.1. However using similar initializations doesn’t seem to help in my case. am I misunderstanding the code? I guess I can probably often just use one of the autoguide classes but I’m trying to understand the guide functions a bit more :slight_smile:

Also very cool about the update to the API to just use svi.run()! I was just following the same format from your Statistical Rethinking numpyro port.

Thanks again for taking the time to look at the code and I appreciate any further insights/help you could offer. :pray:

1 Like

Great questions! I think a good way is still to follow the API documentation and tutorials. When we need something else (e.g. inference with discrete sites), we can search for it to find the right API to use. Probably you are already familiar with AutoNormal, but if you only want to learn the MAP point, you will find that using AutoDelta is enough.

some diagnostics in SVI for understanding when it converges

I think training SVI is like training a neural network. You can find some resources about TensorBoard or something similar, which is helpful to monitor the training/validating process.

still not able to get similar results as with using the auto guides or MCMC

sigma = numpyro.sample("sigma", dist.HalfNormal(sigma_scale))

I think this only learns the scale of sigma and assumes the mode point is 0. Using AutoNormal, it will learn loc and scale of log(sigma). To do something similar to AutoNormal for constrained support, you can use something like dist.TransformedDistribution(dist.Normal(loc, scale), dist.transforms.ExpTransform()).

Thanks a ton @fehiepsi for the thorough answers. I’m not particularly juts interested in getting the MAP point, was just using this simple example to see if I understand how to specify the models and fit them using SVI.

Still not fully understanding how write a guide myself, as implementing:

    sigma_loc = numpyro.param("sigma_loc", jnp.zeros(1), constraint=numpyro.distributions.constraints.positive)            
    sigma_scale = numpyro.param("sigma_scale", 1 * jnp.ones(1), constraint=numpyro.distributions.constraints.positive)
    sigma = numpyro.sample("sigma", dist.TransformedDistribution(dist.Normal(sigma_loc, sigma_scale), dist.transforms.ExpTransform()))

still just leads to the SVI loss getting down to ~1e6 after 1e5 steps.

I’ll have to dig through the API docs a bit more. As I said I’m new to probabilistic programming so sometimes understanding takes some time :wink:

Ah actually now that I think about it - adding the positive constraint on the sigma_loc and sigma_scale parameters makes does not make sense. Now if I remove these constraints and run the SVI for 1e5 steps as you suggest - i get similar results to the AutoNormal. In any case I will keep reading up on SVI :slight_smile:
Thanks again!

Actually sorry one more question - I’m wondering - maybe a dumb question but what is the difference between

    sigma_loc = numpyro.param("sigma_loc", jnp.zeros(1))            
    sigma_scale = numpyro.param("sigma_scale", 1. * jnp.ones(1))
    sigma = numpyro.sample("sigma", dist.TransformedDistribution(dist.Normal(sigma_loc, sigma_scale), dist.transforms.ExpTransform()))

and

exponential_rate = numpyro.param("exp_rate", 0.4 * jnp.ones(1))
sigma = numpyro.sample("sigma", dist.Exponential(rate=exponential_rate))

Looking at the source code for ExpTransform, It seems that the TransformedDistribution is just taking the jnp.exp of the samples from NormalDistribution

Because if I look at histograms of samples from like dist.Exponential(rate=0.4) and from dist.TransformedDistribution(dist.Normal(0, 1), dist.transforms.ExpTransform()) they are quite similar, however if I use the dist.Exponential in the guide, it does not work and i’m not sure I follow why this is the case?

the exponential distribution is the distribution that has a pdf of the form p(x) = exp(-x); the exponential transform of the normal distributions yields the so called log-normal distribution, a distinct distribution.

Ah, that makes sense. Thanks!