Thanks.

pickle
gave the following error: Can't get attribute 'model' on <module '__main__'>
 For anyone who ended up here, here is the simple linear regression with SVI which we can save the model for later to draw from posteriors and plots and so on. Here I used
dill
import numpy as np
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import Predictive, SVI, autoguide, init_to_sample, NUTS, MCMC
from numpyro.infer import Trace_ELBO
import dill
numpyro.set_platform('cpu')
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
size = 200
true_intercept = 1
true_slope = 2
x = jnp.linspace(0, 1, size).reshape(1,1)
# y = intercept + w*x
true_regression_line = true_intercept + true_slope * x
# add noise
y = true_regression_line + random.normal(rng_key, shape=(size,1))
def model(x=None, y=None):
# Model withour subsample
sigma = numpyro.sample("sigma", dist.HalfCauchy(scale=10))
intercept = numpyro.sample("Intercept", dist.Normal(loc=0, scale=20))
w = numpyro.sample("w", dist.Normal(loc=0, scale=20))
with numpyro.plate("data", x.shape[0]):
numpyro.sample("obs", dist.Normal(intercept+w*x, scale=sigma).to_event(1), obs=y)
# SVI for model without subsample
guide = autoguide.AutoNormal(model, init_loc_fn=init_to_sample)
optimizer = numpyro.optim.Adam(step_size=1e3)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(rng_key, 10000, x, y)
output_dict = {}
output_dict['model']=model
output_dict['guide']=guide
output_dict['params']=svi_result.params
with open('file.pkl', 'wb') as handle:
dill.dump(output_dict, handle)
 Then later in the postprocessing step one can use:
with open('file.pkl', 'rb') as f:
input_dict = dill.load(f)
model = input_dict['model']
guide = input_dict['guide']
params = input_dict['params']
# draw some posterior predictive
# Posterior Predictive
num_sample_dist = 200
posterior_predictive = Predictive(model=model, guide=guide,
params=params,
num_samples=num_sample_dist)
y_pred = posterior_predictive(rng_key_predict, x=x)['obs']
 One might want to
dill
the x
and y
data as well.