Deterministic sites in AutoGuideList

For each component of an AutoGuideList, the model is seeded and then only the sites that should be part of the sub-guide are exposed. However, if deterministic sites are exposed in any of the sub-guides, results may be inconsistent across evaluations.

Consider the following simple model with a partitioned guide, i.e., an AutoGuideList with sub-guides capturing subsets of the latent variables. We evaluate the median of the guide after using SVI to initialize parameters; we don’t run any optimization.

>>> from jax import random
>>> import numpyro
>>> from numpyro import distributions as dists
>>> from numpyro import handlers
>>> from numpyro.infer import autoguide


>>> def model():
...     a = numpyro.sample("a", dists.Normal(0, 1))
...     b = numpyro.sample("b", dists.Normal(0, 1))
...     numpyro.deterministic("c", a + b)


>>> def init_partitioned_guide(model, *parts):
...     """
...     Create a structured AutoDiagonalNormal guide. Equivalent to just using one big
...     AutoDiagonalNormal guide, but useful for illustration.
...     """
...     guide = autoguide.AutoGuideList(model)
...     for i, part in enumerate(parts):
...         subguide = autoguide.AutoDiagonalNormal(
...             handlers.block(handlers.seed(model, rng_seed=0), expose=part),
...             prefix=str(i),
...         )
...         guide.append(subguide)
...     return guide


>>> def init_params(model, guide):
...     optim = numpyro.optim.Adam(step_size=0.1)  # Dummy optimizer to init SVI.
...     svi = numpyro.infer.SVI(model, guide, optim, None)
...     state = svi.init(random.key(382))
...     return svi.get_params(state)


>>> guide = init_partitioned_guide(model, ["a"], ["b"])
>>> params = init_params(model, guide)
>>> guide.median(params)
{'a': Array(-1.1310773, dtype=float32), 'b': Array(-0.4418149, dtype=float32)}

So far so good. If we expose the deterministic site in either of the sub-guides, the constraint c == a + b is no longer satisfied. That’s because in the first subguide, b is sampled from the seeded guide constructed before blocking and vice versa.

>>> guide = init_partitioned_guide(model, ["a"], ["b", "c"])
>>> params = init_params(model, guide)
>>> median = guide.median(params)
>>> median, median["c"], median["a"] + median["b"]
({'a': Array(-1.1310773, dtype=float32),
  'b': Array(-0.4418149, dtype=float32),
  'c': Array(-1.0904484, dtype=float32)},
 Array(-1.0904484, dtype=float32),
 Array(-1.5728922, dtype=float32))

Evaluating the median again gives different results (which also violate the deterministic relation). That’s because the state of the seed before blocking has advanced.

>>> guide.median(params)
{'a': Array(-1.1310773, dtype=float32),
 'b': Array(-0.4418149, dtype=float32),
 'c': Array(0.596624, dtype=float32)}

The issue can of course be overcome by not exposing any deterministic sites in any of the sub-guides. It got me when I was using reparametrizations because they are effectively deterministic sites. E.g., consider the following model.

>>> def model():
...     a = numpyro.sample("a", dists.Normal(0, 1))
...     b = numpyro.sample("b", dists.TransformedDistribution(
...         dists.Normal(0, 1),
...         dists.transforms.AffineTransform(a, 1),
...     ))

>>> config = {"b": numpyro.infer.reparam.TransformReparam()}
>>> reparametrized = handlers.reparam(model, config)
>>> guide = init_partitioned_guide(reparametrized, ["a"], ["b", "b_base"])
>>> params = init_params(reparametrized, guide)
>>> median = guide.median(params)
>>> median, median["b_base"] + median["a"]
({'a': Array(-1.1310773, dtype=float32),
  'b_base': Array(-0.4418149, dtype=float32),
  'b': Array(-1.0904484, dtype=float32)},
 Array(-1.5728922, dtype=float32))

We would’ve expected that b = b_base + a, but exposing the original variable leads to the inconsistent result shown above. If we want the median including the original variable, we can get the median of all non-deterministic sites, substitute, and trace the model again.

>>> guide = init_partitioned_guide(reparametrized, ["a"], ["b_base"])
>>> params = init_params(reparametrized, guide)
>>> median = guide.median(params)

>>> substituted = handlers.substitute(reparametrized, median)
>>> median2 = {key: site["value"] for key, site in handlers.trace(substituted).get_trace().items()}
>>> median2, median2["b_base"] + median2["a"]
({'a': Array(-1.1310773, dtype=float32),
  'b_base': Array(-0.4418149, dtype=float32),
  'b': Array(-1.5728922, dtype=float32)},
 Array(-1.5728922, dtype=float32))

Is there a better way to get around this peculiarity? It’s not quite a bug, but nevertheless surprising (at least to me).

This sounds reasonable to me. I think while looping over the guides, we can check whether “deterministic” appears in the trace. If so, we can raise a warning or error.