I’m working on a `SkewMultivariateStudentT`

distribution. It has a `log_prob`

that looks like this:

```
def log_prob(self, value: NDArray[float]) -> NDArray[float]:
distance = value - self.loc
Qy = jnp.einsum("...j,...jk,...k->...", distance, self.prec, distance)
df_term = jnp.sqrt((self.df + self._width) / (Qy + self.df))
distance_term = distance / self._std_devs * df_term[..., jnp.newaxis]
x = jnp.squeeze(self.skewers @ distance_term[..., jnp.newaxis])
skew = t_cdf_approx(self.df + self._width, x)
return jnp.log(2) + self._mv_t.log_prob(value) + jnp.log(skew)
```

If I directly define a `SkewMultivariateStudentT`

and call `log_prob`

on it, the performance is fairly similar for both non-batched and batched versions (that is, versions where the `scale_tril`

has 2 dimensions like 5x5 or 3 dimensions like 120x5x5).

If I embed the `SkewMultivariateStudentT`

as the observed distribution in a simple model, and then extract the distribution via `trace`

, calling `log_prob`

on the batched distribution has much worse performance (around one order of magnitude) than calling `log_prob`

on on the unbatched distribution and much worse performance than calling `log_prob`

on a batched `MultivariateStudentT`

. However, that performance discrepancy only occurs when the the functions are JITed.

Also, doing a full `mcmc.run`

on a model with the batched `SkewMultivariateStudentT`

takes about two orders of magnitude longer than a model with either the unbatched `SkewMultivariateStudentT`

or the batched `MultivariateStudentT`

.

See the notebook with complete distribution code and performance tests here: batched-skew-t-perf.ipynb · GitHub.

Are these sorts of performance discrepancies expected? Any guesses as to what’s causing them? Are there workarounds? Thanks.

(Also, this may be more of a JAX issue? If so, let me know I can ask there.)