Code efficacy for matrix factorization

Hi,

My code is extremely slow. Do you have any suggestions or improvements?

# participant, stimulus, trait, response, rt
K = len(list(set([int(row[0]) for row in DATA])))
N = len(list(set([int(row[1]) for row in DATA])))
T = len(list(set([int(row[2]) for row in DATA])))
R = len(DATA)

def stickbreak(v):
    batch_ndims = len(v.shape) - 1
    cumprod_one_minus_v = jnp.exp(jnp.log1p(-v).cumsum(-1))
    one_v = jnp.pad(v, [[0, 0]] * batch_ndims + [[0, 1]], constant_values=1)
    c_one = jnp.pad(cumprod_one_minus_v, [[0, 0]] * batch_ndims +[[1, 0]], constant_values=1)
    return one_v * c_one

C = 7
def model():
    alpha = numpyro.sample('alpha', Gamma(1, 1))
    with numpyro.plate('weights', C - 1):
        v = numpyro.sample('v', Beta(1, alpha, C - 1))
    dimension = 7
    # Sample competence
    a = numpyro.sample("competenceMean", Exponential(0.1))
    b = numpyro.sample("competenceVariance", Exponential(1))
    with numpyro.plate('compet', K ):
        competence = numpyro.sample("competence", Gamma(a, b))
        culture_id = numpyro.sample("culture_id",  Categorical(stickbreak(v)))

    # Latent construct of image and language features
    visual_f_prior = numpyro.sample("visual_f_prior", Gamma(10, 1))
    language_f_prior = numpyro.sample("language_f_prior", Gamma(10, 1))
    with numpyro.plate('latent_visual_coefficients', C*dimension*512):
        visual_f_coefficient =  numpyro.sample("visual_f_coefficient", Normal(0, 1/visual_f_prior))
        visual_f_coefficient = jnp.reshape(visual_f_coefficient, (C, dimension, 512))

    with numpyro.plate('latent_language_coefficients', C*dimension*300):
        language_f_coefficient =  numpyro.sample("language_f_coefficient", Normal(0, 1/language_f_prior))
        language_f_coefficient = jnp.reshape(language_f_coefficient, (C, dimension, 300))
    
    #Bias
    with numpyro.plate('biasTerm', K):
        bias = numpyro.sample("bias", Normal(0, 0.01))
    with numpyro.plate('scalingTerm', K):
        scaling = numpyro.sample("scaling", Normal(0, 0.01))

    with numpyro.plate("data_loop", R) as i:
        visual_face_vector = visual_face_vectors[:,DATA[i,1].astype(int) - 1]
        visual_f = visual_f_coefficient[culture_id[DATA[i,0].astype(int)].astype(int)]
        visual_latent_factor = jnp.einsum('mik,ijk->ij', visual_face_vector, visual_f)

        language_face_vector = language_trait_vectors[:,DATA[i,2].astype(int) - 1]
        language_f = language_f_coefficient[culture_id[DATA[i,0].astype(int)].astype(int)]
        language_latent_factor = jnp.einsum('mik,ijk->ij', language_face_vector, language_f)
        consensus = jnp.einsum('ab,ab-> a', language_latent_factor, visual_latent_factor)
        # consensus = jnp.einsum('ab,ab->a', language_latent_factor, visual_latent_factor)
        mu = ((1+scaling[DATA[i, 0].astype(int)]) * consensus) + bias[DATA[i, 0].astype(int)]
        precision = competence[DATA[i, 0].astype(int)]
        rating = numpyro.sample("rating", Normal(mu, 1/precision), obs=logit(DATA[i, 3].astype(float)/100))

kernel = DiscreteHMCGibbs(NUTS(model, init_strategy=numpyro.infer.init_to_sample))
mcmc = MCMC(kernel, num_warmup=100, num_samples=100,chain_method='parallel',num_chains=1, jit_model_args=True)
mcmc.run(random.PRNGKey(0))
samples = mcmc.get_samples(group_by_chain=True)
mcmc.print_summary()
diagnos = numpyro.diagnostics.summary(mcmc.get_samples(group_by_chain=True))

you should probably reparameterize visual_f_coefficient and language_f_coefficient. see this tutorial

note that it looks like you have a high-dimensional problem so there are limits to how fast it can be…