Cost of instantiating a Numpyro distribution

Hi all,

I’m trying to write a small lib for state-space models which can work autonomously but is also compatible with Numpyro ecosystem. I’m relying on Numpyro distributions but only instantiating them when I need to access them, keeping parameters separate. As a result there will be many calls of the following form in my code:

x = dists.Normal(**params['dist']).sample(key)


logprob = dists.Normal(**params['dist']).logprob(y)

i.e. I use Numpyro distributions for their functionality only and I don’t expect them to be passed around in the code (I only pass around their parameters). However I see that everything’s been written so that Numpyro distribs are Pytree-serializable, and that many things are done to check the args and precompute things everytime a Numpyro distribution is instantiated. Should I then worry that my usage will induce a large computational overhead ?

Thanks in advance!

PS: the reason I’m doing this is that I’m primarily working with Markov kernels M(x,y) which are such that M(x,.) is a probability distribution defined via the combination of a mapping applied to x and a base noise distribution. I guess I could write the Markov kernel class as Pytree-serializable too with both the mapping parameters and the base noise distribution stored internally. However that would force me to write all my lib as Pytree serializable objects, because Markov kernels will be attributes for many objects in my code. This doesn’t seem to fit with Jax way of thinking so I didn’t go for that yet, but maybe Numpyro should be used more like Pytorch in that respect?

Hi @mathis_c I feel that you don’t need to worry about the pytree flattenning logic. The overhead, if any, will be removed after your program is compiled.

Hi @fehiepsi, thanks for the answer!

Great that Pytree flattening doesn’t impact performance. The thing is that I’m trying not to write my own library as Pytree-serializable mostly for code simplicity (because that would force me to do it for all the objects). That might be a wrong idea and I currently have the possibility to change directions in that respect if this is seriously flawed. My main doubts are mostly related to initialization of the distributions.

Suppose I’m using MultivariateNormal distributions in very large dimensions and initializing them with their precision matrices, I see the following thing in Numpyro code:

elif precision_matrix is not None:
    loc, self.precision_matrix = promote_shapes(loc, precision_matrix)
    self.scale_tril = cholesky_of_inverse(self.precision_matrix)

This suggests that instantiating a MultivariateNormal with its parameters can become costly in large dimensions and that I should do it as little as possible in parts of the code where the parameters do not change (which would motivate a change in my lib). Currently I would have something like:

from numpyro imports distributions as dists
dist = dists.MultivariateNormal # just the class, not an instantiated object

dist_params = {'loc':loc, 'precision_matrix':precision} # suppose very high dimensional params
key, subkey = jax.random.split(key, 2)

some_samples = jax.vmap(lambda key: dist(**dist_params).sample)(jax.random.split(subkey, num_samples)) # some sampling code where the distribution is created to be sampled from

logprob = dist(**dist_params).logprob(some_other_sample) # some other code where the distribution is instantiate to evaluate a logprob of some sample

Note that I call dist(**dist_params) multiple times with the same params, therefore a priori going through the cholesky_of_inverse function multiple times. Is that kind of thing also factored out after the program is compiled ?
Thanks again.

It will recompute stuff in the constructor. It is better to create one instance, then call sample, log_prob on it.