Normalizing loss in NumPyro

Hi NumPyro,

I was following the Pyro tutorial on stochastic variational inference and translating some things to NumPyro. Part IV: Tips and Tricks mentions to “10. Consider normalizing your ELBO” and references the Pyro tutorial on Scaling the Loss by using the @poutine.scale(...) decorator. However, I couldn’t find how to do a similar scaling of model and guide functions in NumPyro.

How could I scale the (ELBO) loss in NumPyro?

see scale in the docs and see here for a usage example

1 Like

Thanks, the following seems to scale the losses when applied on the observations:

from numpyro import handlers

def model(...):
    with handlers.scale(scale=....):
        # observations

For some reason however, scaling with the inverse of the number of sames does not seem to have much effect on the variance of the losses (I was hoping it would reduce the variance), and in some cases it can even increase the variance.

note that scaling should not have any effect on the variance. if computers used exact arithmetic instead of floating point arithmetic scaling would have essentially no effect. the main reason to potentially scale your loss is to bring the numerics into a regime where under/overflows are less likely. note also that you need to scale the entire model, not just some of the sample statements, otherwise you’re actually changing the model

1 Like