Error with init_state when using customized SVI run utility

Hi there,

I wrote a customized SVI run utility with Numpyro’s as a reference and the model can be successfully fitted with my utility.

However, when I pass in the “svi_result.state” as an initial state for another run, I get the following error message:

TypeError                                 Traceback (most recent call last)
Input In [88], in <cell line: 1>()
----> 1 svi_result2 = mnl_run_init(svi2, rng_key, model, guide2, dataset, num_steps2, batch, num_samples, progress_bar=progress_bar, target_hit_rate=target_hit_rate, patience=patience, return_best=return_best, init_state=init_state2)

Input In [85], in mnl_run_init(svi, rng_key, model, guide, dataset, num_steps, batch, num_samples, progress_bar, target_hit_rate, patience, return_best, init_state)
     43 with trange(1, num_steps + 1) as t:
     44     for i in t:
---> 45         svi_state, loss = jit(body_fn)(svi_state, None)
     46         losses.append(loss)
     47         if i % batch == 0:

    [... skipping hidden 14 frame]

Input In [85], in mnl_run_init.<locals>.body_fn(svi_state, _)
     31 def body_fn(svi_state, _):
---> 32     svi_state, loss = svi.update(svi_state, dataset)
     33     return svi_state, loss

File ~/PyVENVs/pytorch-py38/lib/python3.8/site-packages/numpyro/infer/, in SVI.update(self, svi_state, *args, **kwargs)
    242 rng_key, rng_key_step = random.split(svi_state.rng_key)
    243 loss_fn = _make_loss_fn(
    244     self.loss,
    245     rng_key_step,
    252     mutable_state=svi_state.mutable_state,
    253 )
--> 254 (loss_val, mutable_state), optim_state = self.optim.eval_and_update(
    255     loss_fn, svi_state.optim_state
    256 )
    257 return SVIState(optim_state, mutable_state, rng_key), loss_val

File ~/PyVENVs/pytorch-py38/lib/python3.8/site-packages/numpyro/, in _NumPyroOptim.eval_and_update(self, fn, state)
     72 """
     73 Performs an optimization step for the objective function `fn`.
     74 For most optimizers, the update is performed based on the gradient
     84 :return: a pair of the output of objective function and the new optimizer state.
     85 """
     86 params = self.get_params(state)
---> 87 (out, aux), grads = value_and_grad(fn, has_aux=True)(params)
     88 return (out, aux), self.update(grads, state)

    [... skipping hidden 8 frame]

File ~/PyVENVs/pytorch-py38/lib/python3.8/site-packages/numpyro/infer/, in _make_loss_fn.<locals>.loss_fn(params)
     57 def loss_fn(params):
---> 58     params = constrain_fn(params)
     59     if mutable_state is not None:
     60         params.update(mutable_state)

TypeError: 'NoneType' object is not callable

With the error trace, it seems that the ‘constrain_fn’ is not initialized.

I want to know 1) what’s the potential reason behind? 2) how can I fix it?

Many Thanks!

Here is my run-specification code:

guide2 = AutoNormal(model)
optimizer2 = numpyro.optim.Adam(exponential_decay(1e-4, 100000, 0.8))
num_steps2 = 1000000
svi2 = SVI(model, guide2, optimizer2, loss=Trace_ELBO())
init_state2 = svi_result.state 

batch = 10000
num_samples = 2000
progress_bar = True
target_hit_rate = 1.0
patience = 5
return_best = True

svi_result2 = mnl_run_init(svi2, rng_key, model, guide2, dataset, num_steps2, batch, num_samples, progress_bar=progress_bar, target_hit_rate=target_hit_rate, patience=patience, return_best=return_best, init_state=init_state2)

Did you run svi.init?

Yes. My init code just looks like that in

if init_state is None:
    svi_state = svi.init(rng_key, dataset)
    svi_state = init_state

I think you need to run svi.init(...) first to setup svi. Then you can provide your own state in The code in your last comment seems not run svi.init.

Got it. I will try it out.
Thank you very much!

Update: the problem sloved by adding a svi.init(...) statement to init-with-state branch of my customized run utility.

Many thanks for your help!