The correct approach to save the SVI results

Hi All,

What is the efficient and correct approach to save the SVI results in a file for:

  1. Load the file to draw samples from guide/model as the post-processing step.
  2. Continue training the model.

For the MCMC I used arviz interface and I was able to do all my plots after training from the netcdf file.


I think you can pickle the guide and the svi_state, then run svi with the init state is your saved state. You can use Predictive to draw samples and use arviz to store the samples in desired format. Let me know if something is not clear to you.


  • pickle gave the following error: Can't get attribute 'model' on <module '__main__'>
  • For anyone who ended up here, here is the simple linear regression with SVI which we can save the model for later to draw from posteriors and plots and so on. Here I used dill
import numpy as np
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import Predictive, SVI, autoguide, init_to_sample, NUTS, MCMC
from numpyro.infer import Trace_ELBO
import dill
rng_key, rng_key_predict = random.split(random.PRNGKey(0))

size = 200
true_intercept = 1
true_slope = 2

x = jnp.linspace(0, 1, size).reshape(-1,1)
# y = intercept + w*x
true_regression_line = true_intercept + true_slope * x
# add noise
y = true_regression_line + random.normal(rng_key, shape=(size,1))

def model(x=None, y=None):
    # Model withour subsample
    sigma = numpyro.sample("sigma", dist.HalfCauchy(scale=10))
    intercept = numpyro.sample("Intercept", dist.Normal(loc=0, scale=20))
    w = numpyro.sample("w", dist.Normal(loc=0, scale=20))  
    with numpyro.plate("data", x.shape[0]):
        numpyro.sample("obs", dist.Normal(intercept+w*x, scale=sigma).to_event(1), obs=y)
# SVI for model without subsample
guide = autoguide.AutoNormal(model, init_loc_fn=init_to_sample)
optimizer = numpyro.optim.Adam(step_size=1e-3)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result =, 10000, x, y)

output_dict = {}
with open('file.pkl', 'wb') as handle:
    dill.dump(output_dict, handle)
  • Then later in the postprocessing step one can use:
with open('file.pkl', 'rb') as f:
    input_dict = dill.load(f)
model = input_dict['model']
guide = input_dict['guide']
params = input_dict['params']
# draw some posterior predictive
# Posterior Predictive
num_sample_dist = 200
posterior_predictive = Predictive(model=model, guide=guide,
y_pred = posterior_predictive(rng_key_predict, x=x)['obs']
  • One might want to dill the x and y data as well.

Thanks for sharing, @mahdik!

1 Like

Hi @mahdik thank you for sharing that code. I am trying to do something similar but dill is unable to pickle the guide. May I ask what version of numpyro your example used? I am on 0.8.0 so maybe I need to upgrade

I think you need to install the latest version (which seems to fix some pickle issues). We have tests to cover pickling autoguides but we might miss something. For example, we don’t need dill for the above model

optimizer = optax.adam(1e-3)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result =, 10000, x, y)

import pickle

with open("result.pkl", "wb") as f:
    pickle.dump((model, guide, svi_result), f)

with open("result.pkl", "rb") as g:
    model, guide, svi_result = pickle.load(g)

# draw some posterior predictive
# Posterior Predictive
num_sample_dist = 200
posterior_predictive = Predictive(model=model, guide=guide,
y_pred = posterior_predictive(rng_key_predict, x=x)['obs']
1 Like

Hi @shatfield here is the relevant package versions in my environment which the above example works.

numpy:   '1.21.5'
jax:     '0.2.19'
numpyro: '0.9.0'
dill:    '0.3.4'

Hope that helps.

1 Like

I came back to this topic today:

  • I can confirm the method provided by @fehiepsi is working if I use optax for optimization.
  • However, using pickle to save the SVI inference done with the native numpyro optimizer
    optimizer = numpyro.optim.Adam(step_size=1e-3)
    gave me the following error.
cannot pickle 'jaxlib.xla_extension.pytree.PyTreeDef' object

Here are summary of my installed library:

numpyro: 0.10.1
jax    : 0.3.17
optax  : 0.1.3
numpy  : 1.23.3