Hello,
In a complex example performing a SVI on a model, I have some doubts on the ELBO computation for a MVN guide. Certainly due to numerical instabilities. But to investigate I would like to proceed to a “by hand” verification.
Below is a ‘academic example’ but it can help to setup the code I need, so I think if we can manage for this simple example, I guess I can transpose for my more complex one. Thanks in advance for your help.
So, here is the simple problem snippet
import numpy as np
from jax.config import config
config.update("jax_enable_x64", True)
import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS, init_to_sample
numpyro.util.enable_x64()
#########
#generation of mock data
#########
param_true = np.array([1.0, 0.0, 0.2, 0.5, 1.5])
sample_size = 5_000
sigma_e = param_true[4]
random_num_generator = np.random.RandomState(0)
xi = 5*random_num_generator.rand(sample_size)-2.5
e = random_num_generator.normal(0, sigma_e, sample_size)
yi = param_true[0] + param_true[1] * xi + param_true[2] * xi**2 + param_true[3] *xi**3 + e
plt.hist2d(xi, yi, bins=50);
#######
# Numpyro model
#######
def my_model(Xspls,Yspls=None, sigma=sigma_e):
a0 = numpyro.sample('a0', dist.Normal(0.,10.))
a1 = numpyro.sample('a1', dist.Normal(0.,10.))
a2 = numpyro.sample('a2', dist.Normal(0.,10.))
a3 = numpyro.sample('a3', dist.Normal(0.,10.))
mu = a0 + a1*Xspls + a2*Xspls**2 + a3*Xspls**3
return numpyro.sample('obs', dist.Normal(mu, sigma), obs=Yspls)
Then, I perform a SVI based on a MVN (auto) guide
import numpyro.infer.autoguide as autoguide
from numpyro.infer import Predictive, SVI, Trace_ELBO
from numpyro.optim import Adam
guide = autoguide.AutoMultivariateNormal(my_model, init_loc_fn=numpyro.infer.init_to_sample())
optimizer = numpyro.optim.Adam(step_size=5e-3)
svi = SVI(my_model, guide,optimizer,loss=Trace_ELBO())
svi_result = svi.run(jax.random.PRNGKey(0), 10000, Xspls=xi, Yspls=yi)
The convergence seems ok, but let us try to perform a “by hand” computation
We can perform a 1-sample generation
from numpyro.handlers import trace, seed, substitute, replay
one_sample = guide.sample_posterior(jax.random.PRNGKey(0), params=svi_result.params, sample_shape=())
one_sample
{'a0': DeviceArray(1.01964425, dtype=float64),
'a1': DeviceArray(0.02260495, dtype=float64),
'a2': DeviceArray(0.19142779, dtype=float64),
'a3': DeviceArray(0.50435349, dtype=float64)}
with trace() as guide_trace, \
seed(rng_seed=jax.random.PRNGKey(0)),\
substitute(data=one_sample),\
substitute(data=svi_result.params):
guide(Xspls=xi, Yspls=yi)
with trace() as model_trace,\
seed(rng_seed=jax.random.PRNGKey(0)), \
replay(trace=guide_trace), \
substitute(data=svi_result.params):
my_model(Xspls=xi, Yspls=yi)
But I am stuck here are the result of the tracing:
- first for the
guide_trace
OrderedDict([('auto_loc',
{'type': 'param',
'name': 'auto_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (DeviceArray([-12.28707488, -3.68158543, -5.16809691, -1.05775531], dtype=float64),),
'kwargs': {},
'value': DeviceArray([1.01701011, 0.03684789, 0.18448 , 0.49643362], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('auto_scale_tril',
{'type': 'param',
'name': 'auto_scale_tril',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (DeviceArray([[0.1, 0. , 0. , 0. ],
[0. , 0.1, 0. , 0. ],
[0. , 0. , 0.1, 0. ],
[0. , 0. , 0. , 0.1]], dtype=float64),),
'kwargs': {'constraint': <numpyro.distributions.constraints._ScaledUnitLowerCholesky at 0x7fa63d090610>},
'value': DeviceArray([[ 4.38151238e-02, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00],
[ 3.86819831e-04, 2.65221321e-02, 0.00000000e+00,
0.00000000e+00],
[-8.11780194e-03, -1.12051544e-05, 1.89843813e-02,
0.00000000e+00],
[-4.61649160e-05, -4.21053256e-03, 6.07590163e-05,
6.50359953e-03]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('_auto_latent',
{'type': 'sample',
'name': '_auto_latent',
'fn': <numpyro.distributions.continuous.MultivariateNormal at 0x7fa5f44440a0>,
'args': (),
'kwargs': {'rng_key': array([2718843009, 1272950319], dtype=uint32),
'sample_shape': ()},
'value': DeviceArray([1.01964425, 0.02260495, 0.19142779, 0.50435349], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {'is_auxiliary': True}}),
('a0',
{'type': 'sample',
'name': 'a0',
'fn': <numpyro.distributions.distribution.Delta at 0x7fa62c6e8a30>,
'args': (),
'kwargs': {'rng_key': None, 'sample_shape': ()},
'value': DeviceArray(1.01964425, dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('a1',
{'type': 'sample',
'name': 'a1',
'fn': <numpyro.distributions.distribution.Delta at 0x7fa5f44446d0>,
'args': (),
'kwargs': {'rng_key': None, 'sample_shape': ()},
'value': DeviceArray(0.02260495, dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('a2',
{'type': 'sample',
'name': 'a2',
'fn': <numpyro.distributions.distribution.Delta at 0x7fa5f4458be0>,
'args': (),
'kwargs': {'rng_key': None, 'sample_shape': ()},
'value': DeviceArray(0.19142779, dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('a3',
{'type': 'sample',
'name': 'a3',
'fn': <numpyro.distributions.distribution.Delta at 0x7fa5f443cdc0>,
'args': (),
'kwargs': {'rng_key': None, 'sample_shape': ()},
'value': DeviceArray(0.50435349, dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}})])
- and for the
model_trace
, I get
OrderedDict([('a0',
{'type': 'sample',
'name': 'a0',
'fn': <numpyro.distributions.continuous.Normal at 0x7fa5d0660c40>,
'args': (),
'kwargs': {'rng_key': None, 'sample_shape': ()},
'value': DeviceArray(1.01964425, dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('a1',
{'type': 'sample',
'name': 'a1',
'fn': <numpyro.distributions.continuous.Normal at 0x7fa5d0660190>,
'args': (),
'kwargs': {'rng_key': None, 'sample_shape': ()},
'value': DeviceArray(0.02260495, dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('a2',
{'type': 'sample',
'name': 'a2',
'fn': <numpyro.distributions.continuous.Normal at 0x7fa5d0660040>,
'args': (),
'kwargs': {'rng_key': None, 'sample_shape': ()},
'value': DeviceArray(0.19142779, dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('a3',
{'type': 'sample',
'name': 'a3',
'fn': <numpyro.distributions.continuous.Normal at 0x7fa5d06600d0>,
'args': (),
'kwargs': {'rng_key': None, 'sample_shape': ()},
'value': DeviceArray(0.50435349, dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('obs',
{'type': 'sample',
'name': 'obs',
'fn': <numpyro.distributions.continuous.Normal at 0x7fa62c6df3d0>,
'args': (),
'kwargs': {'rng_key': None, 'sample_shape': ()},
'value': array([-1.11367224, 0.45030923, 0.82579152, ..., -1.6035713 ,
1.05997719, 7.81216947]),
'scale': None,
'is_observed': True,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}})])
How I can exploit these information to compute and verify the ELBO loss?