AutoDAIS nan loss

I am trying to figure out an issue with the choice of autoguide for my numpyro model fitted with svi. Right now i get nan ELBO loss when using AutoDAIS… but i do not run into this issue with AutoNormal or AutoLowRankMultivariateNormal.

I have tried reducing learning rate but that hasn’t changed anything.

Any suggestions?

You might need to inspect your model/guide to see where’s the issue. You can run the guide and get trace via

with numpyro.handlers.trace() as tr, numpyro.handlers.seed(rng_seed=0):
    out = guide(*args, **kwargs)

autodais requires some tuning:

  • ideally you init params with mean field params trained with a vanilla elbo and AutoDiagonalNormal
  • you might need to tune eta_init and eta_max to make sure etas aren’t too large
  • you might need to tune your learning rate to be not-too-large

for the first point something like:

init_params = {'auto_z_0_loc': mf_params['auto_loc']}  # obtained from previous SVI run
init_params['auto_z_0_scale'] = mf_params['auto_scale']
svi = SVI(model, guide, optimizer, Trace_ELBO(num_particles=1))
svi_state = svi.init(random.PRNGKey(seed), init_params=init_params)
1 Like

thanks for this! Appreciate how responsive you guys are! i tired to init params with a autonormal guide. still getting nans. will try tuning eta_init and eta_max. Any suggestions for how to do this would be much appreciated

how high dimensional is your latent space?

make sure the mean field guide is well trained

try e.g. eta_init=0.001 and eta_max=0.1

pretty high dimensional latent space. this is for a mixed effects model with many random effect parameters. I will try that… I believe the mean field guide is well trained. The estimates are in the ball park of what I would expect given my domain knowledge. I used a learning rate scheduler for 35000 steps and experience a seemingly good elbo loss curve

Still getting nans… can’t quite figure out what the problem is

for a hierarchical model probably much better to use autosemidais

https://num.pyro.ai/en/stable/autoguide.html#numpyro.infer.autoguide.AutoSemiDAIS