SVI how to debug the ELBO

Hello,

In a complex example performing a SVI on a model, I have some doubts on the ELBO computation for a MVN guide. Certainly due to numerical instabilities. But to investigate I would like to proceed to a “by hand” verification.

Below is a ‘academic example’ but it can help to setup the code I need, so I think if we can manage for this simple example, I guess I can transpose for my more complex one. Thanks in advance for your help.

So, here is the simple problem snippet

import numpy as np
from jax.config import config
config.update("jax_enable_x64", True)

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS, init_to_sample
numpyro.util.enable_x64()

#########
#generation of mock data
#########
param_true = np.array([1.0, 0.0, 0.2, 0.5, 1.5])
sample_size = 5_000
sigma_e = param_true[4] 
random_num_generator = np.random.RandomState(0)
xi = 5*random_num_generator.rand(sample_size)-2.5
e = random_num_generator.normal(0, sigma_e, sample_size)
yi = param_true[0] + param_true[1] * xi + param_true[2] * xi**2 + param_true[3] *xi**3 +  e  
plt.hist2d(xi, yi, bins=50);

#######
# Numpyro model
#######

def my_model(Xspls,Yspls=None, sigma=sigma_e):
    a0 = numpyro.sample('a0', dist.Normal(0.,10.))
    a1 = numpyro.sample('a1', dist.Normal(0.,10.))
    a2 = numpyro.sample('a2', dist.Normal(0.,10.))
    a3 = numpyro.sample('a3', dist.Normal(0.,10.))

    mu = a0 + a1*Xspls + a2*Xspls**2 + a3*Xspls**3

    return numpyro.sample('obs', dist.Normal(mu, sigma), obs=Yspls)

Then, I perform a SVI based on a MVN (auto) guide

import numpyro.infer.autoguide as autoguide
from numpyro.infer import Predictive, SVI, Trace_ELBO
from numpyro.optim import Adam

guide = autoguide.AutoMultivariateNormal(my_model, init_loc_fn=numpyro.infer.init_to_sample())
optimizer = numpyro.optim.Adam(step_size=5e-3)
svi = SVI(my_model, guide,optimizer,loss=Trace_ELBO())
svi_result = svi.run(jax.random.PRNGKey(0), 10000, Xspls=xi, Yspls=yi)

The convergence seems ok, but let us try to perform a “by hand” computation
We can perform a 1-sample generation

from  numpyro.handlers import trace, seed, substitute, replay

one_sample = guide.sample_posterior(jax.random.PRNGKey(0), params=svi_result.params, sample_shape=())

one_sample
{'a0': DeviceArray(1.01964425, dtype=float64),
 'a1': DeviceArray(0.02260495, dtype=float64),
 'a2': DeviceArray(0.19142779, dtype=float64),
 'a3': DeviceArray(0.50435349, dtype=float64)}
with trace() as guide_trace, \
         seed(rng_seed=jax.random.PRNGKey(0)),\
         substitute(data=one_sample),\
         substitute(data=svi_result.params):
    guide(Xspls=xi, Yspls=yi)

with trace() as model_trace,\
         seed(rng_seed=jax.random.PRNGKey(0)), \
         replay(trace=guide_trace), \
         substitute(data=svi_result.params):
    my_model(Xspls=xi, Yspls=yi)

But I am stuck here are the result of the tracing:

  1. first for the guide_trace
