State-space model: is lax.scan compatible with numpyro.sample?

I wanted to test out coding a state-space model in numpyro using lax.scan. I’m running into issues that make me suspect this is not supported — or perhaps I’m just getting something else wrong! Here’s my model:

def target(T=10, q=1, r=1, phi=0., beta=0.):
    
    def transition(state, i):
        x0, mu0 = state
        x1 = numpyro.sample(f'x_{i}', dist.Normal(phi*x0, q))
        mu1 = beta * mu0 + x1
        y1 = numpyro.sample(f'y_{i}', dist.Normal(mu1, r))
        return (x1, mu1), (x1, y1)
    
    mu0 = x0 = numpyro.sample('x_0', dist.Normal(0, q))
    y0 = numpyro.sample('y_0', dist.Normal(mu0, r))
    
    _, xy = jax.lax.scan(transition, (x0, mu0), np.arange(1, T))
    x, y = xy

    return np.append(x0, x), np.append(y0, y)

This returns:

x [-1.1470195 -0.3285517 -0.3285517 -0.3285517 -0.3285517 -0.3285517
 -0.3285517 -0.3285517 -0.3285517 -0.3285517]
y [-2.2391834   0.32762653  0.32762653  0.32762653  0.32762653  0.32762653
  0.32762653  0.32762653  0.32762653  0.32762653]

It appears the sample statements within transition only generate one random value, which is repeated in each iteration. When I try to use this model within Predictive, I get an error:

prior = Predictive(target, posterior_samples = {}, num_samples = 10)
prior_samples = prior(PRNGKey(2), T=10)
UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state.
Details: Can't lift level Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)> to JaxprTrace(level=0/0).

I don’t need to get this model running, and understand that I could reparameterize it to generate all of the random variables outside the loop. I’m just wondering about more general state-space models with transitions that are not as easily re-parameterized: can one put sampling statements within a loop that is executed by lax.scan?

Thanks!

Update: I checked that this behaves as expected if I replace jax.lax.scan with this reference implementation (slightly edited from the jax docs to handle nested containers):

def scan(f, init, xs, length=None):
    if xs is None:
        xs = [None] * length
    carry = init
    ys = []
    for x in xs:
        carry, y = f(carry, x)
        y_flat, y_tree = tree_flatten(y)
        ys.append(y_flat)
    ys_stacked = [np.stack([y[i] for y in ys]) for i in range(len(y_flat))]
    return carry, tree_unflatten(y_tree, ys_stacked)

This gives:

x [-1.1470195  -0.3285517  -1.1249288   0.7978287   2.298611    1.3821741
 -0.71665144 -1.2928588   0.54819393 -0.929248  ]
y [-2.2391834   0.32762653 -0.07294703  0.45133084  0.5909368   2.4635558
  0.39577067 -1.1304063   0.41633856 -0.15525198]

So I guess this confirms that numypro.sample and jax.lax.scan are not compatible. (I can certainly understand this given the complexity involved!) Does that seem right? Also, if correct: I find this a very elegant way to write a time-series model — would there be a chance of this being supported in the future?

Hi @sheldon, currently they are not composable like you expected. It seems to be complicated to support the general pattern but something like pyro.infer.reparam can work here. Basically, your program can be reparameterized (which is very helpful for inference algorithms) as

x_noise = numpyro.sample('x_noise', dist.Normal(0, 1), sample_shape=(T - 1,))
y_noise = numpyro.sample('y_noise', dist.Normal(0, 1), sample_shape=(T - 1,))
# then use x_noise, y_noise in transition, e.g.
#    x1 = phi * x0 + q * x_noise[i]

What do you think?

In case you want to support those common patterns, could you open an issue in github so that we can follow up this FR after porting pyro.infer.reparam to NumPyro? Thanks!

@neerajprad what do you think about having something like reparam.ScanReparam(transition, length, {"x": LocScaleReparam, "y": LocScaleReparam})? I just have a vague idea that it will work.

@sheldon - Unfortunately, numpyro primitives like sample have side effects which need to be captured by the tracer for the programs to work correctly (in this case lax.scan doesn’t have that visibility). Parameterizing as suggested by @fehiepsi will work best. We should highlight these gotchas in our README.

