Optimizing for Statistical Properties of Probabilistic Programs

Hello, I am new to Pyro and I am struggling to express my ideas in the language. I have so far read “An Introduction to Models in Pyro”, “An Introduction to Inference in Pyro”, “SVI Part I: An Introduction to Stochastic Variational Inference in Pyro”, “MLE and MAP Estimation”, and parts of other documentation.

I have a probabilistic program “P” that uses both determinstic and probabilistic (pyro.sample) statements to calculate/generate a scalar value “x”. The way I understand it, “P” thus defines a random variable “x” with some distribution “D”, and executing “P” generates a sample from “D”. I would like Pyro to optimize the parameters of “P” in such a way that “D” fulfills certain properties. For instance, I would like “D” to have a certain mean, median, standard deviation, skewness, etc…

All of the examples I have seen express optimization goals via observed data in pyro.sample statements. However, my optimization goals are not about individual samples but rather properties of the whole distribution, and I am not sure how to express this in Pyro. It seems like in Pyro it is very easy to infer parameters from data, but not from statistical properties, even though the former feels like the harder problem to me.

Can my problem be expressed in the framework of statistical inference or does it simply not fit? If not, it seems that Pyro is very much geared towards statistical inference but maybe it can still help me solve my problem? Or should I try to use a different library or even use PyTorch directly?

I would be grateful for any pointers on how to approach this problem and which tools can help me with it.

Hi @jules, it sounds like you want to implement some sort of moment matching inference. That sounds like it might be outside of Pyro’s Bayesian focus area, but you would probably have good luck using the reparametrized samplers in torch.distributions and pyro.distributions, that is using the .rsample() method to draw batches of samples from parametrized distributions and directly optimizing resulting moments using a PyTorch optimizer in a typical stochastic training loop. If you wanted to use Pyro machinery, you could use pyro.plate to expand a model to a batch of samples and use poutine.trace to record sample statements, but it’s probably easier to hand-implement that stuff for pure moment matching.

Thank you very much for the useful pointers!

I managed to get some minimal examples to work using plain torch.