Is it possible to save SVI.run results?

Hello,

I am trying to save the model artifacts of my SVI session so I can run an inference process later on with a different script. I’m struggling to save the guide (as this is the object used in the inference process).

My first approach was to try use pickle to save the guide, but it does not support saving some JAX components. Dill, JSON, or CloudPickle do not seem to work either:

optimizer = Adam(constants.LEARNING_RATE)
hbd_guide = AutoGuide(hbd_model)
svi = SVI(hbd_model, hbd_guide, optimizer, loss=Trace_ELBO())

hbd_results = svi.run(
    random.PRNGKey(constants.RNG_SEED),
    n_epochs,
    data,
    train=True,
    stable_update=True,
)

hbd_results_json = {}

# Losses
hbd_results_json["losses"] = hbd_results.losses.to_py().tolist()

# Params
hbd_results_json["params"] = {}
for key in hbd_results.params.keys():
    hbd_results_json["params"][key] = hbd_results.params[key].to_py().tolist()

# Model
hbd_results_json["hbd_model"] = hbd_model
hbd_results_json["hbd_guide"] = hbd_guide

with open(hbd_results_json_path_test, 'wb') as handle:
    pickle.dump(hbd_results_json, handle, protocol=pickle.HIGHEST_PROTOCOL)

I also have tried to save the model (there are no issues to save the model) and then reload the guide based on the model but that does not work either. It seems that during the training, the guide is modified to include a plates field, which I do not get it just by running:

with open(hbd_results_json_path_test, 'rb') as f:
    hbd_results_load = pickle.load(f)

for key in hbd_results_load["params"].keys():
    hbd_results_load["params"][key] = jnp.array(hbd_results_load["params"][key])

## Get guide_test
guide_test = AutoNormal(hbd_results_load["hbd_model"])

## Get posterior samples of latent variables
hbd_posterior = Predictive(
    guide_test, params=hbd_results_load["params"], num_samples=n_samples #hbd_guide
)
hbd_samples = hbd_posterior(random.PRNGKey(constants.RNG_SEED))


## Generate model fit -- train sample & test sample
hbd_predict = Predictive(
    model=hbd_model,
    posterior_samples=hbd_samples,
    guide=hbd_guide,
    params=hbd_results_load["params"],
    num_samples=n_samples,
)

This outputs an error when it tries to run the hbd_posterior function:

TypeError: hbd_model() missing 1 required positional arguments: 'data'

I have looked for information on how we can separate training from inference using SVI, but there is no much documentation on how to do this on numpyro. All examples I’ve seen have training and inferencing in the same script so there is no need to save the model. Saving the model after training for a later inference process should not be such headache, so it might be I’m just using the wrong tools/code.

Has anyone encounter this issue?

Thanks in advance!

1 Like

The behavior that you observed is a bit strange. Can you run hdb_posterior without pickling?

To save svi_result state, you can install optax and use

optimizer = optax.adam(learning_rate)

import pickle

with open("result.pkl", "wb") as f:
    pickle.dump(svi_result, f)

with open("result.pkl", "rb") as g:
    svi_result = pickle.load(g)

It requires more code (along the line that you did) to save svi_result with numpyro.optim.Adam. Either way, I think your error hbd_model() missing 1 required positional arguments: 'data' happens because you didn’t provide data for your model. You might need this pattern

def model(data=None):
    numpyro.sample(..., obs=data)

so you can skip providing it for predictive.

Hi @fehiepsi,

Thank you very much for your fast response. Using that optimizer + upgrading to the latest numpyro version allowed me to save the pickle file. That said, when I loaded and tried to call hbd_posterior, I got a weird error that does not provide too much information:

NotImplementedError:

I am attaching the trace I get… Not sure if this could be because the guide was not correctly saved (if I substitute guide with the guide I got from the training I don’t get the error). Do you have any thoughts on what could be happening?

Thanks again!

Error:

with open("result.pkl", 'rb') as f:
    model, guide, svi_result = pickle.load(f)

hbd_posterior = Predictive(
    guide, params=svi_result.params, num_samples=n_samples #hbd_guide
)
hbd_samples = hbd_posterior(random.PRNGKey(constants.RNG_SEED))
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
/tmp/ipykernel_14574/213312458.py in <module>
----> 1 hbd_samples = hbd_posterior(random.PRNGKey(constants.RNG_SEED))

~/.local/lib/python3.9/site-packages/numpyro/infer/util.py in __call__(self, rng_key, *args, **kwargs)
    959         """
    960         if self.batch_ndims == 0 or self.params == {} or self.guide is None:
--> 961             return self._call_with_params(rng_key, self.params, args, kwargs)
    962         elif self.batch_ndims == 1:  # batch over parameters
    963             batch_size = jnp.shape(tree_flatten(self.params)[0][0])[0]

~/.local/lib/python3.9/site-packages/numpyro/infer/util.py in _call_with_params(self, rng_key, params, args, kwargs)
    936             )
    937         model = substitute(self.model, self.params)
--> 938         return _predictive(
    939             rng_key,
    940             model,

~/.local/lib/python3.9/site-packages/numpyro/infer/util.py in _predictive(rng_key, model, posterior_samples, batch_shape, return_sites, infer_discrete, parallel, model_args, model_kwargs)
    773     rng_key = rng_key.reshape((*batch_shape, 2))
    774     chunk_size = num_samples if parallel else 1
--> 775     return soft_vmap(
    776         single_prediction, (rng_key, posterior_samples), len(batch_shape), chunk_size
    777     )

~/.local/lib/python3.9/site-packages/numpyro/util.py in soft_vmap(fn, xs, batch_ndims, chunk_size)
    408         fn = vmap(fn)
    409 
--> 410     ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
    411     map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
    412     ys = tree_map(

    [... skipping hidden 16 frame]

~/.local/lib/python3.9/site-packages/numpyro/infer/util.py in single_prediction(val)
    747             )
    748         else:
--> 749             model_trace = trace(
    750                 seed(substitute(masked_model, samples), rng_key)
    751             ).get_trace(*model_args, **model_kwargs)

~/.local/lib/python3.9/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
    169         :return: `OrderedDict` containing the execution trace.
    170         """
