For each component of an AutoGuideList
, the model is seeded and then only the sites that should be part of the sub-guide are exposed. However, if deterministic sites are exposed in any of the sub-guides, results may be inconsistent across evaluations.
Consider the following simple model with a partitioned guide, i.e., an AutoGuideList
with sub-guides capturing subsets of the latent variables. We evaluate the median of the guide after using SVI
to initialize parameters; we don’t run any optimization.
>>> from jax import random
>>> import numpyro
>>> from numpyro import distributions as dists
>>> from numpyro import handlers
>>> from numpyro.infer import autoguide
>>> def model():
... a = numpyro.sample("a", dists.Normal(0, 1))
... b = numpyro.sample("b", dists.Normal(0, 1))
... numpyro.deterministic("c", a + b)
>>> def init_partitioned_guide(model, *parts):
... """
... Create a structured AutoDiagonalNormal guide. Equivalent to just using one big
... AutoDiagonalNormal guide, but useful for illustration.
... """
... guide = autoguide.AutoGuideList(model)
... for i, part in enumerate(parts):
... subguide = autoguide.AutoDiagonalNormal(
... handlers.block(handlers.seed(model, rng_seed=0), expose=part),
... prefix=str(i),
... )
... guide.append(subguide)
... return guide
>>> def init_params(model, guide):
... optim = numpyro.optim.Adam(step_size=0.1) # Dummy optimizer to init SVI.
... svi = numpyro.infer.SVI(model, guide, optim, None)
... state = svi.init(random.key(382))
... return svi.get_params(state)
>>> guide = init_partitioned_guide(model, ["a"], ["b"])
>>> params = init_params(model, guide)
>>> guide.median(params)
{'a': Array(-1.1310773, dtype=float32), 'b': Array(-0.4418149, dtype=float32)}
So far so good. If we expose the deterministic site in either of the sub-guides, the constraint c == a + b
is no longer satisfied. That’s because in the first subguide, b
is sampled from the seeded guide constructed before block
ing and vice versa.
>>> guide = init_partitioned_guide(model, ["a"], ["b", "c"])
>>> params = init_params(model, guide)
>>> median = guide.median(params)
>>> median, median["c"], median["a"] + median["b"]
({'a': Array(-1.1310773, dtype=float32),
'b': Array(-0.4418149, dtype=float32),
'c': Array(-1.0904484, dtype=float32)},
Array(-1.0904484, dtype=float32),
Array(-1.5728922, dtype=float32))
Evaluating the median again gives different results (which also violate the deterministic relation). That’s because the state of the seed
before block
ing has advanced.
>>> guide.median(params)
{'a': Array(-1.1310773, dtype=float32),
'b': Array(-0.4418149, dtype=float32),
'c': Array(0.596624, dtype=float32)}
The issue can of course be overcome by not exposing any deterministic sites in any of the sub-guides. It got me when I was using reparametrizations because they are effectively deterministic sites. E.g., consider the following model.
>>> def model():
... a = numpyro.sample("a", dists.Normal(0, 1))
... b = numpyro.sample("b", dists.TransformedDistribution(
... dists.Normal(0, 1),
... dists.transforms.AffineTransform(a, 1),
... ))
>>> config = {"b": numpyro.infer.reparam.TransformReparam()}
>>> reparametrized = handlers.reparam(model, config)
>>> guide = init_partitioned_guide(reparametrized, ["a"], ["b", "b_base"])
>>> params = init_params(reparametrized, guide)
>>> median = guide.median(params)
>>> median, median["b_base"] + median["a"]
({'a': Array(-1.1310773, dtype=float32),
'b_base': Array(-0.4418149, dtype=float32),
'b': Array(-1.0904484, dtype=float32)},
Array(-1.5728922, dtype=float32))
We would’ve expected that b = b_base + a
, but exposing the original variable leads to the inconsistent result shown above. If we want the median including the original variable, we can get the median of all non-deterministic sites, substitute, and trace the model again.
>>> guide = init_partitioned_guide(reparametrized, ["a"], ["b_base"])
>>> params = init_params(reparametrized, guide)
>>> median = guide.median(params)
>>> substituted = handlers.substitute(reparametrized, median)
>>> median2 = {key: site["value"] for key, site in handlers.trace(substituted).get_trace().items()}
>>> median2, median2["b_base"] + median2["a"]
({'a': Array(-1.1310773, dtype=float32),
'b_base': Array(-0.4418149, dtype=float32),
'b': Array(-1.5728922, dtype=float32)},
Array(-1.5728922, dtype=float32))
Is there a better way to get around this peculiarity? It’s not quite a bug, but nevertheless surprising (at least to me).