Parallelizing eigenvalue optimization

Hi

I’m currently attempting an eigenvalue optimization problem. The observed data corresponds to eigenvalues of a given matrix and the elements of the matrix depend on some model parameters. Here is a simplified version of the problem I’m attempting.

def model():
    # setting min and max value to be 0.1*true and 3.*true
    w1min, w1max = .1*abs(w1t), 3.*abs(w1t)
    w3min, w3max = .1*abs(w3t), 3.*abs(w3t)
    w5min, w5max = .1*abs(w5t), 3.*abs(w5t)
    
    w1 = numpyro.sample('w1', dist.Uniform(w1min, w1max))
    w3 = numpyro.sample('w3', dist.Uniform(w3min, w3max))
    w5 = numpyro.sample('w5', dist.Uniform(w5min, w5max))
    
    sigma = numpyro.sample('sigma', dist.Uniform(0.1, 10.0))
    eig_sample = numpyro.deterministic('eig', eig_mcmc_func(w1=w1, w3=w3, w5=w5))
    return numpyro.sample('obs', dist.Normal(eig_sample, sigma), obs=eigvals_true)

def eig_mcmc_func(w1=None, w3=None, w5=None):
    return  get_eigs(create_supermatrix(w1, w3, w5))/2./omega0
    

#creates the matrix - details not pertinent for this question
def create_supermatrix(w1, w3, w5):
    integrand1 = Tsr[0, :] * w1
    integrand3 = Tsr[1, :] * w3
    integrand5 = Tsr[2, :] * w5
    integral1 = trapz(integrand1, x=r)
    integral3 = trapz(integrand3, x=r)
    integral5 = trapz(integrand5, x=r)
    prod_gamma1 = gamma(ell1)*gamma(ell2)*gamma(s_arr[0])
    prod_gamma3 = gamma(ell1)*gamma(ell2)*gamma(s_arr[0])
    prod_gamma5 = gamma(ell1)*gamma(ell2)*gamma(s_arr[0])
    wpi = (wigvals[:, 0]*integral1*prod_gamma1 + 
           wigvals[:, 1]*integral3*prod_gamma3 +
           wigvals[:, 2]*integral5*prod_gamma5)
    #diag = minus1pow_vecm(m)*8*np.pi*omega0*(wigvals @ (prod_gammas * integral))
    #diag = minus1pow_vecm(m)*8*np.pi*omega0*wpi
    diag = 8*np.pi*omega0*wpi
    supmat = jnp.diag(diag)
    return supmat

def get_eigs(mat):
    eigvals, eigvecs = jnp.linalg.eigh(mat)
    return eigvals

def gamma(ell):
    return jnp.sqrt((2*ell+1)/4./np.pi)

I’m using CPUs for the computation. Is there a way to assign multiple processors to the eigenvalue solver so that the computation is sped up?

In the problem I’m attempting, I also have about 200 such matrices which depend on the same model parameters w1, w3, w5. Is there a way to parallelize the sampler so the eigenvalue solving of each matrix is computed on a different CPU (different sets of CPUs)?

I think you can raise this question to in jax discussions to get better answers. You might try jax.vmap first - hopefully it will utilize multi-cores for your computation. I’m not sure if it will work but you can also try jax.pmap your function. In NumPyro, we use pmap to run parallel chains, not for parallel computation of the density.