SVI optimizer scheduling

The question is related to the optimizer used in SVI:

Imagiine that I have a numpyro model that is valid, and I would like to use a MVN variational optimisation

import numpyro.infer.autoguide as autoguide
from numpyro.infer import Predictive, SVI, Trace_ELBO,  TraceMeanField_ELBO
from numpyro.optim import Adam
guide = autoguide.AutoMultivariateNormal(model_spl, init_loc_fn=numpyro.infer.init_to_median())
optimizer = numpyro.optim.Adam(step_size=5e-3)
svi = SVI(model_spl, guide,optimizer,loss=Trace_ELBO())
svi_result = svi.run(jax.random.PRNGKey(0), 1000, obs)

Usually in Deep Neural Network optimisation, what is used beside the optimizer, it is the learning_rate scheduling. Is there something like that in Numpyro-SVI ?

You can use optax for scheduling (see SVI.optim docs for an example) - this is recommended. Or you can specify schedule like optim.Adam(schedule).

Ok Thanks.
Is there a way also to save the result/status of

optimizer = numpyro.optim.Adam(exponential_decay(5e-3,1000,0.1, end_value=1e-7))
svi = SVI(model_spl, guide,optimizer,loss=Trace_ELBO())
svi_result = svi.run(jax.random.PRNGKey(0), 1000, cl_obs)

to perform a new optimisation from the last one?
As well as save the status for every new loss decrease?

perform a new optimisation from the last one

It is tricky I guess. Currently, you can pass svi_result.state to svi.run(..., init_state) (using dev version of numpyro). But we can’t renew optimizer states. I guess you can do

def model(..., init_params):
    a = numpyro.param("a", init_params["a"])
    ...

then after getting optimized params, you can create another SVI instance, then run with the new init_params argument.

save the status for every new loss decrease?

You can use the pattern (see SVI docs)

state = svi.init(...)
for i in range(1000):
    state, loss = svi.update(state, ...)

to customize your training process.

1 Like

Thanks @fehiepsi

In the following snippet:

guide = autoguide.AutoMultivariateNormal(model, init_loc_fn=numpyro.infer.init_to_median())
optimizer = numpyro.optim.Adam(exponential_decay(5e-3,1000,0.1, end_value=1e-7))
svi = SVI(model_spl, guide,optimizer,loss=Trace_ELBO())
svi_result = svi.run(jax.random.PRNGKey(0), 1000, model_obs)

how are initialized model & guide when the svi.run() is launched?

Looking at the code of SVI init(self, rng_key, *args, **kwargs):, I am not sure to tackle if I can inititalize both Guide & Model, or only the Guide, and in case the init of guide is the only one the user has to do, how to use numpyro.infer.init_to_value().

For instance, if I do

init_params = {'var0':0.2545, 'var2':0.801, 'var3':0.682,...,'var20':0.5}

which specifies the initial values for all my model variables, and

guide = autoguide.AutoMultivariateNormal(model_spl, init_loc_fn=numpyro.infer.init_to_value(values=init_params))

I get the errors when triggering svi.run(jax.random.PRNGKey(0), 1000, model_obs):

      2 optimizer = optax.noisy_sgd(1e-5)
      3 svi = SVI(model_spl, guide,optimizer,loss=Trace_ELBO())
----> 4 svi_result = svi.run(jax.random.PRNGKey(0), 1000, cl_obs)

/numpyro/numpyro/infer/svi.py in run(self, rng_key, num_steps, progress_bar, stable_update, init_state, *args, **kwargs)
    333 
    334         if init_state is None:
--> 335             svi_state = self.init(rng_key, *args, **kwargs)
    336         else:
    337             svi_state = init_state

/numpyro/numpyro/infer/svi.py in init(self, rng_key, *args, **kwargs)
    172         model_init = seed(self.model, model_seed)
    173         guide_init = seed(self.guide, guide_seed)
