Thanks @fehiepsi
In the following snippet:
guide = autoguide.AutoMultivariateNormal(model, init_loc_fn=numpyro.infer.init_to_median())
optimizer = numpyro.optim.Adam(exponential_decay(5e-3,1000,0.1, end_value=1e-7))
svi = SVI(model_spl, guide,optimizer,loss=Trace_ELBO())
svi_result = svi.run(jax.random.PRNGKey(0), 1000, model_obs)
how are initialized model & guide when the svi.run() is launched?
Looking at the code of SVI init(self, rng_key, *args, **kwargs):
, I am not sure to tackle if I can inititalize both Guide & Model, or only the Guide, and in case the init of guide is the only one the user has to do, how to use numpyro.infer.init_to_value().
For instance, if I do
init_params = {'var0':0.2545, 'var2':0.801, 'var3':0.682,...,'var20':0.5}
which specifies the initial values for all my model variables, and
guide = autoguide.AutoMultivariateNormal(model_spl, init_loc_fn=numpyro.infer.init_to_value(values=init_params))
I get the errors when triggering svi.run(jax.random.PRNGKey(0), 1000, model_obs)
:
2 optimizer = optax.noisy_sgd(1e-5)
3 svi = SVI(model_spl, guide,optimizer,loss=Trace_ELBO())
----> 4 svi_result = svi.run(jax.random.PRNGKey(0), 1000, cl_obs)
/numpyro/numpyro/infer/svi.py in run(self, rng_key, num_steps, progress_bar, stable_update, init_state, *args, **kwargs)
333
334 if init_state is None:
--> 335 svi_state = self.init(rng_key, *args, **kwargs)
336 else:
337 svi_state = init_state
/numpyro/numpyro/infer/svi.py in init(self, rng_key, *args, **kwargs)
172 model_init = seed(self.model, model_seed)
173 guide_init = seed(self.guide, guide_seed)
--> 174 guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
175 model_trace = trace(replay(model_init, guide_trace)).get_trace(
176 *args, **kwargs, **self.static_kwargs
/numpyro/numpyro/handlers.py in get_trace(self, *args, **kwargs)
169 :return: `OrderedDict` containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace
173
/numpyro/numpyro/primitives.py in __call__(self, *args, **kwargs)
85 return self
86 with self:
---> 87 return self.fn(*args, **kwargs)
88
89
/numpyro/numpyro/primitives.py in __call__(self, *args, **kwargs)
85 return self
86 with self:
---> 87 return self.fn(*args, **kwargs)
88
89
/numpyro/numpyro/infer/autoguide.py in __call__(self, *args, **kwargs)
545 if self.prototype_trace is None:
546 # run model to inspect the model structure
--> 547 self._setup_prototype(*args, **kwargs)
548
549 latent = self._sample_latent(*args, **kwargs)
/numpyro/numpyro/infer/autoguide.py in _setup_prototype(self, *args, **kwargs)
507
508 def _setup_prototype(self, *args, **kwargs):
--> 509 super()._setup_prototype(*args, **kwargs)
510 self._init_latent, shape_dict = _ravel_dict(self._init_locs)
511 unpack_latent = partial(_unravel_dict, shape_dict=shape_dict)
/numpyro/numpyro/infer/autoguide.py in _setup_prototype(self, *args, **kwargs)
144 postprocess_fn,
145 self.prototype_trace,
--> 146 ) = initialize_model(
147 rng_key,
148 self.model,
/numpyro/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
676 line=w.line,
677 )
--> 678 raise RuntimeError(
679 "Cannot find valid initial parameters. Please check your model again."
680 )
RuntimeError: Cannot find valid initial parameters. Please check your model again.