Hello,
I’m trying to forecast with the following model. First I tried:
def scan_fn(carry, error):
    sigma, coeff_prev = carry
    coeff = coeff_prev + sigma*error
    coeff_prev = coeff
    
    return (sigma, coeff_prev), coeff
def weights_regression_sv(X=None, y=None):
    
    T, K = X.shape
    
    beta_unweighted   = numpyro.sample('beta_unweighted', dist.HalfNormal(jnp.ones(K)))
    beta              = numpyro.deterministic("beta",  beta_unweighted/jnp.sum(beta_unweighted))
    
    # SV
    lambda_errors      = numpyro.sample('lambda_errors', dist.Normal(0.,jnp.ones(T)))
    log_lambda_0       = numpyro.sample('log_lambda_0', dist.Normal(0.,4))
    phi                = numpyro.sample('phi', dist.InverseGamma(5,.15))
    __, log_lambda_t   = scan(scan_fn, (phi, log_lambda_0), lambda_errors, T)
    log_lambdas        = numpyro.deterministic("log_lambdas",  log_lambda_t)
    lambdas            = numpyro.deterministic("lambdas",  jnp.exp(log_lambda_t))
    
    return numpyro.sample("y",dist.Normal(jnp.matmul(X,beta),jnp.sqrt(lambdas)),obs=y)
and used
predictive = Predictive(weights_regression_sv, posterior_samples=m_sv.get_samples())
y_pred = predictive(random.PRNGKey(0), X=X_[[-1],:])["y"]
az.plot_kde(100*y_pred)
plt.show();
but got the error:
Then I tried:
def scan_fn(carry, t):
    phi, log_lambda_prev, xbeta, lambda_errors = carry
    
    log_lambda = log_lambda_prev + phi*lambda_errors[t]
    lambda_ = jnp.exp(log_lambda)
    
    y_ = numpyro.sample("y", dist.Normal(xbeta[t],jnp.sqrt(lambda_)))
    
    log_lambda_prev = log_lambda
    return (phi, log_lambda_prev, xbeta, lambda_errors), y_
def weights_regression_sv_fcast(X=None, y=None, forecast=None):
    
    T, K = X.shape
    
    if forecast:
        T_ = T+1
    else:
        T_ = T
    
    beta_unweighted   = numpyro.sample('beta_unweighted', dist.HalfNormal(jnp.ones(K)))
    beta              = numpyro.deterministic("beta",  beta_unweighted/jnp.sum(beta_unweighted))
    
    # SV
    lambda_errors      = numpyro.sample('lambda_errors', dist.Normal(0.,jnp.ones(T_)))
    log_lambda_0       = numpyro.sample('log_lambda_0', dist.Normal(0.,4))
    phi                = numpyro.sample('phi', dist.InverseGamma(5,.15))
    
    with numpyro.handlers.condition(data={"y": y}):
        xbeta = jnp.matmul(X,beta)
        __, ys   = scan(scan_fn, (phi, log_lambda_0, xbeta, lambda_errors), jnp.arange(0,T_))
    if forecast:
        numpyro.deterministic("y_forecast", ys[-1])
But got this error during model estimation:
I was working off Time Series Forecasting — NumPyro documentation, but this doesn’t really seem applicable given the use of X variables? Also I looked at Forecasting with Dynamic Linear Model (DLM) — Pyro Tutorials 1.7.0 documentation but this also doesn’t seem applicable given the absence of the Forecast module in numpyro?



