Hello,
I am trying to sample the following random walk with truncated-normal errors:
\alpha_t = \alpha_{t-1} + \epsilon_t, \ \ \ \epsilon_t \sim Normal(0,\sigma_{\alpha}^2)[0-\alpha_{t-1},1-\alpha_{t-1}]
such that \alpha_t \in (0,1) for all t.
My code is:
...
from numpyro.contrib.control_flow import scan
...
def model(x: np.ndarray = None, y: np.ndarray = None):
T = len(y)
# ################
# Prior
# ################
y_sigma = numpyro.sample('y_sigma', dist.HalfNormal(scale=1))
sigma_bar = numpyro.sample('sigma_bar', dist.HalfNormal(scale=1))
alpha_sigma = numpyro.sample('alpha_sigma', dist.HalfNormal(scale=1))
alpha_0 = numpyro.sample('alpha_0', dist.TruncatedNormal(loc=0,scale=1,low=0.01,high=0.99), sample_shape=(1,))
def transition_fn(carry,t):
sigma, param_old = carry
err = numpyro.sample(f'alpha_errs_{t}',
dist.TruncatedNormal(loc=0,scale=alpha_sigma,low=0.001-param_old,high=0.999-param_old))
param_new = param_old + err
return (sigma, param_new), (param_new,err)
_, (_alpha,errs) = scan(transition_fn, init = (alpha_sigma,alpha_0), xs = jnp.arange(0, T))
alpha = numpyro.deterministic('alpha',_alpha.flatten())
# ################
# Likelihood
# ################
def transition_fn(carry,t):
y_sigma = carry
ymean = jnp.log(sigma_bar) + alpha[t]*x[t]
yhat = numpyro.sample("y", dist.Normal(loc=ymean, scale=y_sigma))
return (y_sigma), yhat
with numpyro.handlers.condition(data={"y": y}):
_, yhat = scan(transition_fn, init = (y_sigma), xs = jnp.arange(0, T))
I am running into the rng_key problem for the block:
err = numpyro.sample(f'alpha_errs_{t}',
dist.TruncatedNormal(loc=0,scale=alpha_sigma,low=0.001-param_old,high=0.999-param_old))
as described here: State-space model: is lax.scan compatible with numpyro.sample? - numpyro - Pyro Discussion Forum
The suggestion to use rng_key = numpyro.sample('key', dist.PRNGIdentity())
fails for me (maybe it was removed in later versions of numpyro).
What would be the current proper way to get a new rng_key for each sample call when using scan?
I’m using numpyro version 0.11. Thanks in advance.