I’m working on a
SkewMultivariateStudentT distribution. It has a
log_prob that looks like this:
def log_prob(self, value: NDArray[float]) -> NDArray[float]: distance = value - self.loc Qy = jnp.einsum("...j,...jk,...k->...", distance, self.prec, distance) df_term = jnp.sqrt((self.df + self._width) / (Qy + self.df)) distance_term = distance / self._std_devs * df_term[..., jnp.newaxis] x = jnp.squeeze(self.skewers @ distance_term[..., jnp.newaxis]) skew = t_cdf_approx(self.df + self._width, x) return jnp.log(2) + self._mv_t.log_prob(value) + jnp.log(skew)
If I directly define a
SkewMultivariateStudentT and call
log_prob on it, the performance is fairly similar for both non-batched and batched versions (that is, versions where the
scale_tril has 2 dimensions like 5x5 or 3 dimensions like 120x5x5).
If I embed the
SkewMultivariateStudentT as the observed distribution in a simple model, and then extract the distribution via
log_prob on the batched distribution has much worse performance (around one order of magnitude) than calling
log_prob on on the unbatched distribution and much worse performance than calling
log_prob on a batched
MultivariateStudentT. However, that performance discrepancy only occurs when the the functions are JITed.
Also, doing a full
mcmc.run on a model with the batched
SkewMultivariateStudentT takes about two orders of magnitude longer than a model with either the unbatched
SkewMultivariateStudentT or the batched
See the notebook with complete distribution code and performance tests here: batched-skew-t-perf.ipynb · GitHub.
Are these sorts of performance discrepancies expected? Any guesses as to what’s causing them? Are there workarounds? Thanks.
(Also, this may be more of a JAX issue? If so, let me know I can ask there.)