Hello hello,
Trying to parallelize 1-step ahead forecasts using multiprocessing . Simplified example below (excluding the Predictive step since it doesn’t seem to be the issue here):
def w_regression(X=None, y=None):
N, K = X.shape
sigma_eps = numpyro.sample('sigma_eps', dist.HalfNormal(1))
beta_unweighted = numpyro.sample('beta_unweighted', dist.HalfNormal(jnp.ones(K)))
beta = numpyro.deterministic("beta", beta_unweighted/jnp.sum(beta_unweighted))
return numpyro.sample("y",dist.Normal(jnp.matmul(X,beta),sigma_eps),obs=y)
n_warmup = 5000
n_samples = 10000
trees = 15
target_acc_prob = 0.99
m_ = MCMC(NUTS(w_regression,
max_tree_depth=trees,
target_accept_prob=target_acc_prob),
num_warmup=n_warmup,
num_samples=n_samples,
num_chains=1,
progress_bar=False)
def forecast_w_predictive_parallel(date, X, y, mcmc_obj, model):
# Training Data up to and including t
y_train, X_train = y.copy().loc[:date].to_numpy(), X.copy().loc[:date].to_numpy()
# Testing Data for t+1
y_test, X_test = y.copy().loc[date:].iloc[1] , X.copy().loc[date:].iloc[[1]].to_numpy()
# Run model
mcmc_obj.run(random.PRNGKey(0),X=X_train,y=y_train)
return 1
# Parallelizing with Pool.starmap()
fTs = X.loc['1990-01-01':,:].index
import multiprocessing as mp
pool = mp.Pool(mp.cpu_count())
results = pool.starmap(forecast_w_predictive_parallel, [(month, X, y, m_,w_regression) for month in fTs[:-1]])
pool.close()
Trying to run the last block, it just freezes - like it just keeps running but never finishes or returns anything. Any insight as to why mcmc_obj.run
won’t run? (It worked just fine in a regular for-loop).