Error with init_state when using customized SVI run utility

Hi there,

I wrote a customized SVI run utility with Numpyro’s SVI.run() as a reference and the model can be successfully fitted with my utility.

However, when I pass in the “svi_result.state” as an initial state for another run, I get the following error message:

TypeError                                 Traceback (most recent call last)
Input In [88], in <cell line: 1>()
----> 1 svi_result2 = mnl_run_init(svi2, rng_key, model, guide2, dataset, num_steps2, batch, num_samples, progress_bar=progress_bar, target_hit_rate=target_hit_rate, patience=patience, return_best=return_best, init_state=init_state2)

Input In [85], in mnl_run_init(svi, rng_key, model, guide, dataset, num_steps, batch, num_samples, progress_bar, target_hit_rate, patience, return_best, init_state)
     43 with trange(1, num_steps + 1) as t:
     44     for i in t:
---> 45         svi_state, loss = jit(body_fn)(svi_state, None)
     46         losses.append(loss)
     47         if i % batch == 0:

    [... skipping hidden 14 frame]

Input In [85], in mnl_run_init.<locals>.body_fn(svi_state, _)
     31 def body_fn(svi_state, _):
---> 32     svi_state, loss = svi.update(svi_state, dataset)
     33     return svi_state, loss

File ~/PyVENVs/pytorch-py38/lib/python3.8/site-packages/numpyro/infer/svi.py:254, in SVI.update(self, svi_state, *args, **kwargs)
    242 rng_key, rng_key_step = random.split(svi_state.rng_key)
    243 loss_fn = _make_loss_fn(
    244     self.loss,
    245     rng_key_step,
   (...)
    252     mutable_state=svi_state.mutable_state,
    253 )
--> 254 (loss_val, mutable_state), optim_state = self.optim.eval_and_update(
    255     loss_fn, svi_state.optim_state
    256 )
    257 return SVIState(optim_state, mutable_state, rng_key), loss_val

File ~/PyVENVs/pytorch-py38/lib/python3.8/site-packages/numpyro/optim.py:87, in _NumPyroOptim.eval_and_update(self, fn, state)
     72 """
     73 Performs an optimization step for the objective function `fn`.
     74 For most optimizers, the update is performed based on the gradient
   (...)
     84 :return: a pair of the output of objective function and the new optimizer state.
     85 """
     86 params = self.get_params(state)
---> 87 (out, aux), grads = value_and_grad(fn, has_aux=True)(params)
     88 return (out, aux), self.update(grads, state)

    [... skipping hidden 8 frame]

File ~/PyVENVs/pytorch-py38/lib/python3.8/site-packages/numpyro/infer/svi.py:58, in _make_loss_fn.<locals>.loss_fn(params)
     57 def loss_fn(params):
---> 58     params = constrain_fn(params)
     59     if mutable_state is not None:
     60         params.update(mutable_state)

TypeError: 'NoneType' object is not callable

With the error trace, it seems that the ‘constrain_fn’ is not initialized.

I want to know 1) what’s the potential reason behind? 2) how can I fix it?

Many Thanks!

Here is my run-specification code:

guide2 = AutoNormal(model)
optimizer2 = numpyro.optim.Adam(exponential_decay(1e-4, 100000, 0.8))
num_steps2 = 1000000
svi2 = SVI(model, guide2, optimizer2, loss=Trace_ELBO())
init_state2 = svi_result.state 

batch = 10000
num_samples = 2000
progress_bar = True
target_hit_rate = 1.0
patience = 5
return_best = True

svi_result2 = mnl_run_init(svi2, rng_key, model, guide2, dataset, num_steps2, batch, num_samples, progress_bar=progress_bar, target_hit_rate=target_hit_rate, patience=patience, return_best=return_best, init_state=init_state2)

Did you run svi.init?

Yes. My init code just looks like that in svi.run():

if init_state is None:
    svi_state = svi.init(rng_key, dataset)
else:
    svi_state = init_state

I think you need to run svi.init(...) first to setup svi. Then you can provide your own state in svi.run(...). The code in your last comment seems not run svi.init.

Got it. I will try it out.
Thank you very much!

Update: the problem sloved by adding a svi.init(...) statement to init-with-state branch of my customized run utility.

Many thanks for your help!