 # Parallelizing eigenvalue optimization

Hi

I’m currently attempting an eigenvalue optimization problem. The observed data corresponds to eigenvalues of a given matrix and the elements of the matrix depend on some model parameters. Here is a simplified version of the problem I’m attempting.

``````def model():
# setting min and max value to be 0.1*true and 3.*true
w1min, w1max = .1*abs(w1t), 3.*abs(w1t)
w3min, w3max = .1*abs(w3t), 3.*abs(w3t)
w5min, w5max = .1*abs(w5t), 3.*abs(w5t)

w1 = numpyro.sample('w1', dist.Uniform(w1min, w1max))
w3 = numpyro.sample('w3', dist.Uniform(w3min, w3max))
w5 = numpyro.sample('w5', dist.Uniform(w5min, w5max))

sigma = numpyro.sample('sigma', dist.Uniform(0.1, 10.0))
eig_sample = numpyro.deterministic('eig', eig_mcmc_func(w1=w1, w3=w3, w5=w5))
return numpyro.sample('obs', dist.Normal(eig_sample, sigma), obs=eigvals_true)

def eig_mcmc_func(w1=None, w3=None, w5=None):
return  get_eigs(create_supermatrix(w1, w3, w5))/2./omega0

#creates the matrix - details not pertinent for this question
def create_supermatrix(w1, w3, w5):
integrand1 = Tsr[0, :] * w1
integrand3 = Tsr[1, :] * w3
integrand5 = Tsr[2, :] * w5
integral1 = trapz(integrand1, x=r)
integral3 = trapz(integrand3, x=r)
integral5 = trapz(integrand5, x=r)
prod_gamma1 = gamma(ell1)*gamma(ell2)*gamma(s_arr)
prod_gamma3 = gamma(ell1)*gamma(ell2)*gamma(s_arr)
prod_gamma5 = gamma(ell1)*gamma(ell2)*gamma(s_arr)
wpi = (wigvals[:, 0]*integral1*prod_gamma1 +
wigvals[:, 1]*integral3*prod_gamma3 +
wigvals[:, 2]*integral5*prod_gamma5)
#diag = minus1pow_vecm(m)*8*np.pi*omega0*(wigvals @ (prod_gammas * integral))
#diag = minus1pow_vecm(m)*8*np.pi*omega0*wpi
diag = 8*np.pi*omega0*wpi
supmat = jnp.diag(diag)
return supmat

def get_eigs(mat):
eigvals, eigvecs = jnp.linalg.eigh(mat)
return eigvals

def gamma(ell):
return jnp.sqrt((2*ell+1)/4./np.pi)
``````

I’m using CPUs for the computation. Is there a way to assign multiple processors to the eigenvalue solver so that the computation is sped up?

In the problem I’m attempting, I also have about 200 such matrices which depend on the same model parameters w1, w3, w5. Is there a way to parallelize the sampler so the eigenvalue solving of each matrix is computed on a different CPU (different sets of CPUs)?

I think you can raise this question to in jax discussions to get better answers. You might try `jax.vmap` first - hopefully it will utilize multi-cores for your computation. I’m not sure if it will work but you can also try `jax.pmap` your function. In NumPyro, we use `pmap` to run parallel chains, not for parallel computation of the density.