Advice on forecast cross-validation

Hello,

I am interested in estimating a model to use in forecasting, and in particular, I would like to estimate forecast accuracy using the following CV scheme:

(1) Given R total observations (rows of data in X and y), estimate the model on the first T (where T < R) observations;

(2) use the T+1 row of X to forecast the T+1 value of y (call this y_fcast_1)

Then, repeat (1) and (2) for estimation on T + j and forecasting on T+j+1 (and collecting y_fcast_j) until we reach estimation on R-1 and forecasting of observation R.

So far I have:

def linear_regression_model(X=None, y=None):
    N = X.shape[0]
    K = X.shape[1]
    
    sigma_2_eps   = numpyro.sample('sigma_2_eps', dist.HalfNormal(1))
    beta =  numpyro.sample("beta", dist.MultivariateNormal(loc=jnp.zeros(K), covariance_matrix=25*jnp.eye(K)))
    beta0      = numpyro.sample('beta0', dist.Normal(0, 5))
    
    numpyro.sample("y",
                   dist.MultivariateNormal(loc=beta0+jnp.matmul(X,beta),covariance_matrix=sigma_2_eps*jnp.eye(N)),
                   obs=y)

rng_key = random.PRNGKey(0)
kernel = NUTS(linear_regression_model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000, num_chains=1)
mcmc.run(rng_key, X=X_T, y=y_T)
mcmc.print_summary()
samples = mcmc.get_samples()

rng_key, rng_key_ = random.split(rng_key)
predictive = Predictive(linear_regression_model, samples)
predictions = predictive(rng_key_, X = X_T_plus_1)['y']
y_fcast_T_plus_1 = jnp.mean(predictions,0)

This seems to work fine for the first forecast (assuming of course I’m doing the forecasting correctly…), and I suppose I could just create a loop around this to achieve creating and storing my forecasts of y from T+1 to R, but I’m wondering if there is a cleaner/more efficient/faster way to achieve this?

Thanks in advance.

Best,
Christian