Hi,
My code is extremely slow. Do you have any suggestions or improvements?
# participant, stimulus, trait, response, rt
K = len(list(set([int(row[0]) for row in DATA])))
N = len(list(set([int(row[1]) for row in DATA])))
T = len(list(set([int(row[2]) for row in DATA])))
R = len(DATA)
def stickbreak(v):
batch_ndims = len(v.shape) - 1
cumprod_one_minus_v = jnp.exp(jnp.log1p(-v).cumsum(-1))
one_v = jnp.pad(v, [[0, 0]] * batch_ndims + [[0, 1]], constant_values=1)
c_one = jnp.pad(cumprod_one_minus_v, [[0, 0]] * batch_ndims +[[1, 0]], constant_values=1)
return one_v * c_one
C = 7
def model():
alpha = numpyro.sample('alpha', Gamma(1, 1))
with numpyro.plate('weights', C - 1):
v = numpyro.sample('v', Beta(1, alpha, C - 1))
dimension = 7
# Sample competence
a = numpyro.sample("competenceMean", Exponential(0.1))
b = numpyro.sample("competenceVariance", Exponential(1))
with numpyro.plate('compet', K ):
competence = numpyro.sample("competence", Gamma(a, b))
culture_id = numpyro.sample("culture_id", Categorical(stickbreak(v)))
# Latent construct of image and language features
visual_f_prior = numpyro.sample("visual_f_prior", Gamma(10, 1))
language_f_prior = numpyro.sample("language_f_prior", Gamma(10, 1))
with numpyro.plate('latent_visual_coefficients', C*dimension*512):
visual_f_coefficient = numpyro.sample("visual_f_coefficient", Normal(0, 1/visual_f_prior))
visual_f_coefficient = jnp.reshape(visual_f_coefficient, (C, dimension, 512))
with numpyro.plate('latent_language_coefficients', C*dimension*300):
language_f_coefficient = numpyro.sample("language_f_coefficient", Normal(0, 1/language_f_prior))
language_f_coefficient = jnp.reshape(language_f_coefficient, (C, dimension, 300))
#Bias
with numpyro.plate('biasTerm', K):
bias = numpyro.sample("bias", Normal(0, 0.01))
with numpyro.plate('scalingTerm', K):
scaling = numpyro.sample("scaling", Normal(0, 0.01))
with numpyro.plate("data_loop", R) as i:
visual_face_vector = visual_face_vectors[:,DATA[i,1].astype(int) - 1]
visual_f = visual_f_coefficient[culture_id[DATA[i,0].astype(int)].astype(int)]
visual_latent_factor = jnp.einsum('mik,ijk->ij', visual_face_vector, visual_f)
language_face_vector = language_trait_vectors[:,DATA[i,2].astype(int) - 1]
language_f = language_f_coefficient[culture_id[DATA[i,0].astype(int)].astype(int)]
language_latent_factor = jnp.einsum('mik,ijk->ij', language_face_vector, language_f)
consensus = jnp.einsum('ab,ab-> a', language_latent_factor, visual_latent_factor)
# consensus = jnp.einsum('ab,ab->a', language_latent_factor, visual_latent_factor)
mu = ((1+scaling[DATA[i, 0].astype(int)]) * consensus) + bias[DATA[i, 0].astype(int)]
precision = competence[DATA[i, 0].astype(int)]
rating = numpyro.sample("rating", Normal(mu, 1/precision), obs=logit(DATA[i, 3].astype(float)/100))
kernel = DiscreteHMCGibbs(NUTS(model, init_strategy=numpyro.infer.init_to_sample))
mcmc = MCMC(kernel, num_warmup=100, num_samples=100,chain_method='parallel',num_chains=1, jit_model_args=True)
mcmc.run(random.PRNGKey(0))
samples = mcmc.get_samples(group_by_chain=True)
mcmc.print_summary()
diagnos = numpyro.diagnostics.summary(mcmc.get_samples(group_by_chain=True))