Save SVI state for best loss

Hi,

In the context of SVI, I wander how to save the state given svi.update in case the loss is improved.

Here is a snippet

def body_fn(i,carry):
    svi_state, svi_state_best, loss_best = carry
    svi_state, loss =svi.update(svi_state,cl_obs)
    if loss<loss_best:
       loss_best = loss
       svi_state_best = svi_state
    return (svi_state, svi_state_best, loss_best)


svi_state = svi.init(jax.random.PRNGKey(42),obs)
svi_state_best = svi_state
carry_init = (svi_state, svi_state_best, 1e10)
carry = jax.lax.fori_loop(0,100,body_fn,carry_init)

But

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=1/0)>
The problem arose with the `bool` function. 
While tracing the function scanned_fun at python3.8/site-packages/jax/_src/lax/control_flow.py:142 for scan, this concrete value was not available in Python because it depends on the values of the argument 'loop_carry'.

So, how I cure this problem? Thanks

Well, googling and proceeding by Try/Error, I manage to do someting

def body_fn(i,carry):
    svi_state, svi_state_best, losses = carry
    svi_state, loss =svi.update(svi_state,obs)

    def update_fn(dummy):
       return losses.at[i].set(loss), svi_state
    def keep_fn(dummy):
       return losses.at[i].set(losses[i-1]), svi_state_best
    
   losses, svi_state_best = jax.lax.cond(loss<losses[i-1],update_fn,keep_fn,None)
   
   return (svi_state, svi_state_best, losses)

then

svi_state = svi.init(jax.random.PRNGKey(42),obs)

num_steps=100
losses = np.zeros(num_steps)
losses = losses.at[0].set(1e10)
svi_state_best = svi_state
carry = (svi_state,svi_state_best,losses)
carry = jax.lax.fori_loop(1,num_steps,body_fn,carry)

Seems that the losses array show a decrease by steps…

DeviceArray([ 1.00000000e+10, -9.71529487e+03, -9.93338223e+03,
             -9.93338223e+03, -9.93338223e+03, -9.99908896e+03,
             -9.99908896e+03, -9.99908896e+03, -9.99908896e+03,
             -9.99908896e+03, -9.99908896e+03, -1.01572489e+04,
             -1.01572489e+04, -1.01572489e+04, -1.01572489e+04,
             -1.01572489e+04, -1.01572489e+04, -1.01572489e+04,
             -1.01902815e+04, -1.01902815e+04, -1.01902815e+04,
             -1.01902815e+04, -1.01902815e+04, -1.01902815e+04,
             -1.01902815e+04, -1.01902815e+04, -1.01902815e+04,
             -1.01902815e+04, -1.01902815e+04, -1.01902815e+04,
             -1.01902815e+04, -1.01902815e+04, -1.01902815e+04,
             -1.01902815e+04, -1.01963629e+04, -1.01963629e+04,
             -1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
             -1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
             -1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
             -1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
             -1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
             -1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
             -1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
             -1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
             -1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
             -1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
             -1.01963629e+04, -1.01963629e+04, -1.01963629e+04,
             -1.01963629e+04, -1.02294512e+04, -1.02294512e+04,
             -1.02514454e+04, -1.02514454e+04, -1.02514454e+04,
             -1.02514454e+04, -1.02514454e+04, -1.02514454e+04,
             -1.02514454e+04, -1.02514454e+04, -1.02514454e+04,
             -1.02514454e+04, -1.02514454e+04, -1.02514454e+04,
             -1.02514454e+04, -1.02514454e+04, -1.02514454e+04,
             -1.02514454e+04, -1.02514454e+04, -1.02514454e+04,
             -1.02514454e+04, -1.02514454e+04, -1.02514454e+04,
             -1.02514454e+04, -1.02514454e+04, -1.02514454e+04,
             -1.02514454e+04, -1.02514454e+04, -1.02514454e+04,
             -1.02514454e+04], dtype=float64)

It would be nice is someone can confirm that it is the right thing to do (ie. body_fncode) and/or correct to a better/right way to go. Thanks

That’s a nice solution for getting best state (not that SVI is stochastic so smaller loss does not imply that it is a more optimal state - but I guess it is the most reasonable choice - to make things more robust, you can set num_particles to a higher value in the ELBO constructor) and cumulative min of losses. I would prefer carrying a variable loss_min and check with that, rather than modifying the losses.

Thanks @fehiepsi