Model won't run using multiprocessing starmap

Hello hello,

Trying to parallelize 1-step ahead forecasts using multiprocessing . Simplified example below (excluding the Predictive step since it doesn’t seem to be the issue here):

def w_regression(X=None, y=None):
    N, K = X.shape
    
    sigma_eps        = numpyro.sample('sigma_eps', dist.HalfNormal(1))
    beta_unweighted  = numpyro.sample('beta_unweighted', dist.HalfNormal(jnp.ones(K)))
    beta             = numpyro.deterministic("beta",  beta_unweighted/jnp.sum(beta_unweighted))
    
    return numpyro.sample("y",dist.Normal(jnp.matmul(X,beta),sigma_eps),obs=y)
n_warmup = 5000
n_samples = 10000
trees = 15
target_acc_prob = 0.99

m_ = MCMC(NUTS(w_regression,
                 max_tree_depth=trees,
                 target_accept_prob=target_acc_prob),
            num_warmup=n_warmup,
            num_samples=n_samples,
            num_chains=1,
            progress_bar=False)
def forecast_w_predictive_parallel(date, X, y, mcmc_obj, model):
    
    # Training Data up to and including t
    y_train, X_train = y.copy().loc[:date].to_numpy(), X.copy().loc[:date].to_numpy()
    # Testing Data for t+1
    y_test, X_test   = y.copy().loc[date:].iloc[1]   , X.copy().loc[date:].iloc[[1]].to_numpy()

    # Run model
    mcmc_obj.run(random.PRNGKey(0),X=X_train,y=y_train)

    return 1

# Parallelizing with Pool.starmap()
fTs = X.loc['1990-01-01':,:].index

import multiprocessing as mp

pool = mp.Pool(mp.cpu_count())

results = pool.starmap(forecast_w_predictive_parallel, [(month, X, y, m_,w_regression) for  month in fTs[:-1]])

pool.close()

Trying to run the last block, it just freezes - like it just keeps running but never finishes or returns anything. Any insight as to why mcmc_obj.run won’t run? (It worked just fine in a regular for-loop).

Hi @clgarciga, have you tested this with jax alone yet to see if jax support this pattern?

I hadn’t thought of doing that, thanks.

I tried, but I’m stuck at this:

def forecast_w_predictive_parallel(n, X, y, mcmc_obj,model):
    
    # Training Data up to and including t
    n_cols = X.shape[1]
    X_train = jax.lax.dynamic_slice(X,(0,0),(n,n_cols))
    y_train = y[:n]
    # Testing Data for t+1
    y_test, X_test   = y[n+1], X[[n+1]]

    # Run model
    m_.run(random.PRNGKey(0),X=X_train,y=y_train)

    # Forecast
    predictive     = Predictive(weights_regression, posterior_samples=mcmc_obj.get_samples())
    y_pred_samples = predictive(random.PRNGKey(0), X_test)["y"]

    return y_pred_samples


X_copy = X_.copy()
X_copy['n'] = np.arange(len(X_copy))
n_start = int(X_copy.loc['1990-01-01'].n)
ns = jnp.array(range(n_start,len(X_copy)))


import jax

jax.pmap(forecast_w_predictive_parallel,static_broadcasted_argnums=(1,2,3,4))(ns,
                                                                              jnp.array(X_),
                                                                              jnp.array(y_),
                                                                              m_,
                                                                              weights_regression)

Not using dynamic slice gave me this error:

This appears to have worked!

def forecast_w_predictive_parallel(date, X, y, mcmc_obj, model):
    
    print(date)
    
    # Training Data up to and including t
    y_train, X_train = y.copy().loc[:date].to_numpy(), X.copy().loc[:date].to_numpy()
    # Testing Data for t+1
    y_test, X_test   = y.copy().loc[date:].iloc[1]   , X.copy().loc[date:].iloc[[1]].to_numpy()

    # Run model
    mcmc_obj.run(random.PRNGKey(0),X=X_train,y=y_train)

    # Forecast
    predictive    = Predictive(weights_regression, posterior_samples=m_.get_samples())
    y_predictions = predictive(random.PRNGKey(0), X_test)["y"]
    
    # Calculate forecast error
    point_fcast = jnp.median(y_predictions )
    forcast_error = y_test-point_fcast
    
    return (y_predictions,point_fcast,forcast_error)
from joblib import Parallel, delayed
Parallel(n_jobs=16)(delayed(forecast_w_predictive_parallel)(d, X_cpi_1m, y_cpi_1m, m_,weights_regression) for d in fTs[:-1])

Thank you joblib.

1 Like