Update:
This worked:
def scan_fn(carry, error):
sigma, coeff_prev = carry
coeff = coeff_prev + sigma*error
coeff_prev = coeff
return (sigma, coeff_prev), coeff
def weights_regression_sv_fcast(X=None, y=None, forecast=0):
T, K = X.shape
beta_unweighted = numpyro.sample('beta_unweighted', dist.HalfNormal(jnp.ones(K)))
beta = numpyro.deterministic("beta", beta_unweighted/jnp.sum(beta_unweighted))
# SV
lambda_errors = numpyro.sample('lambda_errors', dist.Normal(0.,jnp.ones(T)))
log_lambda_0 = numpyro.sample('log_lambda_0', dist.Normal(0.,4))
phi = numpyro.sample('phi', dist.InverseGamma(5,.15))
__, log_lambda_t = scan(scan_fn, (phi, log_lambda_0), lambda_errors, T)
log_lambdas = numpyro.deterministic("log_lambdas", log_lambda_t)
lambdas = numpyro.deterministic("lambdas", jnp.exp(log_lambda_t))
T_END = T-forecast
numpyro.sample("y_obs",dist.Normal(jnp.matmul(X[:T_END],beta),jnp.sqrt(lambdas[:T_END])),obs=y)
numpyro.sample("y_pred",dist.Normal(jnp.matmul(X[T_END:],beta),jnp.sqrt(lambdas[T_END:])),obs=None)
Is getting forecasts this way meaningfully different from getting them using
predictive = Predictive(weights_regression_sv, posterior_samples=m_sv.get_samples())
y_pred = predictive(random.PRNGKey(0), X=X_[[-1],:])["y"]
?
P.S.: Thank you to Using lax.scan for time series in NumPyro - numpyro - Pyro Discussion Forum