what do you think about having something like reparam.ScanReparam(transition, length, {"x": LocScaleReparam, "y": LocScaleReparam}) ? I just have a vague idea that it will work.

Do you think that it is a general enough solution? We can open up an issue and discuss if this is a common pattern that we’d like to support. An alternative is to highlight these use cases through specific examples that users can use as templates for their models.

Hi @fehiepsi and @neerajprad. Thanks for the responses! Reparameterization seems like a good solution, at least for many models. As a user, it’s very clear how to do it for this model, but I’m not sure it would be completely obvious to me for a more complex model. If there are routines to reparameterize automatically, when possible, that sounds like a nice feature.

I was actually thinking about this from the perspective of inference algorithms and not any specific model. The nice thing about the lax.scan design pattern is that it explicitly reveals the sequential nature of the model, and it seems like it would be a short leap from a model described in that format to something like SMC for that model. That would seem harder if noise is pulled out of the loop.

Understood! I agree that it is less engineering and the model looks more intuitive if we do reparameterize automatically. Looking like we can support this in the future.

@neerajprad Yes, I think it will work (though complicated) for all reparameterized sites. Something like

def scan_reparam(transition, carry, xs):
    # inspect latent sites by running the first step of `transition(carry)`
    # replace `sample(site, dist)` statements
    # by `sample(site, noise_dist, sample_shape=(len(xs),))`,
    # and store the results in a dict `site_values`

    def new_transition(carry, x_):
        i, x = xs_
        # use effect handler `reparam_transition_fn` for `transition`
        # to make `sample(site, ...)` returns `loc + scale * site_values[site][i]`
        noises =  {site: values[i] for site, value in site_values}
        return block(reparam_given_noise(transform_fn, noises))(carry, x)

    # run scan for the remaining steps (I use the same
    # xs and carry here for simplicity)
    return lax.scan(new_transition, carry, (np.arange(len(xs)), xs))

While writing the sketch, I feel more confident that it will work. WDYT?

@sheldon - Reparameterization will more generally make it easier for the NUTS sampler to sample from this model due to non-centering and should be much faster.

If you would like to use lax.scan in other models where it might not be easy to reparameterize, you will need to pass in the PRNGKey explicitly to the scan’s body function, otherwise the same source of randomness will be used each time. This is just a limitation of how numpyro’s effect handlers that carry state like numpyro.seed interacts with JAX’s transformations (which need the functions to be deterministic functions of the input parameters). The workaround should be simple - you just need to pass in the rng key explicitly and use that in your sample statements.

def target(T=10, q=1, r=1, phi=0., beta=0.):
    def transition(state, xs):
        i, key = xs
        # different keys for the two sample statements
        key1, key2 = random.split(key)
        x0, mu0 = state
        x1 = numpyro.sample(f'x_{i}', dist.Normal(phi * x0, q), rng_key=key1)
        mu1 = beta * mu0 + x1
        y1 = numpyro.sample(f'y_{i}', dist.Normal(mu1, r), rng_key=key2)
        return (x1, mu1), (x1, y1)

    mu0 = x0 = numpyro.sample('x_0', dist.Normal(0, q))
    y0 = numpyro.sample('y_0', dist.Normal(mu0, r))

    # Sample a rng_key and pass it to `scan`
    rng_key = numpyro.sample('key', dist.PRNGIdentity())
    _, xy = jax.lax.scan(transition, (x0, mu0), (np.arange(1, T), random.split(rng_key, T-1)))
    x, y = xy

    return np.append(x0, x), np.append(y0, y)

The only change wrt to your snippet is that we are passing and using explicit rng keys in the scanned function. Does that work for your use case?

@fehiepsi - This seems like an interesting utility function. So the basic idea is that if all the distributions inside transition are reparameterizable as loc-scale, we should be able to sample from the noise distribution beforehand and make the function a deterministic one? Once we have the reparameterizers in numpyro, this will be an interesting use case. :slight_smile:

