Scoring samples inside vs outside scan

Hi everyone,

We are working on a model that uses the scan operator and SVI in NumPyro. To anneal the KL between the guide and model prior, we utilize a scale on x_aux in the outside variate below. It’s omitted in our small example.

We expected the inside and outside variates to compute the same ELBO; however, they are not. Is our expectation wrong? or is there a problem with scan and masking distributions?

Thanks, Ola

from numpyro import sample
from numpyro.distributions import Normal
from numpyro.contrib.control_flow import scan
from numpyro.infer import Trace_ELBO
from numpyro.infer.util import log_density
from numpyro.handlers import seed


from jax import numpy as jnp, random


### STATIC ARGUMENTS
key = random.PRNGKey(0)
x = jnp.arange(5)


### Score samples inside scan ###
def inside_scan_model(x):
    def scan_fn(carry, state):
        sample("x", Normal(state))
        return carry, carry

    scan(scan_fn, jnp.ones(()), x, length=x.shape[0])


def inside_scan_guide(x):
    def scan_fn(carry, state):
        sample("x", Normal())
        return carry, carry

    scan(scan_fn, jnp.ones(()), x, length=x.shape[0])


model_in_ld, tr = log_density(seed(inside_scan_model, key), (x,), {}, {})
guide_in_ld, tr = log_density(seed(inside_scan_guide, key), (x,), {}, {})
inside_elbo = Trace_ELBO().loss(key, {}, inside_scan_model, inside_scan_guide, x)


### Score samples outside scan ###
def outside_scan_model(x):
    def scan_fn(carry, state):
        draw = sample("x", Normal(state).mask(False))
        return carry, draw

    _, draws = scan(scan_fn, jnp.ones(()), x, length=x.shape[0])

    sample("x_aux", Normal(x), obs=draws)


def outside_scan_guide(x):
    def scan_fn(carry, state):
        draw = sample("x", Normal().mask(False))
        return carry, draw

    _, draws = scan(scan_fn, jnp.ones(()), x, length=x.shape[0])

    sample("x_aux", Normal(), obs=draws)


model_out_ld, tr = log_density(seed(outside_scan_model, key), (x,), {}, {})
guide_out_ld, tr = log_density(seed(outside_scan_guide, key), (x,), {}, {})
outside_elbo = Trace_ELBO().loss(key, {}, outside_scan_model, outside_scan_guide, x)

# Check that they are consistent
assert (
    model_in_ld == model_out_ld
), f"outside model ld {model_out_ld:.3f} != {model_in_ld:.3f} ld from inside guide"
assert (
    guide_in_ld == guide_out_ld
), f"outside guide ld {guide_out_ld:.3f} != {guide_in_ld:.3f} ld from inside guide"
assert (
    inside_elbo == outside_elbo
), f"outside elbo {outside_elbo:.3f} != {inside_elbo:.3f} from inside elbo"  # Fails with 3.095 != 8.298

Hi Ola, currently nested scan is not supported (see the note in docs).

I see. I thought a nested scan would be the following pattern

...
def outer_fn(out_car, xs):
  def inner_fn(in_car, x):
     y = sample('x', Dist(in_car, x))
     return y, y

  in_it = sample('in_it', Dist(out_car))
  y, hist = scan(inner_fn, in_it, xs)
  return hist, hist

... = scan(outer_fn, out_it, xss)

Is it the replay of the guide trace that makes the scans nested?

oops, I thought scan_fn is doing some sort of scan. Sorry for the confusion.

Could you check which one gives correct answer? You can use length=2 to check for it.

The log prog for x=arange(5) should be 8.298.

I just saw that the repo I’m using has v0.12.1. I checked the program on the latest version of NumPyro (0.14.0), and it behaves correctly there. So, nothing needs to be done after all (except a version upgrade).

Oh, nice to hear.