User-defined distribution; Poisson Binomial

Hi: I’m entirely new to NumPyro / Pyro, somewhat new to probabilistic programming, and am trying to determine whether NumPyro or Pyro would be a viable option for an application I have in mind. The model is a complicated, potentially non-standard hierarchical Bayesian, so I’m inclined to think that I should at least initially run it with MCMC rather than VI. Given speed concerns, would that suggest using NumPyro rather than Pyro? The model also requires a Poisson binomial distribution, which I do not see implemented in NumPyro or Pyro. The distribution is implemented in an R package (C++ / Rcpp) and there is a pure Python implementation (though that implementation may not have a viable exact method for this application). Am hoping to be able to reuse some existing code, if possible.

I’ve noticed some discussion on this forum about creating user-defined distributions for Pyro, but not for NumPyro. Would implementing a user-defined dist for NumPyro be much different than for Pyro? Also, how difficult might this be for a new user? The exact method I need requires a fast Fourier transform (FFT), for which there is an existing C program. There are Python implementations, but I’m not sure if they are adequately fast. Thoughts?

if you want to do hmc i suggest you use numpyro. however you can’t freely mix outside packges with numpyro/jax since otherwise jax can’t compute gradients for you. your best bet, if it’s not too complicated, would be to implement the distribution yourself and use jax fft operations. the main bit of work would be computing the log_prob method and encapsulating in some generic distribution boilerplate, as you don’t e.g. need a sampler if you’re doing hmc.

Thank you, that’s very helpful! Good to know jax has fft operations.

Is there an example you’d recommend of a (fairly straightforward) distribution implementation that I could use as a template?

i guess you might look at BinomialProbs since it has a similar support, but @fehiepsi might have a better suggestion. note that you don’t need to implement methods like enumerate_support.

Much appreciated!