Hello,
I’m trying to forecast with the following model. First I tried:
def scan_fn(carry, error):
sigma, coeff_prev = carry
coeff = coeff_prev + sigma*error
coeff_prev = coeff
return (sigma, coeff_prev), coeff
def weights_regression_sv(X=None, y=None):
T, K = X.shape
beta_unweighted = numpyro.sample('beta_unweighted', dist.HalfNormal(jnp.ones(K)))
beta = numpyro.deterministic("beta", beta_unweighted/jnp.sum(beta_unweighted))
# SV
lambda_errors = numpyro.sample('lambda_errors', dist.Normal(0.,jnp.ones(T)))
log_lambda_0 = numpyro.sample('log_lambda_0', dist.Normal(0.,4))
phi = numpyro.sample('phi', dist.InverseGamma(5,.15))
__, log_lambda_t = scan(scan_fn, (phi, log_lambda_0), lambda_errors, T)
log_lambdas = numpyro.deterministic("log_lambdas", log_lambda_t)
lambdas = numpyro.deterministic("lambdas", jnp.exp(log_lambda_t))
return numpyro.sample("y",dist.Normal(jnp.matmul(X,beta),jnp.sqrt(lambdas)),obs=y)
and used
predictive = Predictive(weights_regression_sv, posterior_samples=m_sv.get_samples())
y_pred = predictive(random.PRNGKey(0), X=X_[[-1],:])["y"]
az.plot_kde(100*y_pred)
plt.show();
but got the error:
Then I tried:
def scan_fn(carry, t):
phi, log_lambda_prev, xbeta, lambda_errors = carry
log_lambda = log_lambda_prev + phi*lambda_errors[t]
lambda_ = jnp.exp(log_lambda)
y_ = numpyro.sample("y", dist.Normal(xbeta[t],jnp.sqrt(lambda_)))
log_lambda_prev = log_lambda
return (phi, log_lambda_prev, xbeta, lambda_errors), y_
def weights_regression_sv_fcast(X=None, y=None, forecast=None):
T, K = X.shape
if forecast:
T_ = T+1
else:
T_ = T
beta_unweighted = numpyro.sample('beta_unweighted', dist.HalfNormal(jnp.ones(K)))
beta = numpyro.deterministic("beta", beta_unweighted/jnp.sum(beta_unweighted))
# SV
lambda_errors = numpyro.sample('lambda_errors', dist.Normal(0.,jnp.ones(T_)))
log_lambda_0 = numpyro.sample('log_lambda_0', dist.Normal(0.,4))
phi = numpyro.sample('phi', dist.InverseGamma(5,.15))
with numpyro.handlers.condition(data={"y": y}):
xbeta = jnp.matmul(X,beta)
__, ys = scan(scan_fn, (phi, log_lambda_0, xbeta, lambda_errors), jnp.arange(0,T_))
if forecast:
numpyro.deterministic("y_forecast", ys[-1])
But got this error during model estimation:
I was working off Time Series Forecasting — NumPyro documentation, but this doesn’t really seem applicable given the use of X variables? Also I looked at Forecasting with Dynamic Linear Model (DLM) — Pyro Tutorials 1.7.0 documentation but this also doesn’t seem applicable given the absence of the Forecast module in numpyro?