Quick question… I know that if a model has only local (per-data) latent variables, then we don’t need to poutine.scale the likelihood log prob when using mini-batches (which is why we don’t with a VAE).
However, if we have both global and local latent variables, I’m pretty sure we do have to scale the likelihood log prob for mini-batches… but my question is, do we also need to scale the local latent log probs in both the model and the guide too? It seems like this would give too much weight to a per-datapoint prior relative to the data, wouldn’t it?
I wondered the same and all I can say is that pyro’s autoguide does scale them as well:
N = 15
X = jnp.array(np.linspace(.001, 1., N))
def model(subsample_size=None):
noise = numpyro.sample('noise', Gamma(1., 100.))
with numpyro.plate('n', N, subsample_size=subsample_size) as ind:
z = numpyro.sample('z', Gamma(1., 1.))
x = numpyro.sample('x', LogNormal(jnp.log(z), noise), obs=X[ind])
guide = numpyro.infer.autoguide.AutoNormal(model)
subsample_size = 2
model_seed, guide_seed = random.split(random.key(0))
seeded_guide = handlers.seed(guide, guide_seed)
seeded_model = handlers.seed(model, model_seed)
# params = svi.get_params(self.model.svi_state)
trace_m, trace_g = numpyro.infer.util.get_importance_trace(
seeded_model, seeded_guide, [subsample_size], {}, {})
# trace_m, trace_g = next(elbo._get_traces(model, guide, [subsample_size], {}))
print('Model')
for _, node in trace_m.items():
if 'scale' in node:
print(node['name'], node['scale'])
print('Guide')
for _, node in trace_g.items():
if 'scale' in node:
print(node['name'], node['scale'])
Outputs:
Model
noise None
n 1.0
z 7.5
x 7.5
Guide
n 1.0
noise_auto_loc None
noise_auto_scale None
noise None
z_auto_loc 7.5
z_auto_scale 7.5
z 7.5
1 Like
Ah nice, thanks! I didn’t think to check how AutoGuides would deal with them, that’s a good idea.
I asked ChatGPT too and it did seem to agree that you should scale them in both the model and guide, even after I tried to challenge it in the usual ways that can sometimes get it to correct itself when it gives a bogus answer.
I’m thinking the reasoning might have something to do with when we scale the likelihood we’re pretending the batch is the full dataset, so we also need to scale the local latents so that they get a proportional boost if the batch was actually the full dataset.
@fehiepsi does it make sense to scale the local latent log probs?
if you use plates everywhere as intended (using e.g. subsample_size
) any necessary scaling is handled for you automatically under the hood, see e.g. this example
Ah yeah, I probably should’ve been more specific. If we do mini-batches using subsample_size with plate, then it automatically scales them. But if we do mini-batches outside Pyro using something like a DataLoader, then we have to manually scale them ourselves.
Thanks all!
1 Like
The code above shows that this is indeed handled automatically, I think the question why local sites are also scaled
generally speaking, whenever you are doing mini-batching/subsampling in the presence of global latent variables various terms in the elbo will need to be scaled. some of this is explained here. from the point of view of subsampling, the log likelihood of a (unobserved) local latent variable isn’t different from the log likelihood of an observed data point. so the same scaling logic applies
1 Like