I would like to profile my script to understand which line slows down the process. I am able to run cProfile, which is not very helpful because I can’t tell which line causes the slower speed.
def stickbreak(v):
batch_ndims = len(v.shape) - 1
cumprod_one_minus_v = jnp.exp(jnp.log1p(-v).cumsum(-1))
one_v = jnp.pad(v, [[0, 0]] * batch_ndims + [[0, 1]], constant_values=1)
c_one = jnp.pad(cumprod_one_minus_v, [[0, 0]] * batch_ndims +[[1, 0]], constant_values=1)
return one_v * c_one
C = 7
def model():
alpha = numpyro.sample('alpha', Gamma(1, 1))
with numpyro.plate('weights', C - 1):
v = numpyro.sample('v', Beta(1, alpha, C - 1))
# Sample competence
a = numpyro.sample("competenceMean", Exponential(0.3))
b = numpyro.sample("competenceVariance", Exponential(0.5))
with numpyro.plate('compet', K ):
competence = numpyro.sample("competence", Gamma(a, b))
culture_id = numpyro.sample("culture_id", Categorical(stickbreak(v)))
with numpyro.plate('latent_construct_coefficients', C*512):
feature_coefficient = numpyro.sample("feature_coefficient", Normal(0, 2))
feature_coefficient = jnp.reshape(feature_coefficient, (C, 512))
beta_intercept = numpyro.sample("beta_intercept", Exponential(0.05))
with numpyro.plate('biasTerm', K):
bias = numpyro.sample("bias", Normal(0, 0.01))
with numpyro.plate('scalingTerm', K):
scaling = numpyro.sample("scaling", Normal(0, 0.01))
with numpyro.plate("data_loop", R) as i:
f_coef = feature_coefficient[culture_id[DATA[i, 0].astype(int)].astype(int)]
face_vector = face_vectors[DATA[i,1].astype(int) - 1]
consensus = beta_intercept + f_coef.dot(face_vector.transpose())
consensus = consensus.diagonal()
mu = ((1+scaling[DATA[i, 0].astype(int)]) * consensus) + bias[DATA[i, 0].astype(int)]
precision = competence[DATA[i, 0].astype(int)]
rating = numpyro.sample("rating", Normal(mu, 1/precision), obs=DATA[i, 3].astype(float)/100)
kernel = DiscreteHMCGibbs(NUTS(model, init_strategy=numpyro.infer.init_to_sample))
mcmc = MCMC(kernel, num_warmup=100, num_samples=100,chain_method='parallel',num_chains=1, jit_model_args=True)```