Autoguide in class method

Hi everyobdy,
I am trying to find a way to include an autoguide as the guide method in my model class:

class MYMODEL(object):

    def __init__(self, ...):
        ...

    def model(self):
        ...

    def guide(self):
        pyro.infer.autoguide.AutoNormal(self.model)

mm = MYMODEL(...)
svi = SVI(mm.model, mm.guide, optimizer, loss=Trace_ELBO())

However, this always throws back a warning

UserWarning: Found vars in model but not guide: {'...', '...'}
warnings.warn(f"Found vars in model but not guide: {bad_sites}")

Also, the optimized values after SVI look different. Hence, I guess something is wrong with this approach or it is not intended to use autoguide like this?

Thanks in advance!

I believe in that usage, the guide will never be called, since calling mm.guide() simply creates a new autoguide and then throws it away. One alternative would be to make .guide an attribute instead.

class MYMODEL(object):
    def __init__(self, ...):
        self.guide = pyro.infer.autoguide.AutoNormal(self.model)
        ...

    def model(self):
        ...

mm = MYMODEL(...)
svi = SVI(mm.model, mm.guide, optimizer, loss=Trace_ELBO())

It is safe to pass self.model to AutoNormal in the constructor because the guide is initialized lazily, only during SVI.

1 Like

Thanks a bunch!