Wildly varying performance for batched `log_prob` in custom distribution depending on invocation

I’m working on a SkewMultivariateStudentT distribution. It has a log_prob that looks like this:

    def log_prob(self, value: NDArray[float]) -> NDArray[float]:
        distance = value - self.loc
        Qy = jnp.einsum("...j,...jk,...k->...", distance, self.prec, distance)
        df_term = jnp.sqrt((self.df + self._width) / (Qy + self.df))
        distance_term = distance / self._std_devs * df_term[..., jnp.newaxis]
        x = jnp.squeeze(self.skewers @ distance_term[..., jnp.newaxis])
        skew = t_cdf_approx(self.df + self._width, x)
        return jnp.log(2) + self._mv_t.log_prob(value) + jnp.log(skew)

If I directly define a SkewMultivariateStudentT and call log_prob on it, the performance is fairly similar for both non-batched and batched versions (that is, versions where the scale_tril has 2 dimensions like 5x5 or 3 dimensions like 120x5x5).

If I embed the SkewMultivariateStudentT as the observed distribution in a simple model, and then extract the distribution via trace, calling log_prob on the batched distribution has much worse performance (around one order of magnitude) than calling log_prob on on the unbatched distribution and much worse performance than calling log_prob on a batched MultivariateStudentT. However, that performance discrepancy only occurs when the the functions are JITed.

Also, doing a full mcmc.run on a model with the batched SkewMultivariateStudentT takes about two orders of magnitude longer than a model with either the unbatched SkewMultivariateStudentT or the batched MultivariateStudentT.

See the notebook with complete distribution code and performance tests here: batched-skew-t-perf.ipynb · GitHub.

Are these sorts of performance discrepancies expected? Any guesses as to what’s causing them? Are there workarounds? Thanks.

(Also, this may be more of a JAX issue? If so, let me know I can ask there.)

By performance, I guess you meant speed performance? In JAX, as far as I know, we should avoid jitting internal functions. You can see that we rarely use jit in NumPyro (except for the outermost loops).

Comparing MVT and SkewMVT under MCMC is not fair because SkewMVT has additional expensive operators like cholesky, cho_solve,… In addition, SkewMVT might have additional numerical errors that make MCMC chains not mixing well, hence each MCMC step requires longer trajectory.

Thanks for the reply!

I guess you meant speed performance?

Yes, I meant speed.

as far as I know, we should avoid jitting internal functions

I’m only JITing these small fragments here because I was trying to narrow down the performance problem (while still JITing since I believe these fragments would be JITed as part of the larger whole in a normal MCMC run).

because SkewMVT has additional expensive operators like cholesky, cho_solve,…

Yeah, this was what I suspected but wasn’t 100% certain on. Most of those expensive operations are done during the __init__ function of the distribution. So I think that means those operations don’t need to be performed when log_prob is invoked directly on the distribution (so its fast), but I think they are invoked when log_prob is pulled from the model trace (so its slow). And those same operations don’t get a big speedup from JITing which explains why JITing makes batched SkewMVT look much worse in relative terms.

Does that theory sound plausible? If so, is there any way to move the expensive operations in __init__ outside of the hot loop so that the log_prob from trace is more like the direct log_prob from the distribution? I sort of assume not because when SkewMVT is used within the model it’s used with varying loc, df, etc from the other parts of the model. But I don’t have a great mental model for what NumPyro is doing behind the scenes.

Thanks again.

those operations don’t need to be performed when log_prob is invoked directly on the distribution

I think that’s what going on under the hood. We trace the model once and use the collected distributions to compute log probabilities. I’m not sure why we need to invoke the init again when computing log_prob.

is there any way to move the expensive operations in __init__ outside of the hot loop so that the log_prob from trace is more like the direct log_prob from the distribution

If those arrays are constants, you can use numpy/scipy operator on them. The output of those operators will be constant when compiled.

The __init__ operations are constant with respect to the values passed into log_prob but not with respect to the distribution’s parameters (loc, scale_tril, etc) which are also part of the model. I think this explains why that part of the __init__ is invoked again (to accommodate changing loc, scale_tril, etc.) and I think means the numpy/scipy approach won’t work here. (But I could still be very confused about how NumPyro works.) Anyway, I get TracerArrayConversionErrors if I try to use the numpy/scipy approach.

That said, this is mostly moot now since, by rearranging the linear algebra, I was able to get the runtime down to within a factor of 2 of MultivariateStudentT which seems much more reasonable.

Thanks again for the help!

1 Like