Order Dirichlet Process

Hello. I have a Dirichlet Procees model, very similar that outlined in Dirichlet Process Mixture Models in Pyro — Pyro Tutorials 1.8.1 documentation, except in numpyro, with the T plate containing a Dirichlet sample statement, and the observations being multinomial:

def model(data):
    with numpyro.plate("beta_plate", T-1):
        beta = numpyro.sample("beta", Beta(1, alpha))

    with numpyro.plate("lambda_plate", T):
        probs = numpyro.sample("probs", Dirichlet(np.array([1,1,1,1])))

    with numpyro.plate("data", N):
        z = numpyro.sample("z", Categorical(mix_weights(beta)))
        numpyro.sample("obs", Multinomial(probs[z]), obs=data)

I would like to impose ordering on the components to make posterior predictive evaluation easier, else the samples are in random order with respect to each other due to label switching.

I attempted to constrain the result of “mix_weights”, “beta”, and “probs” in turn, trying both TransformDistribution(dist,[OrderedTransform, ExpTransform]), and the simplex ordered transform. I get various errors, for instance when done on the beta, complains that it expects at least a shape of 1, but gets (), but if impose a shape of 1 on the beta, my model complains that it expected a ().

It would be very helpful to get a basic example of the application of ordering constraint, or otherwise approaches to dealing with label switching without resorting to sorting after sampling.

Thanks very much.

I think you can use SimplexToOrderedTransform for that. Something likes

C = 4
TransformedDistribution(dist.Normal(0, 1).expand([C-1]), [OrderedTransform(), SimplexToOrderedTransform().inv])

will give you an “ordered” simplex prior. Edit: it seems to not be right, let me think a bit more.

@svilup do you have some ideas on this problem? I think @akotlar would like to get an ordered simplex prior but also be open to other approaches to fight against the label switching issue.

The topic Ordinal regression with a principled prior seems to be relevant.