How to sample ordered samples

I’m reading the Gaussian mixture model tutorial, and with this model we have a non-identifiability problem because the group labels and their corresponding probability can be interchanged and we’d still have the same output.

One of the solutions to this problem is to enforce an ordering of the parameters. In the mixture example, we can enforce one group mean to be always smaller that the other. So my question is two fold:

  1. How to enforce an ordering in Pyro, when using pyro.sample to generate samples, how to make sure the samples for one parameter is always bigger than another? Of course if we have more than two groups we’d have to enforce an ordering on more variables. The pyro.distributions.constraints seems to only work on pyro.param but not pyro.sample?
  2. What can we do when the data is multidimensional? If the data is one-dimensional their means have a natural ordering, but if the data is multidimensional there doesn’t seem to exist such a natural ordering.

@olivierma In NumPyro, we have OrderedTransform, which can be easily ported to Pyro, for that purpose. You can use this transform with TransformedDistribution. You can look at some examples here.

What can we do when the data is multidimensional?

I’m not sure about this. If the dimension of your data is batch dimension, then you can use OrderedTransform as usual. If not, consider it as one-dimensional data, e.g. one order for [[1, 2, 3], [4, 5, 6]] is [1, 2, 3, 4, 5, 6]:smiley:

1 Like

Thanks as always @fehiepsi :smile: If the transformation can be ported easily to Pyro, I’d appreciate very much your effort!

No I’m not talking about batch dimensions, I’m talking about where the data should be modeled by multivariate distributions. In the mixture model example the data can all be aligned on the real line, I wonder what we can do when the dimension is higher. Maybe using some metric? I have no idea.

(Edit: I see in the implementation you just exponentiated all the elements from the second on to ensure they are positive, and then use cumsum to ensure the ordering. I guess without porting the class I can just do the same in my model…)

And another question if I’m not bothering you too much: before I was under the impression that Pyro is where most of the development goes, and the functionalities are later ported to NumPyro and other backends, but clearly this is not the case. Is there any functionality parity comparison between the different backends, and are there recommendations for different backends among different audience?

Is there any functionality parity comparison between the different backends, and are there recommendations for different backends among different audience?

Right now, NumPyro has a much faster implementation of NUTS, especially for small models, but Pyro has much more functionality, especially for variational inference and discrete latent variables.

Over the next few months we plan to replace much of the inference machinery in Pyro (like the internals of TraceEnum_ELBO) with our new intermediate language Funsor, which will support both JAX and PyTorch and narrow the gap in inference functionality between Pyro and NumPyro. Lower level components like normalizing flows or higher level pyro.contrib packages will mostly continue to be backend-specific.

2 Likes

Thanks!