Hi there,
I wrote a customized SVI run utility with Numpyro’s SVI.run() as a reference and the model can be successfully fitted with my utility.
However, when I pass in the “svi_result.state” as an initial state for another run, I get the following error message:
TypeError Traceback (most recent call last)
Input In [88], in <cell line: 1>()
----> 1 svi_result2 = mnl_run_init(svi2, rng_key, model, guide2, dataset, num_steps2, batch, num_samples, progress_bar=progress_bar, target_hit_rate=target_hit_rate, patience=patience, return_best=return_best, init_state=init_state2)
Input In [85], in mnl_run_init(svi, rng_key, model, guide, dataset, num_steps, batch, num_samples, progress_bar, target_hit_rate, patience, return_best, init_state)
43 with trange(1, num_steps + 1) as t:
44 for i in t:
---> 45 svi_state, loss = jit(body_fn)(svi_state, None)
46 losses.append(loss)
47 if i % batch == 0:
[... skipping hidden 14 frame]
Input In [85], in mnl_run_init.<locals>.body_fn(svi_state, _)
31 def body_fn(svi_state, _):
---> 32 svi_state, loss = svi.update(svi_state, dataset)
33 return svi_state, loss
File ~/PyVENVs/pytorch-py38/lib/python3.8/site-packages/numpyro/infer/svi.py:254, in SVI.update(self, svi_state, *args, **kwargs)
242 rng_key, rng_key_step = random.split(svi_state.rng_key)
243 loss_fn = _make_loss_fn(
244 self.loss,
245 rng_key_step,
(...)
252 mutable_state=svi_state.mutable_state,
253 )
--> 254 (loss_val, mutable_state), optim_state = self.optim.eval_and_update(
255 loss_fn, svi_state.optim_state
256 )
257 return SVIState(optim_state, mutable_state, rng_key), loss_val
File ~/PyVENVs/pytorch-py38/lib/python3.8/site-packages/numpyro/optim.py:87, in _NumPyroOptim.eval_and_update(self, fn, state)
72 """
73 Performs an optimization step for the objective function `fn`.
74 For most optimizers, the update is performed based on the gradient
(...)
84 :return: a pair of the output of objective function and the new optimizer state.
85 """
86 params = self.get_params(state)
---> 87 (out, aux), grads = value_and_grad(fn, has_aux=True)(params)
88 return (out, aux), self.update(grads, state)
[... skipping hidden 8 frame]
File ~/PyVENVs/pytorch-py38/lib/python3.8/site-packages/numpyro/infer/svi.py:58, in _make_loss_fn.<locals>.loss_fn(params)
57 def loss_fn(params):
---> 58 params = constrain_fn(params)
59 if mutable_state is not None:
60 params.update(mutable_state)
TypeError: 'NoneType' object is not callable
With the error trace, it seems that the ‘constrain_fn’ is not initialized.
I want to know 1) what’s the potential reason behind? 2) how can I fix it?
Many Thanks!
Here is my run-specification code:
guide2 = AutoNormal(model)
optimizer2 = numpyro.optim.Adam(exponential_decay(1e-4, 100000, 0.8))
num_steps2 = 1000000
svi2 = SVI(model, guide2, optimizer2, loss=Trace_ELBO())
init_state2 = svi_result.state
batch = 10000
num_samples = 2000
progress_bar = True
target_hit_rate = 1.0
patience = 5
return_best = True
svi_result2 = mnl_run_init(svi2, rng_key, model, guide2, dataset, num_steps2, batch, num_samples, progress_bar=progress_bar, target_hit_rate=target_hit_rate, patience=patience, return_best=return_best, init_state=init_state2)