Reparameterization with AutoGuides - Strange Behavior


I’ve run into some strange behavior when using reparameterization with an autoguide (such as AutoDiagonalNormal/AutoNormal). As an example, the below is part of a simple time series model that uses local latent ‘drift’ variables for each time period… so there should be a different time-local latent variable (each with its own variational parameters) for each time period.

def local_level_model(data, features, predict=False):
    data_dim = data.shape[-1]
    features_dim = features.shape[-1]

    # - Local level trend.
    # Global scale parameter for local latent drifts.
    drift_scale = pyro.sample('drift_scale', dist.LogNormal(-5, 5).expand([data_dim]).to_event(1))
    with pyro.plate('time', len(data), dim=-1):
        # Reparameterize local latents to help optimization.
        with poutine.reparam(config={'drift': LocScaleReparam()}):
            drift = pyro.sample('drift', dist.Normal(0, drift_scale).to_event(1))

    # Cumulative sum of local latent drift terms.
    trend = drift.cumsum(dim=0)

...<<<Remaining code is just likelihood, etc.>>>...

The above model suggests the drift latent variable should have different variational parameters for each time period. However, when I use the poutine.reparam with an autoguide, I only get 1 variational parameter for the local latent drift variable. In contrast, if I either don’t use reparameterization or don’t use an InitMessenger with the autoguide, then I get all the expected variational parameters for the local latent drift variable (equal to the number of time periods).

I was looking at the source code and I think I found why it’s doing this… when the autoguide gets a prototype_trace, it sets the initial value of the drift variable using InitMessenger first (which is out of the usual order but working as intended). However, the outer time plate then expands this drift variable’s ‘fn’ distribution as expected but the ‘value’ still has shape (1,). The prototype trace then uses this drift shape of (1,) to create the associated variational parameter for this latent variable… leading to 1 variational parameter for the local latent drift variable.

In contrast, if reparameterization isn’t used, then when the autoguide runs a prototype_trace, it’s InitMessenger applies the value to the drift site after the plate expands the site. It then creates variational parameters over this full shape of (time periods,). Basically, poutine.reparam is what makes InitMessenger set the value of the site before the plate expands the site due to some workaround logic that is intended/documented in the code… however, this only seems to be an issue when using autoguide bc of the prototype_trace (maybe?). The poutine.reparam itself creates a pyro.sample site for the decentered distribution that is properly expanded, but this doesn’t seem to affect the shape of the site in the prototype_trace.

My main question is… shouldn’t local latent variables be able to have their own variational parameters regardless of whether or not reparameterization is used? And is there a possible workaround I could use to get both reparameterization and local latents in this model? Or is there something I’m missing about why it’s set up that way, and it should give a good fit regardless?

Thanks for any help you can provide!

So I did figure out a workaround to get variational parameters on each local latent variable. Basically, I just had to .expand() the drift distribution instead of waiting for the plate’s internal broadcast messenger to expand it.

However, strangely, having local variational parameters actually causes the fit to be a bit worse even with reparameterization. Simultaneously though, the trend component is more readable instead of looking like random noise. Anyway, I would have thought the fit should be better since it’s tailoring each local drift variable. Not sure if anyone has any idea why the fit would be worse this way? Possibly the trend equation is too stiff compared to random noise…?

Anyway, I’ll play around with the optimizer and see if I can get a lower elbo loss.


Okay, so I was able to get a really good fit with variational parameters on each local latent variable.

The main thing I needed to keep in mind is that it’s harder to fit the model when there are so many more local latents so I needed to decrease the step size. I also had a really bad prior on the initial trend (intercept) component and fixing that also helped tremendously.

1 Like

Nice sleuthing and great writeup! Yeah it has been tricky getting all these effect handlers to work together: plates, reparam, and InitMessenger. There are still some sharp edges. poutine.reparam should at least give a loud error in your case. Feel free to file a bug report.