# Scoring samples inside vs outside scan

Hi everyone,

We are working on a model that uses the scan operator and SVI in NumPyro. To anneal the KL between the guide and model prior, we utilize a scale on `x_aux` in the outside variate below. It’s omitted in our small example.

We expected the inside and outside variates to compute the same ELBO; however, they are not. Is our expectation wrong? or is there a problem with scan and masking distributions?

Thanks, Ola

``````from numpyro import sample
from numpyro.distributions import Normal
from numpyro.contrib.control_flow import scan
from numpyro.infer import Trace_ELBO
from numpyro.infer.util import log_density
from numpyro.handlers import seed

from jax import numpy as jnp, random

### STATIC ARGUMENTS
key = random.PRNGKey(0)
x = jnp.arange(5)

### Score samples inside scan ###
def inside_scan_model(x):
def scan_fn(carry, state):
sample("x", Normal(state))
return carry, carry

scan(scan_fn, jnp.ones(()), x, length=x.shape[0])

def inside_scan_guide(x):
def scan_fn(carry, state):
sample("x", Normal())
return carry, carry

scan(scan_fn, jnp.ones(()), x, length=x.shape[0])

model_in_ld, tr = log_density(seed(inside_scan_model, key), (x,), {}, {})
guide_in_ld, tr = log_density(seed(inside_scan_guide, key), (x,), {}, {})
inside_elbo = Trace_ELBO().loss(key, {}, inside_scan_model, inside_scan_guide, x)

### Score samples outside scan ###
def outside_scan_model(x):
def scan_fn(carry, state):
return carry, draw

_, draws = scan(scan_fn, jnp.ones(()), x, length=x.shape[0])

sample("x_aux", Normal(x), obs=draws)

def outside_scan_guide(x):
def scan_fn(carry, state):
return carry, draw

_, draws = scan(scan_fn, jnp.ones(()), x, length=x.shape[0])

sample("x_aux", Normal(), obs=draws)

model_out_ld, tr = log_density(seed(outside_scan_model, key), (x,), {}, {})
guide_out_ld, tr = log_density(seed(outside_scan_guide, key), (x,), {}, {})
outside_elbo = Trace_ELBO().loss(key, {}, outside_scan_model, outside_scan_guide, x)

# Check that they are consistent
assert (
model_in_ld == model_out_ld
), f"outside model ld {model_out_ld:.3f} != {model_in_ld:.3f} ld from inside guide"
assert (
guide_in_ld == guide_out_ld
), f"outside guide ld {guide_out_ld:.3f} != {guide_in_ld:.3f} ld from inside guide"
assert (
inside_elbo == outside_elbo
), f"outside elbo {outside_elbo:.3f} != {inside_elbo:.3f} from inside elbo"  # Fails with 3.095 != 8.298

``````

Hi Ola, currently nested scan is not supported (see the note in docs).

I see. I thought a nested scan would be the following pattern

``````...
def outer_fn(out_car, xs):
def inner_fn(in_car, x):
y = sample('x', Dist(in_car, x))
return y, y

in_it = sample('in_it', Dist(out_car))
y, hist = scan(inner_fn, in_it, xs)
return hist, hist

... = scan(outer_fn, out_it, xss)
``````

Is it the replay of the guide trace that makes the scans nested?

oops, I thought scan_fn is doing some sort of scan. Sorry for the confusion.

Could you check which one gives correct answer? You can use length=2 to check for it.

The log prog for `x=arange(5)` should be 8.298.

I just saw that the repo I’m using has v0.12.1. I checked the program on the latest version of NumPyro (0.14.0), and it behaves correctly there. So, nothing needs to be done after all (except a version upgrade).

Oh, nice to hear.