Well, googling and proceeding by Try/Error, I manage to do someting
def body_fn(i,carry):
svi_state, svi_state_best, losses = carry
svi_state, loss =svi.update(svi_state,obs)
def update_fn(dummy):
return losses.at[i].set(loss), svi_state
def keep_fn(dummy):
return losses.at[i].set(losses[i-1]), svi_state_best
losses, svi_state_best = jax.lax.cond(loss<losses[i-1],update_fn,keep_fn,None)
return (svi_state, svi_state_best, losses)
then
svi_state = svi.init(jax.random.PRNGKey(42),obs)
num_steps=100
losses = np.zeros(num_steps)
losses = losses.at[0].set(1e10)
svi_state_best = svi_state
carry = (svi_state,svi_state_best,losses)
carry = jax.lax.fori_loop(1,num_steps,body_fn,carry)
Seems that the losses array show a decrease by steps…
DeviceArray([ 1.00000000e+10, -9.71529487e+03, -9.93338223e+03,
-9.93338223e+03, -9.93338223e+03, -9.99908896e+03,
-9.99908896e+03, -9.99908896e+03, -9.99908896e+03,
-9.99908896e+03, -9.99908896e+03, -1.01572489e+04,
-1.01572489e+04, -1.01572489e+04, -1.01572489e+04,
-1.01572489e+04, -1.01572489e+04, -1.01572489e+04,
-1.01902815e+04, -1.01902815e+04, -1.01902815e+04,
-1.01902815e+04, -1.01902815e+04, -1.01902815e+04,
-1.01902815e+04, -1.01902815e+04, -1.01902815e+04,
-1.01902815e+04, -1.01902815e+04, -1.01902815e+04,
-1.01902815e+04, -1.01902815e+04, -1.01902815e+04,
-1.01902815e+04, -1.01963629e+04, -1.01963629e+04,
-1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
-1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
-1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
-1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
-1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
-1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
-1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
-1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
-1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
-1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
-1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
-1.01963629e+04, -1.02294512e+04, -1.02294512e+04,
-1.02514454e+04, -1.02514454e+04, -1.02514454e+04,
-1.02514454e+04, -1.02514454e+04, -1.02514454e+04,
-1.02514454e+04, -1.02514454e+04, -1.02514454e+04,
-1.02514454e+04, -1.02514454e+04, -1.02514454e+04,
-1.02514454e+04, -1.02514454e+04, -1.02514454e+04,
-1.02514454e+04, -1.02514454e+04, -1.02514454e+04,
-1.02514454e+04, -1.02514454e+04, -1.02514454e+04,
-1.02514454e+04, -1.02514454e+04, -1.02514454e+04,
-1.02514454e+04, -1.02514454e+04, -1.02514454e+04,
-1.02514454e+04], dtype=float64)
It would be nice is someone can confirm that it is the right thing to do (ie. body_fn
code) and/or correct to a better/right way to go. Thanks