Quick question… I know that if a model has only local (per-data) latent variables, then we don’t need to poutine.scale the likelihood log prob when using mini-batches (which is why we don’t with a VAE).
However, if we have both global and local latent variables, I’m pretty sure we do have to scale the likelihood log prob for mini-batches… but my question is, do we also need to scale the local latent log probs in both the model and the guide too? It seems like this would give too much weight to a per-datapoint prior relative to the data, wouldn’t it?
I wondered the same and all I can say is that pyro’s autoguide does scale them as well:
N = 15
X = jnp.array(np.linspace(.001, 1., N))
def model(subsample_size=None):
noise = numpyro.sample('noise', Gamma(1., 100.))
with numpyro.plate('n', N, subsample_size=subsample_size) as ind:
z = numpyro.sample('z', Gamma(1., 1.))
x = numpyro.sample('x', LogNormal(jnp.log(z), noise), obs=X[ind])
guide = numpyro.infer.autoguide.AutoNormal(model)
subsample_size = 2
model_seed, guide_seed = random.split(random.key(0))
seeded_guide = handlers.seed(guide, guide_seed)
seeded_model = handlers.seed(model, model_seed)
# params = svi.get_params(self.model.svi_state)
trace_m, trace_g = numpyro.infer.util.get_importance_trace(
seeded_model, seeded_guide, [subsample_size], {}, {})
# trace_m, trace_g = next(elbo._get_traces(model, guide, [subsample_size], {}))
print('Model')
for _, node in trace_m.items():
if 'scale' in node:
print(node['name'], node['scale'])
print('Guide')
for _, node in trace_g.items():
if 'scale' in node:
print(node['name'], node['scale'])
Outputs:
Model
noise None
n 1.0
z 7.5
x 7.5
Guide
n 1.0
noise_auto_loc None
noise_auto_scale None
noise None
z_auto_loc 7.5
z_auto_scale 7.5
z 7.5