Scan with sample statement needs rng_key

Hello,

I am trying to sample the following random walk with truncated-normal errors:

\alpha_t = \alpha_{t-1} + \epsilon_t, \ \ \ \epsilon_t \sim Normal(0,\sigma_{\alpha}^2)[0-\alpha_{t-1},1-\alpha_{t-1}]

such that \alpha_t \in (0,1) for all t.

My code is:

...
from numpyro.contrib.control_flow import scan
...

def model(x: np.ndarray = None, y: np.ndarray = None):

    T = len(y)

    # ################
    # Prior 
    # ################

    y_sigma  = numpyro.sample('y_sigma', dist.HalfNormal(scale=1))
    sigma_bar = numpyro.sample('sigma_bar', dist.HalfNormal(scale=1))
    alpha_sigma =  numpyro.sample('alpha_sigma', dist.HalfNormal(scale=1))

    alpha_0 = numpyro.sample('alpha_0', dist.TruncatedNormal(loc=0,scale=1,low=0.01,high=0.99), sample_shape=(1,))
    
    def transition_fn(carry,t):
        sigma, param_old = carry
        err  = numpyro.sample(f'alpha_errs_{t}',
                               dist.TruncatedNormal(loc=0,scale=alpha_sigma,low=0.001-param_old,high=0.999-param_old))
        param_new = param_old + err
        return (sigma, param_new), (param_new,err)

    _, (_alpha,errs) = scan(transition_fn, init = (alpha_sigma,alpha_0), xs = jnp.arange(0, T))

    alpha = numpyro.deterministic('alpha',_alpha.flatten())

    # ################
    # Likelihood 
    # ################

    def transition_fn(carry,t):
        y_sigma = carry
        ymean = jnp.log(sigma_bar) + alpha[t]*x[t]
        yhat = numpyro.sample("y", dist.Normal(loc=ymean, scale=y_sigma))

        return (y_sigma), yhat
            
    with numpyro.handlers.condition(data={"y": y}):
        _, yhat = scan(transition_fn, init = (y_sigma), xs = jnp.arange(0, T))

I am running into the rng_key problem for the block:

err  = numpyro.sample(f'alpha_errs_{t}',
                               dist.TruncatedNormal(loc=0,scale=alpha_sigma,low=0.001-param_old,high=0.999-param_old))

as described here: State-space model: is lax.scan compatible with numpyro.sample? - numpyro - Pyro Discussion Forum

The suggestion to use rng_key = numpyro.sample('key', dist.PRNGIdentity()) fails for me (maybe it was removed in later versions of numpyro).

What would be the current proper way to get a new rng_key for each sample call when using scan?

I’m using numpyro version 0.11. Thanks in advance.

Think I answered my own question. Is this correct?

    def transition_fn(carry,t):
        sigma, param_old = carry
        err  = numpyro.sample(f'alpha_errs',
                               dist.TruncatedNormal(loc=0,scale=alpha_sigma,low=0.001-param_old,high=0.999-param_old))
        param_new = param_old + err
        return (sigma, param_new), (param_new,err)

    with numpyro.handlers.seed(rng_seed=0):
        _, (_alpha,errs) = scan(transition_fn, init = (alpha_sigma,alpha_0), xs = jnp.arange(0, T))

To sample from a numpyro model, you need to provide seed, something like

numpyro.handlers.seed(model, rng_seed=0)(...)

or

with numpyro.handlers.seed(rng_seed=0):
  model(...)