Hi everyone,
We are working on a model that uses the scan operator and SVI in NumPyro. To anneal the KL between the guide and model prior, we utilize a scale on x_aux
in the outside variate below. It’s omitted in our small example.
We expected the inside and outside variates to compute the same ELBO; however, they are not. Is our expectation wrong? or is there a problem with scan and masking distributions?
Thanks, Ola
from numpyro import sample
from numpyro.distributions import Normal
from numpyro.contrib.control_flow import scan
from numpyro.infer import Trace_ELBO
from numpyro.infer.util import log_density
from numpyro.handlers import seed
from jax import numpy as jnp, random
### STATIC ARGUMENTS
key = random.PRNGKey(0)
x = jnp.arange(5)
### Score samples inside scan ###
def inside_scan_model(x):
def scan_fn(carry, state):
sample("x", Normal(state))
return carry, carry
scan(scan_fn, jnp.ones(()), x, length=x.shape[0])
def inside_scan_guide(x):
def scan_fn(carry, state):
sample("x", Normal())
return carry, carry
scan(scan_fn, jnp.ones(()), x, length=x.shape[0])
model_in_ld, tr = log_density(seed(inside_scan_model, key), (x,), {}, {})
guide_in_ld, tr = log_density(seed(inside_scan_guide, key), (x,), {}, {})
inside_elbo = Trace_ELBO().loss(key, {}, inside_scan_model, inside_scan_guide, x)
### Score samples outside scan ###
def outside_scan_model(x):
def scan_fn(carry, state):
draw = sample("x", Normal(state).mask(False))
return carry, draw
_, draws = scan(scan_fn, jnp.ones(()), x, length=x.shape[0])
sample("x_aux", Normal(x), obs=draws)
def outside_scan_guide(x):
def scan_fn(carry, state):
draw = sample("x", Normal().mask(False))
return carry, draw
_, draws = scan(scan_fn, jnp.ones(()), x, length=x.shape[0])
sample("x_aux", Normal(), obs=draws)
model_out_ld, tr = log_density(seed(outside_scan_model, key), (x,), {}, {})
guide_out_ld, tr = log_density(seed(outside_scan_guide, key), (x,), {}, {})
outside_elbo = Trace_ELBO().loss(key, {}, outside_scan_model, outside_scan_guide, x)
# Check that they are consistent
assert (
model_in_ld == model_out_ld
), f"outside model ld {model_out_ld:.3f} != {model_in_ld:.3f} ld from inside guide"
assert (
guide_in_ld == guide_out_ld
), f"outside guide ld {guide_out_ld:.3f} != {guide_in_ld:.3f} ld from inside guide"
assert (
inside_elbo == outside_elbo
), f"outside elbo {outside_elbo:.3f} != {inside_elbo:.3f} from inside elbo" # Fails with 3.095 != 8.298