Hi all,
I’m currently try to create a model to sample from a list of mean and lower cholesky matrix.
The goal is to use a set of mean and the lower cholesky matrix to individually sample from multivariate normal distribution.
The question is to understand if the multivariate normal distributions are sampled individually from each set of mean and cholesky matrix when used indexing.
I would like to avoid the for loop version as much as possible as it slows down the sampling.
def for_loop_model(means, L_corrs):
"""Ability model to estimate the mean and covariance across of all abilities
Args:
means (ndarray): Mean of each subject of each person (n_persons x n_factors)
L_corrs (ndarray): Correlation matrix of each person (n_persons x n_factors)
"""
=== FOR LOOP VERSION ===
abs = []
# Current understanding:
# Sampling Multivariate Normal with each person's means and L_corrs
for i in range(n_persons):
sample_abilities = numpyro.sample(
f"sample_abs_{i}",
dist.MultivariateNormal(loc=means[i], scale_tril=L_corrs[i]),
)
abs.append(sample_abilities)
abs = jnp.array(abs)
=== OR ===
=== INDEXING + PLATE ===
# Using index to get the mean and L_corrs of each person and individually sample
# the multivariate normal distribution
with plate("student_2", n_persons, dim=-1):
sample_abilities = numpyro.sample(
"sample_abs",
dist.MultivariateNormal(loc=means[persons_idx], scale_tril=L_corrs[persons_idx]),
)
...
Thanks in advance!