Hi @neerajprad. Yes, this makes complete sense, thanks! I had considered a solution like this where the key is passed around as part of the state:

def transition(state, i, q=1., r=1., phi=0.5, beta=0.5):
    x0, mu0, key = state
    key, subkey1, subkey2 = jax.random.split(key, 3)
    x1 = numpyro.sample(f'x_{i}', dist.Normal(phi*x0, q), rng_key=subkey1)
    mu1 = beta * mu0 + x1
    y1 = numpyro.sample(f'y_{i}', dist.Normal(mu1, r), rng_key=subkey2)
    return (x1, mu1, key), (x1, y1)

but I didn’t know how to connect this kind of explicit key handling to the numpyro handlers (i.e. PRNGIdentity() distribution and the rng_key argument to numpyro.sample). This works perfectly, thanks!

Hi @neerajprad. Oops, with your workaround I can now generate from the model by executing target with the seed handler. But when I try to use the Predictive distribution it fails. I’m not sure if you were expecting that to work. Here’s my code:

def target(T=10, q=1., r=1., phi=0.5, beta=0.5):

    def transition(state, xs):
        i, key = xs
        key1, key2 = jax.random.split(key)
        x0, mu0 = state
        x1 = numpyro.sample(f'x_{i}', dist.Normal(phi * x0, q), rng_key=key1)
        mu1 = beta * mu0 + x1
        y1 = numpyro.sample(f'y_{i}', dist.Normal(mu1, r), rng_key=key2)
        return (x1, mu1), (x1, y1)

    mu0 = x0 = numpyro.sample('x_0', dist.Normal(0, q))
    y0 = numpyro.sample('y_0', dist.Normal(mu0, r))

    key = numpyro.sample('key', dist.PRNGIdentity())
    _, xy = jax.lax.scan(transition, (x0, mu0), (np.arange(1, T), jax.random.split(key, T-1)))
    x, y = xy
    
    return np.append(x0, x), np.append(y0, y)


prior = Predictive(target, posterior_samples = {}, num_samples = 10)
prior_samples = prior(PRNGKey(2), T=10)

And the result

UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state.
Details: Can't lift level Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)> to JaxprTrace(level=0/0).

Very interesting, I didn’t expect that but this isn’t a pattern that we have used in the past, so I’m happy to see these bugs getting percolated up and getting fixed. I have filed an issue (Predictive distribution fails on model with lax.scan · Issue #566 · pyro-ppl/numpyro · GitHub) and will be posting a follow up on that.

Stumbled across this post as I’m working on a new state space model, and just wanted to mention for others who might stumble across this thread that the issues described above for jax.lax.scan do not occur if you use numpyro.contrib.control_flow.scan.

import numpyro
import numpy as np
import jax
from numpyro import distributions as dist
from numpyro.contrib.control_flow import scan

def target(T=10, q=1, r=1, phi=0., beta=0.):
    
    def transition(state, i):
        x0, mu0 = state
        x1 = numpyro.sample(f'x_{i}', dist.Normal(phi*x0, q))
        mu1 = beta * mu0 + x1
        y1 = numpyro.sample(f'y_{i}', dist.Normal(mu1, r))
        return (x1, mu1), (x1, y1)
    
    mu0 = x0 = numpyro.sample('x_0', dist.Normal(0, q))
    y0 = numpyro.sample('y_0', dist.Normal(mu0, r))
    
    _, xy = scan(transition, (x0, mu0), np.arange(1, T))
    x, y = xy

    return np.append(x0, x), np.append(y0, y)

seeded = numpyro.handlers.seed(target, jax.random.PRNGKey(0))
seeded()

returns

(array([-1.2515389 ,  1.476229  ,  0.281745  , -1.7618588 , -0.41219947,
        -0.4498062 ,  0.36490896, -1.2858144 ,  0.95181686, -0.34184098],
       dtype=float32),
 array([-1.8381894 ,  2.5052357 , -1.1551927 , -2.4374788 , -1.5104059 ,
        -0.284664  ,  0.17208518, -1.3326082 ,  0.6942739 , -0.30786556],
       dtype=float32))
3 Likes