Bayesian Non Parametric VAE in pyro



I am new to pyro and generative modelling.

Recently, I have been interested in models that combine the interpretability of a discrete structured PGM with NNET likelihood functions. In particular, a model such as, Nonparametric Variational Auto-encoders for Hierarchical Representation Learning. The model uses an alternating optimisation strategy, alternating between optimising the nested CRP and the VAE. The generative model consists of the Nested CRP that generates z and a NNET decoder mapping z to x.

I was excited about pyro because I believed that pyro’s flexibility would let us design models as above. Building a simple VAE is quite straightforward in other toolkits as well, such as PyTorch. Although, the DMM and AIR examples show that how powerful a PPL such as pyro could be, building models with discrete structures is still not clear to me. For e.g. in a language acquisition domain, you might want to infer the underlying linguistic structure (discrete), given speech etc.

Is pyro in the current form not suitable for models such as the one referenced above? or models that combine complicated discrete structures with NNETs.

A noob to this field. Your thoughts and comments would help me a lot.



Hi, what do you mean by suitable? Do you have a more specific question about the model in that paper? I haven’t read it carefully, but it seems to me that it would be reasonably straightforward to implement the model and algorithm described in the paper in Pyro, as long as one were sufficiently clever about batching datapoints together (as in e.g. the AIR tutorial or via a tool like Matchbox).

If you’re looking for a way to get started with reproducing that paper in particular, I’d suggest beginning with the CRP part of the generative model: write the most naive implementation of the truncated nested CRP generative model (i.e. a bunch of nested for loops or recursive calls), and then vectorize that naive version by replacing for loops over conditionally independent variables with with pyro.iarange(...) statements and poutine.broadcast.

You’d also have to implement the mean-field coordinate ascent updates yourself, as Pyro does not compute those automatically, but that should be a fairly direct translation from the equations in the paper if you write the model and mean-field variational distribution as Pyro models and use poutine.trace to grab the distributions and parameters.


Thanks for taking the time to reply. I can get going with this information.