--> 174         guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
    175         model_trace = trace(replay(model_init, guide_trace)).get_trace(
    176             *args, **kwargs, **self.static_kwargs

/numpyro/numpyro/handlers.py in get_trace(self, *args, **kwargs)
    169         :return: `OrderedDict` containing the execution trace.
    170         """
--> 171         self(*args, **kwargs)
    172         return self.trace
    173 

/numpyro/numpyro/primitives.py in __call__(self, *args, **kwargs)
     85             return self
     86         with self:
---> 87             return self.fn(*args, **kwargs)
     88 
     89 

/numpyro/numpyro/primitives.py in __call__(self, *args, **kwargs)
     85             return self
     86         with self:
---> 87             return self.fn(*args, **kwargs)
     88 
     89 

/numpyro/numpyro/infer/autoguide.py in __call__(self, *args, **kwargs)
    545         if self.prototype_trace is None:
    546             # run model to inspect the model structure
--> 547             self._setup_prototype(*args, **kwargs)
    548 
    549         latent = self._sample_latent(*args, **kwargs)

/numpyro/numpyro/infer/autoguide.py in _setup_prototype(self, *args, **kwargs)
    507 
    508     def _setup_prototype(self, *args, **kwargs):
--> 509         super()._setup_prototype(*args, **kwargs)
    510         self._init_latent, shape_dict = _ravel_dict(self._init_locs)
    511         unpack_latent = partial(_unravel_dict, shape_dict=shape_dict)

/numpyro/numpyro/infer/autoguide.py in _setup_prototype(self, *args, **kwargs)
    144                 postprocess_fn,
    145                 self.prototype_trace,
--> 146             ) = initialize_model(
    147                 rng_key,
    148                 self.model,

/numpyro/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    676                                     line=w.line,
    677                                 )
--> 678             raise RuntimeError(
    679                 "Cannot find valid initial parameters. Please check your model again."
    680             )

RuntimeError: Cannot find valid initial parameters. Please check your model again.

SVI optimizes the parameters, so you can specify init values in

numpyro.param(name, init value)

statements.
Here init value can be either an array or a function which takes random key and return an array (see numpyro.param docs). This applies for all parameters in model and customized guide. For autoguide, param statements are hidden but you can specify init values for its loc parameter by using init loc fn.

With the init to value strategy, the loc parameter is obtained using the init values specified in init to value strategy. Here, in your code, init params should contain init values of some latent variables in your model. Make sure that they belong to the support of their corresponding distributions. For example, -1 cannot be the init value of an Log normal site.

Nice.
I was wandering if the error is due to the naming convention of variables in the autoguide?
In my model I have var0, in the autoguide there is a prefix ‘auto’ so do I should specify

{''auto_var0": 0.2545, ...}

Or, depending of the AutoGuide, ex. AutoMultivariateNormal, there is an array auto_loc that I can initialise ? But then, what is the order of the model variables in this array?

I’m not sure how to set init_loc parameter directly. Probably you need some hack like using svi.init(...) to get initial state, then modifying the init loc value in that init state, then using svi.run with that init state. Which kind of values you want to set to init loc parameters? If they are zeros, you can use init to feasible strategy instead.

If you use the dev branch, the order follows the same order as in the model. Note that those values in auto loc parameters are in unconstrained space (see AutoContinuous docs). I don’t think that you need to play with auto loc values. Why not use init to value strategy instead?

Well, I would like to initialise the parameters of MultivariateNormal to some numerical values but

  1. what is the naming of the variables does it is auto_?
    or an array of size of model variables, named auto_loc
    , and
  2. in case it is the array auto_loc what is the order of the variables: the same of the model variables (ie the order of the numpyro.sample(‘var0’, dist…), numpyro.sample(‘var1’,dist…)?

It is tricky and easy to get mistakes, but if you stick with modifying the svi init state, then see my last comment on the order of parameters stored in auto loc, scale, scale tril,… Those are in unconstrained space and reshaped and concatenated (to turn latent values into a single 1D real vector).

I would recommend to build your own custom guide instead of modifying the svi state.

Note that you can make the init value of loc parameter corresponding to values of variables in the model using init to values strategy. You can also modify the initial global scale of the guide by using init scale argument in the constructor of autoguides.

Thanks but,

guide = autoguide.AutoMultivariateNormal(model_spl,init_loc_fn=numpyro.infer.init_to_value(values=init_params))

fails when ìnit_paramsis a dictionary ofthe form {'var0':val0,...'var20':val20} with

def model(the_obs=None):
  var0 = numpyro.sample('var0', dist.Uniform(-5,5)) # exemple
 ....
 var20 = numpyro.sample('var20', dist.Uniform(-5,5))

return numpyro.sample('my_obs', dist..., obs=the_obs)

Notice also that if I look how are initial_params for the guide (AutoMultivariateNormal) using

numpyro.infer.util.find_valid_initial_params(jax.random.PRNGKey(0),guide)

fails with AttributeError: 'DeviceArray' object has no attribute 'items'

In fact what is may be missing is beside `SVI:get_params(svi_states)

SVI: set_params(values dictionary)

It is likely that your init params are not valid. You can check:

  • if those values belong to the supports
  • run the model by substituting those init values to each site and get log probability at each site. Inspecting them.
  • check if grads of the sum of log probability w.r.t. the init values are valid.

Guide plays no role in finding init values of your model.

Having set_params is error-prone and unnecessary. You can replace the value directly in the init state if you want but I don’t see any reason to follow that error-prone direction. By unnecessary, I meant that in SVI, you can always set init values using numpyro.param(name, init value). If you want to hack the system, just subclass AutoMVN and modify the implementation as you like.

Overall, I highly recommend to inspect your model first, then try AutoDelta, AutoNormal before AutoMVN. If the former ones work, then it is just that AutoMVN is hard to optimize. This will save you much time.

Ok @fehiepsi Thanks for the guidance to debug. By ‘support’ you mean the distribution support of each variable (well, if so I have only Uniform(min,max) and Normal distribs(mean,sigma)). Let see.

I was looking to use

numpyro.infer.util.log_density(model,params={'var0':0.2545})

with

def model(the_obs=None):
     var0=numpyro.sample('var0',dist.Uniform(-5,5))
     var1=numpyro.sample('var1',dist.Uniform(-5,5))
     ...
    compute signal = a_func(var0,var1...)
   return numpyro.sample('obs',dist.MultivariateNormal(signal, Cov), obs=the_obs)

but then

 log_density() missing 2 required positional arguments: 'model_args' and 'model_kwargs'

If I can imagine to provide model_args = some_obs vector, I do not know to provide ‘model_kwargs’

I don’t think that’s the utility that you need. You might want to use log_likelihood with batch_ndims=0 (there are several examples of using log likelihood). But it is better to calculate by hand if you want to avoid diving deeper into the implementation of log_likelihood to debug.

Regarding your question, if model has signature a, b, c=3, then model_args=(1, 2) corresponding to (a, b) and model_kwargs={"c": 1}. You can also specify model_args=(1, 2, 3) and model_kwargs={} for that model. Generally, model args, kwargs comes from its signature

def model(*model_args, **model_kwargs):
    ...

Hi @fehiepsi
I have cross-check the initial value of each parameter and found that one was set at the lower edge of a Unifom(min,max) distribution.

Now, having all the initial values in the range of [min,max] of all the Uniform distrib of my model, I still get an error but a different one:

uide = autoguide.AutoMultivariateNormal(model, init_loc_fn=numpyro.infer.init_to_value(values={'var0':0.2545, 'var1':0.801, 'var2':0.682,..., 'var20':0.0}))
optimizer = numpyro.optim.Adam(5e-3)
svi = SVI(model_spl, guide,optimizer,loss=Trace_ELBO())
svi.init(jax.random.PRNGKey(0), the_obs)  # the_obs are a 1D-vector argument of model

I do not get anymore the error Cannot find valid initial parameters, but this new one

AttributeError                            Traceback (most recent call last)
<ipython-input-14-9ef47034e090> in <module>
----> 1 svi.init(jax.random.PRNGKey(0), the_obs)

/numpyro/numpyro/infer/svi.py in init(self, rng_key, *args, **kwargs)
    172         model_init = seed(self.model, model_seed)
    173         guide_init = seed(self.guide, guide_seed)
--> 174         guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
    175         model_trace = trace(replay(model_init, guide_trace)).get_trace(
    176             *args, **kwargs, **self.static_kwargs

/numpyro/numpyro/handlers.py in get_trace(self, *args, **kwargs)
    169         :return: `OrderedDict` containing the execution trace.
    170         """
--> 171         self(*args, **kwargs)
    172         return self.trace
    173 

/numpyro/numpyro/primitives.py in __call__(self, *args, **kwargs)
     85             return self
     86         with self:
---> 87             return self.fn(*args, **kwargs)
     88 
     89 

/numpyro/numpyro/primitives.py in __call__(self, *args, **kwargs)
     85             return self
     86         with self:
---> 87             return self.fn(*args, **kwargs)
     88 
     89 

/numpyro/numpyro/infer/autoguide.py in __call__(self, *args, **kwargs)
    545         if self.prototype_trace is None:
    546             # run model to inspect the model structure
--> 547             self._setup_prototype(*args, **kwargs)
    548 
    549         latent = self._sample_latent(*args, **kwargs)

/numpyro/numpyro/infer/autoguide.py in _setup_prototype(self, *args, **kwargs)
    508     def _setup_prototype(self, *args, **kwargs):
    509         super()._setup_prototype(*args, **kwargs)
--> 510         self._init_latent, shape_dict = _ravel_dict(self._init_locs)
    511         unpack_latent = partial(_unravel_dict, shape_dict=shape_dict)
    512         # this is to match the behavior of Pyro, where we can apply

/numpyro/numpyro/infer/autoguide.py in _ravel_dict(x)
    479     for name, value in x.items():
    480         shape_dict[name] = jnp.shape(value)
--> 481         x_flat.append(value.reshape(-1))
    482     x_flat = jnp.concatenate(x_flat) if x_flat else jnp.zeros((0,))
    483     return x_flat, shape_dict

AttributeError: 'float' object has no attribute 'reshape'

an idea?

It is a bug. You can resolve the issue by replacing scalar values in init values by numpy sclalar values, like jnp.array(2.)

1 Like