--> 171         self(*args, **kwargs)
    172         return self.trace
    173 

~/.local/lib/python3.9/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
    103             return self
    104         with self:
--> 105             return self.fn(*args, **kwargs)
    106 
    107 

~/.local/lib/python3.9/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
    103             return self
    104         with self:
--> 105             return self.fn(*args, **kwargs)
    106 
    107 

~/.local/lib/python3.9/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
    103             return self
    104         with self:
--> 105             return self.fn(*args, **kwargs)
    106 
    107 

~/.local/lib/python3.9/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
    103             return self
    104         with self:
--> 105             return self.fn(*args, **kwargs)
    106 
    107 

~/.local/lib/python3.9/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
    103             return self
    104         with self:
--> 105             return self.fn(*args, **kwargs)
    106 
    107 

~/.local/lib/python3.9/site-packages/numpyro/infer/autoguide.py in __call__(self, *args, **kwargs)
    308 
    309                 site_fn = dist.Normal(site_loc, site_scale).to_event(event_dim)
--> 310                 if site["fn"].support is constraints.real or (
    311                     isinstance(site["fn"].support, constraints.independent)
    312                     and site["fn"].support.base_constraint is constraints.real

~/.local/lib/python3.9/site-packages/numpyro/distributions/distribution.py in support(self)
    986     @property
    987     def support(self):
--> 988         codomain = self.transforms[-1].codomain
    989         codomain_event_dim = codomain.event_dim
    990         assert self.event_dim >= codomain_event_dim

~/.local/lib/python3.9/site-packages/numpyro/distributions/transforms.py in codomain(self)
    195                 )
    196         else:
--> 197             raise NotImplementedError
    198 
    199     def __call__(self, x):

NotImplementedError: 

I think this is likely related to your model. I guess you are using some transform that does not have codomain defined. Could you provide your model?

In addition to what @fehiepsi suggested can you check the following thread and run the example there to see if that’s work on your machine. The versions of numpyro and dill packages are also provided at the end of the thread.

1 Like

Thank you both for your responses, they are very useful as I navigate through this issue.

@mahdik the upgrade from numpyro 0.8 to 0.9 was key to save the pickle file without issues, so I appreciate you sharing your working version + JAX + dill + numpy.

@fehiepsi due to an NDA with the client I can’t share the full model structure, but I believe that what you outlined makes sense as I’m using Affine Transforms in the model. Here is part of the code to show you how I’m using it:

import numpyro
import numpyro.distributions as dist
class Plate:
    ...
    store = "plate_store"

def model(...,data,...):
    ...

    attr_store = get_store_attributes(data)
    plate_store = numpyro.plate(Plate.store, size=n_stores, dim=-1)

    mu_store_beta = numpyro.sample('mu_store_beta', dist.Normal(loc=0.0, scale=1.0), sample_shape=(n_features_store,))
    mu_store_sigma = numpyro.sample('mu_store_sigma', dist.HalfNormal(scale=1.0))

    nu_store_beta = numpyro.sample('nu_store_beta', dist.Normal(loc=0.0, scale=1.0), sample_shape=(n_features_store,))
    nu_store_sigma = numpyro.sample('nu_store_sigma', dist.HalfNormal(scale=1.0))

    with plate_store:
            mu_store = numpyro.sample('mu_store', dist.TransformedDistribution(
                dist.Normal(loc=jnp.zeros((n_stores,)), scale=1.0),
                dist.transforms.AffineTransform(attr_store @ mu_store_beta, mu_store_sigma)))
            nu_store = numpyro.sample('nu_store', dist.TransformedDistribution(
                dist.Normal(loc=jnp.zeros((n_stores,)), scale=1.0),
                dist.transforms.AffineTransform(attr_store @ nu_store_beta, nu_store_sigma)))

I have seen in the AffineTransform definition that it will output the NotImplementedError when the domain is not defined.

Could this mean that when we save the pickle file, the domain is not saved correctly? Maybe that information gets lost somewhere in the pickle saving process.

I’m not sure what’s happening. Your code looks reasonable to me. Does prediction still work if you don’t pickle? Maybe it is better to come up with a small reproducible code (a simple model with Affine transform and fake data) for the error.

Maybe not the most elegant solution to the problem, but in the end I decided to save the hbd_samples as a model output. Then I load it back and inference using those samples (in numpyro 0.9 the guide is no longer needed if you have the posterior samples).

This approach works fine - the only limitation is that the number of samples you can produce is limited by the num_samples you decide before saving the training outputs. As I’m not changing the number of samples, this solution worked for me.

Thank you all for the support, I really appreciate it!