Hi,
I’ve run into some strange behavior when using reparameterization with an autoguide (such as AutoDiagonalNormal/AutoNormal). As an example, the below is part of a simple time series model that uses local latent ‘drift’ variables for each time period… so there should be a different time-local latent variable (each with its own variational parameters) for each time period.
def local_level_model(data, features, predict=False):
data_dim = data.shape[-1]
features_dim = features.shape[-1]
# - Local level trend.
# Global scale parameter for local latent drifts.
drift_scale = pyro.sample('drift_scale', dist.LogNormal(-5, 5).expand([data_dim]).to_event(1))
with pyro.plate('time', len(data), dim=-1):
# Reparameterize local latents to help optimization.
with poutine.reparam(config={'drift': LocScaleReparam()}):
drift = pyro.sample('drift', dist.Normal(0, drift_scale).to_event(1))
# Cumulative sum of local latent drift terms.
trend = drift.cumsum(dim=0)
...<<<Remaining code is just likelihood, etc.>>>...
The above model suggests the drift latent variable should have different variational parameters for each time period. However, when I use the poutine.reparam with an autoguide, I only get 1 variational parameter for the local latent drift
variable. In contrast, if I either don’t use reparameterization or don’t use an InitMessenger with the autoguide, then I get all the expected variational parameters for the local latent drift
variable (equal to the number of time periods).
I was looking at the source code and I think I found why it’s doing this… when the autoguide gets a prototype_trace, it sets the initial value of the drift
variable using InitMessenger first (which is out of the usual order but working as intended). However, the outer time plate then expands this drift
variable’s ‘fn’ distribution as expected but the ‘value’ still has shape (1,). The prototype trace then uses this drift shape of (1,) to create the associated variational parameter for this latent variable… leading to 1 variational parameter for the local latent drift variable.
In contrast, if reparameterization isn’t used, then when the autoguide runs a prototype_trace, it’s InitMessenger applies the value to the drift
site after the plate expands the site. It then creates variational parameters over this full shape of (time periods,). Basically, poutine.reparam is what makes InitMessenger set the value of the site before the plate expands the site due to some workaround logic that is intended/documented in the code… however, this only seems to be an issue when using autoguide bc of the prototype_trace (maybe?). The poutine.reparam itself creates a pyro.sample site for the decentered distribution that is properly expanded, but this doesn’t seem to affect the shape of the site in the prototype_trace.
My main question is… shouldn’t local latent variables be able to have their own variational parameters regardless of whether or not reparameterization is used? And is there a possible workaround I could use to get both reparameterization and local latents in this model? Or is there something I’m missing about why it’s set up that way, and it should give a good fit regardless?
Thanks for any help you can provide!