OrderedDict([('auto_loc',
              {'type': 'param',
               'name': 'auto_loc',
               'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
               'args': (DeviceArray([-12.28707488,  -3.68158543,  -5.16809691,  -1.05775531], dtype=float64),),
               'kwargs': {},
               'value': DeviceArray([1.01701011, 0.03684789, 0.18448   , 0.49643362], dtype=float64),
               'scale': None,
               'cond_indep_stack': []}),
             ('auto_scale_tril',
              {'type': 'param',
               'name': 'auto_scale_tril',
               'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
               'args': (DeviceArray([[0.1, 0. , 0. , 0. ],
                             [0. , 0.1, 0. , 0. ],
                             [0. , 0. , 0.1, 0. ],
                             [0. , 0. , 0. , 0.1]], dtype=float64),),
               'kwargs': {'constraint': <numpyro.distributions.constraints._ScaledUnitLowerCholesky at 0x7fa63d090610>},
               'value': DeviceArray([[ 4.38151238e-02,  0.00000000e+00,  0.00000000e+00,
                              0.00000000e+00],
                            [ 3.86819831e-04,  2.65221321e-02,  0.00000000e+00,
                              0.00000000e+00],
                            [-8.11780194e-03, -1.12051544e-05,  1.89843813e-02,
                              0.00000000e+00],
                            [-4.61649160e-05, -4.21053256e-03,  6.07590163e-05,
                              6.50359953e-03]], dtype=float64),
               'scale': None,
               'cond_indep_stack': []}),
             ('_auto_latent',
              {'type': 'sample',
               'name': '_auto_latent',
               'fn': <numpyro.distributions.continuous.MultivariateNormal at 0x7fa5f44440a0>,
               'args': (),
               'kwargs': {'rng_key': array([2718843009, 1272950319], dtype=uint32),
                'sample_shape': ()},
               'value': DeviceArray([1.01964425, 0.02260495, 0.19142779, 0.50435349], dtype=float64),
               'scale': None,
               'is_observed': False,
               'intermediates': [],
               'cond_indep_stack': [],
               'infer': {'is_auxiliary': True}}),
             ('a0',
              {'type': 'sample',
               'name': 'a0',
               'fn': <numpyro.distributions.distribution.Delta at 0x7fa62c6e8a30>,
               'args': (),
               'kwargs': {'rng_key': None, 'sample_shape': ()},
               'value': DeviceArray(1.01964425, dtype=float64),
               'scale': None,
               'is_observed': False,
               'intermediates': [],
               'cond_indep_stack': [],
               'infer': {}}),
             ('a1',
              {'type': 'sample',
               'name': 'a1',
               'fn': <numpyro.distributions.distribution.Delta at 0x7fa5f44446d0>,
               'args': (),
               'kwargs': {'rng_key': None, 'sample_shape': ()},
               'value': DeviceArray(0.02260495, dtype=float64),
               'scale': None,
               'is_observed': False,
               'intermediates': [],
               'cond_indep_stack': [],
               'infer': {}}),
             ('a2',
              {'type': 'sample',
               'name': 'a2',
               'fn': <numpyro.distributions.distribution.Delta at 0x7fa5f4458be0>,
               'args': (),
               'kwargs': {'rng_key': None, 'sample_shape': ()},
               'value': DeviceArray(0.19142779, dtype=float64),
               'scale': None,
               'is_observed': False,
               'intermediates': [],
               'cond_indep_stack': [],
               'infer': {}}),
             ('a3',
              {'type': 'sample',
               'name': 'a3',
               'fn': <numpyro.distributions.distribution.Delta at 0x7fa5f443cdc0>,
               'args': (),
               'kwargs': {'rng_key': None, 'sample_shape': ()},
               'value': DeviceArray(0.50435349, dtype=float64),
               'scale': None,
               'is_observed': False,
               'intermediates': [],
               'cond_indep_stack': [],
               'infer': {}})])
  1. and for the model_trace, I get
OrderedDict([('a0',
              {'type': 'sample',
               'name': 'a0',
               'fn': <numpyro.distributions.continuous.Normal at 0x7fa5d0660c40>,
               'args': (),
               'kwargs': {'rng_key': None, 'sample_shape': ()},
               'value': DeviceArray(1.01964425, dtype=float64),
               'scale': None,
               'is_observed': False,
               'intermediates': [],
               'cond_indep_stack': [],
               'infer': {}}),
             ('a1',
              {'type': 'sample',
               'name': 'a1',
               'fn': <numpyro.distributions.continuous.Normal at 0x7fa5d0660190>,
               'args': (),
               'kwargs': {'rng_key': None, 'sample_shape': ()},
               'value': DeviceArray(0.02260495, dtype=float64),
               'scale': None,
               'is_observed': False,
               'intermediates': [],
               'cond_indep_stack': [],
               'infer': {}}),
             ('a2',
              {'type': 'sample',
               'name': 'a2',
               'fn': <numpyro.distributions.continuous.Normal at 0x7fa5d0660040>,
               'args': (),
               'kwargs': {'rng_key': None, 'sample_shape': ()},
               'value': DeviceArray(0.19142779, dtype=float64),
               'scale': None,
               'is_observed': False,
               'intermediates': [],
               'cond_indep_stack': [],
               'infer': {}}),
             ('a3',
              {'type': 'sample',
               'name': 'a3',
               'fn': <numpyro.distributions.continuous.Normal at 0x7fa5d06600d0>,
               'args': (),
               'kwargs': {'rng_key': None, 'sample_shape': ()},
               'value': DeviceArray(0.50435349, dtype=float64),
               'scale': None,
               'is_observed': False,
               'intermediates': [],
               'cond_indep_stack': [],
               'infer': {}}),
             ('obs',
              {'type': 'sample',
               'name': 'obs',
               'fn': <numpyro.distributions.continuous.Normal at 0x7fa62c6df3d0>,
               'args': (),
               'kwargs': {'rng_key': None, 'sample_shape': ()},
               'value': array([-1.11367224,  0.45030923,  0.82579152, ..., -1.6035713 ,
                       1.05997719,  7.81216947]),
               'scale': None,
               'is_observed': True,
               'intermediates': [],
               'cond_indep_stack': [],
               'infer': {}})])

How I can exploit these information to compute and verify the ELBO loss?

For each site in the trace, you can do site["fn"].log_prob(site["value"]).sum() to compute the log probability of that site. You can use the formula here (there p is your model, q is your guide) to compute ELBO - basically you compute the sum of log probs of all sites in the model then minus the sum of log probs of all sites in the guide. Note that ELBO loss is the negative of ELBO.

Hi @fehiepsi Thanks for your guide lines.

Does the relevant “sites” (this word sounds very strange for a French guy) are

  • ‘a0’, ‘a1’, ‘a2’,‘a3’ and ‘obs’ for the model_trace to compute priors+likelihood
  • ‘_auto_latent’ for the guide_trace