Ordinal regression with a principled prior

First of all, thank you for the amazing package!

I’ve been using the Ordinal Regression example, but I’d like to enhance with a more principled prior as suggested by M.Betancourt in Case Study on Ordinal regression (section 2.2).

However, I’m struggling to create cut point “parameters” that would be optimized by the MCMC alongside the other random variables. It might be an improper usage of the numpyro.param() statement and/or numpyro.factor()

The expected behaviour would be that cutpoints will change during the MCMC run, however, they stay equal to their initial values.

Helper function for the lpd of the prior model as per Michael’s blog:

def induced_dirichlet_lpdf(c,alpha,phi):
‘’’
Function to calculate the log-density for induced Dirichlet prior
Based on: Ordinal Regression
INPUTS:
- c:jnp.array: cut point array (length=number of categories less 1)
- alpha:jnp.array: concentration for Dirichlet distribution
- phi:float: anchor point honouring M. Betancourt’s implenentation

RETURNS:
- p:jnp.array: vector of implied probabilities for each category
- lpd:float: log-density for the prior (include log det Jacobian )
'''

K=len(c)+1 # number of categories
p=jnp.zeros((K,),dtype=float) # proba
J=jnp.zeros((K,K),dtype=float) # Jacobian

sigma=expit(phi-c) #Logistic function // CDF

# Induced Ordinal Proba
p=p.at[0].set(sigma[0])
p=p.at[1:-1].set(sigma[:-1]-sigma[1:])
p=p.at[-1].set(sigma[-1])

# Baseline column of Jacobian
J=J.at[:,0].set(1.0)

# Diagonal entries of Jacobian
rho=jnp.multiply(sigma,1-sigma) # partial deriv.
i,j=jnp.diag_indices(K-1) # needs to start from index 1:
# diagonal elements
J=J.at[..., i+1, j+1].set(-rho)
# off-diagonal elements
J=J.at[..., i+1, j+1].set(rho)

_,log_det=jnp.linalg.slogdet(J)

lpd=dist.Dirichlet(alpha).log_prob(p)+log_det

return p,lpd

and the model to run it:

def model(covariates,obs=None):
‘’’
Simple Ordinal regression with Dirichlet induced prior for the latent cut points
‘’’

num_classes=3 #fixed for now
dim_cov = covariates.shape[1]
N=covariates.shape[0]
    
cutoff_init=jnp.linspace(-2,2,num_classes-1)
anchor_point=0.0
concentration_prior=jnp.array([1,1,1])

beta=numpyro.sample('beta',dist.Normal(0,10).expand([dim_cov]).to_event(1))
mean_pred = jnp.dot(beta,covariates.T)

# parameters for cut points
c_y=numpyro.param('c_y',init_value=cutoff_init,constraint=dist.constraints.ordered_vector)

# get their log-density
p_y,log_density=induced_dirichlet_lpdf(c_y,concentration_prior,anchor_point)

# add to the overall log likelihood
numpyro.factor("induced_prior_transformation", log_density)

if obs is None:
    numpyro.deterministic('mean_pred',mean_pred)  
    numpyro.deterministic('c_y',c_y)  
    numpyro.deterministic('p_y',p_y)  

with numpyro.plate("N", N) as idx:
    return numpyro.sample("obs", dist.OrderedLogistic(mean_pred,c_y), obs=obs)

I’ve searched for any similar threads, but the only helpful one was Implicit priors, but it wasn’t very clear around param() vs factor().

Any help would be greatly appreciated! I’m happy to contribute it to the example if it’s useful.

Hi @svilup, you can use dist.ImproperUniform(constraints.order_vector, event_shape=(num_classes,), batch_shape=()) distribution in your model. Using numpyro.param as an improper distribution has been deprecated and removed in one of the past releases.

Thank you, @fehiepsi!

I’ve tried that and I struggle using it, since the sample() isn’t implemented (eg, not being able to run my model to test it, not being able to trace it without conditioning, utility method for conversion to Arviz InferenceData isn’t working etc.)

I would be curious what was the reason behind it - would you mind commenting on that? I’d be keen to understand numpyro under-the-hood a bit more.


Separately, is my usage of numpyro.factor() correct? I can see the loglikelihoods of the induced Dirichlet moving in the right way (when changing the priors), but I’m not seeing it influence the cut points as much as I’d expect (despite the fairly (relatively) low log probability of the observed site).

All in all, I have some simple generated data and the model is running into 1000s of divergences.
Since I do not have any constants/intercepts that would cause the obvious non-identifiability, I wasn’t sure where it’s coming from

I think it would be tricky to implement samplers for improper distributions (maybe we can use rejection sampling?). If you want to have a proper distribution for ordered vector, you can use transformed distributions as in ordinal regression tutorial. It should work with ArviZ but not equivalent to Stan code because the Stan code uses improper priors (i.e. parameters). I’m interested in knowing how they draw samples for those parameters from the improper prior distribution.

Regarding numpyro.factor, I haven’t checked the implementation of induced_dirichlet_lpdf but your usage seems correct to me. If you use transformed distributions as above, you don’t need to compute induced_dirichlet_lpdf at all. You just use that distribution in your model.

Uff, I didn’t appreciate its complexity - I assumed that ImproperUniform is just a wrapper for TransformedDistribution+Uniform, similarly to the TransformedDistribution+Normal example you’ve referenced.


I’ll have to keep digging then, because I don’t see how TransformedDistribution would solve my problem.
I need to translate the ordered cutpoints into probabilities (difference in CDFs) and add their loglikelihood to the model. It’s such a simple code, so I must be just missing some trick :slight_smile:

Improper is different from Transformed(uniform) in that the former has uniform density over the transformed domain (i.e. uniform density over all cutpoints), while the later has uniform density over the base domain.

With transformed distribution, you are assume that your cutpoints follow t(Z) where t is the transform and Z is the base distribution. We also implement all the math for you to get log density (involves jacobian of transforms) of that transformed distribution. I think the Stan code is different in what Z and t are, and needs to compute log density of t(Z) by hand. As I mentioned, there is nothing to prevent you from doing that. Just simply provide a fake sampler for your improper prior with: TransformedDistribution(dist.Normal(...)).mask(False). The mask will guarantee that your cutpoints has uniform improper prior, while still allows you to get “fake” samples (i.e. samples from t(Normal)) if you run the model. Personally, I won’t use mask and avoid implementing a custom density if possible.

Thank you, Fehiepsi! That’s a helpful trick with .mask(False).

I was a bit embarrassed that I could have figured out the difference myself, so I took it as an opportunity to unpick Numpyro a bit more - effect handlers are awesome!

I’ve solved my problem - I had two sources of problems there:

  1. typo in the induced_dirichlet_lpdf (I appreciate your comment on custom densities), and
  2. too much separation between categories, which has created non-identifiability that the improper priors/transformed-Normal weren’t able to push through

With the induced Dirichlet prior, it cut through the divergences like a knife through butter! :slight_smile:


Btw. I’ve re-visited available transformations, but (IMO) none can serve this use case of inducing cut-points with Dirichlet prior, because there needs to be a transformation from cut-points to simplex under Logistic CDF, which is quite specific.
I have tried, as you’ve suggested, doing that separately and then simply passing the probabilities (p) to dist.Dirichlet(concentration).log_prob(p), but that would be missing the log det J term transform. If I add that as well, then I have effectively implemented induced_dirichlet_lpdf anyway.

I agree I wouldn’t want to derive it from scratch but since M. Betancourt did it, there is no reason not to use it :slight_smile:


If there was interest, I could add it as one more model to the Ordinal regression example.

Updated log density function:

def induced_dirichlet_lpdf(c,alpha,phi):
‘’’
Function to calculate the log-density for induced Dirichlet prior
Based on: Ordinal Regression

INPUTS:
- c:jnp.array: cut point array (length=number of categories less 1)
- alpha:jnp.array: concentration for Dirichlet distribution
- phi:float: anchor point honouring M. Betancourt's notation

RETURNS:
- p:jnp.array: vector of implied probabilities for each category
- lpd:float: log-density for the prior (includes log det Jacobian)
'''

K=len(c)+1 # number of categories
p=jnp.zeros((K,),dtype=float) # proba
J=jnp.zeros((K,K),dtype=float) # Jacobian

sigma=expit(phi-c) #Logistic function // CDF

# Induced Ordinal Proba
p=p.at[0].set(1-sigma[0])
p=p.at[1:-1].set(sigma[:-1]-sigma[1:])
p=p.at[-1].set(sigma[-1])

# Baseline column of Jacobian
J=J.at[:,0].set(1.0)

# Diagonal entries of Jacobian
rho=jnp.multiply(sigma,1-sigma) # partial deriv.
i,j=jnp.diag_indices(K-1) # needs to start from index 1:
# diagonal elements
J=J.at[..., i+1, j+1].set(-rho)
# off-diagonal elements
J=J.at[..., i+1, j+1].set(rho)

_,log_det=jnp.linalg.slogdet(J)

dirichlet_lpd=dist.Dirichlet(alpha).log_prob(p)
lpd=dirichlet_lpd+log_det

return p,lpd
1 Like

If there was interest, I could add it as one more model to the Ordinal regression example.

Sure please go ahead - personally I found this discussion super helpful. I would recommend to move your code to a transform class:

class Simplex2OrderedTransform(dist.transforms.Transform):
    ...

and reuse it in the model:

sample("c", TransformedDistribution(Dirichlet(...), Simplex2OrderedTransform())

so that you can also draw cutpoint samples from this prior (rather than generating “fake” samples as in my last comment). If you want to record p, you can use “TransformReparam” as in this example - MCMC will give you both c_base (which is p) and c.

You can also use ComposedTransformed([StickBreakingTransform().inv, OrderedTransform()]) rather than the above transform. This transform has an analytical determinant - so it is expected to be much faster (if J.size is large). If it also gives you the desired result, please use it in the tutorial (for both simplicity and pyroic). :slight_smile:

Re. ComposeTransform(), I don’t think it achieves the same thing (I did some simulations to confirm it).

The intention in the case study is:

  1. sample a vector
  2. transform into an ordered vector → Record the output as c_y to use for OrderedLogistic()
  3. transform into a simplex that represents probabilities of each category in the latent “affinity” space defined by the OrderedLogistic model (ie, not stick-breaking)
  4. add prior knowledge/regularization to these probabilities in the same latent space by indirectly using the Dirichlet prior

I’m not sure we can achieve the same via the composed transform as per your suggestion:

proba_to_cutpoint_transform=dist.transforms.ComposeTransform([dist.transforms.StickBreakingTransform().inv,dist.transforms.OrderedTransform()])

with numpyro.handlers.reparam(config={‘c_y’: TransformReparam()}):
c_y=numpyro.sample(“c_y”, dist.TransformedDistribution(dist.Dirichlet(concentration_prior),
proba_to_cutpoint_transform))

I think we would do:

  1. sample a random simplex via Dirichlet prior
  2. transform into a set of cut points (non-ordered, via stick-breaking)
  3. transform into an ordered vector → Record the output as c_y to use for OrderedLogistic()

While it practically works to break through some of the non-identifiability, it sounds like it assumes an entirely different model and we can no longer add any prior knowledge around the category probabilities.
The reason is that the Dirichlet prior would no longer correspond to the probabilities of the individual categories (ie, to the latent “affinity” space), as the ordered transform in step 3 breaks that…
Eg, if you start with an equal probability in step 1, by step 3 the implied probability will be strongly concentrated in the first few categories, because of how the Ordered transform works.

I could also see some issues popping up because the cut-points are not anchored (ie, they could move freely with the latent beta parameter to create the same “category probabilities”)


It seems that the easiest way to apply our domain knowledge to the probabilities of individual categories is to regularize the cut-points indirectly with Dirichlet log_prob in the transformed space (ie, after cutpoints–>proba–>Dirichlet log_prob).

What would be the cleanest / the most pyroic way to implement it?
Option A (similar to c_y_smp in model2 in the tutorial):

  • sample cut-points
  • apply a deterministic transform into ordinal category proba (eg, Cutpoint2OrderedLogisticCategoryProbaTransform - an awful name, I know!)
  • add Dirichlet prior with observed = transformed proba from the previous step

My concern here is that applying the deterministic transform will not add the log_det_J to the model density.

Option B:

  • sample cut-points
  • apply some “ReversedTransformedDistribution”, where Dirichlet would be applied only after the transformations

All other options that I think of are even more convoluted, which is why I would like to ask for your advice.


In terms of performance for analytic vs numeric solution of the determinant of Jacobian. It’s a fair challenge, but I’m less concerned as Ordinal regression makes sense only for a relatively small number of categories, so the impact should be limited (J is of shape (K,K) where K is the number of categories).

apply our domain knowledge to the probabilities of individual categories is to regularize the cut-points indirectly with Dirichlet log_prob in the transformed space

An equivalent Pyroic way is to define Simplex2OrderedTransform and use
TransformedDistribution(Dirichlet(...), Simplex2OrderedTransform()) for c as in my last comment. More explicitly,

class Simplex2OrderedTransform(dist.transforms.Transform):
    def __call__(self, x):
        c = x
        p = ...
        return p

    def _inverse(self, y):
        p = y
        c = ...
        return c

    def log_abs_det_jacobian(self, x, y, intermediates=None):
        J = ...
        return J

where I use the same notation of c, p, J as in your implementation.

This is the prior Betancourt used in his tutorial.
image
The idea is clearer if you look at the sampling section in his article, where he drawn samples from dirichlet distribution, then transform them to cut points.

A more Pyroic way is to use ComposedTransform as in my last comment. When we assume that predictor=0 (i.e. there are no latent structure in Stan language), I agree that using Simplex2OrderedTransform is more aligned with the implementation of OrderedLogistic, so it seems to be more intuitive (but in the general case with random predictor, I don’t agree that this will help us add prior knowledge around the category probability - if predictor ~ 10 e.g., I just feel that this can lead to a poor setting). Maybe talk about Simplex2OrderedTransform mainly in the tutorial and just mention ComposedTransform as a convenient way to get a transform for 2 domains. What do you think?

Regarding Cutpoint2OrderedLogisticCategoryProbaTransform, I’m not sure if I can follow. It seems to be Simplex2Ordered().inv?

I think we have a slightly different understanding of what that means. My view is that we sample cutpoints, not probabilities - that is consistent with the Stan code (probabilities feature only to calculate the density via Dirichlet lpdf).

If we sampled probabilities via Dirichlet first and then transformed, we wouldn’t be able to apply our domain knowledge on category proba (because of the ordered transform in between) and we would have many-to-one map, quoting Michael here:

(4th section in 2.2) … Because of the simplex constraint on the ordinal probabilities we have to be a particularly careful in the construction of the pushforward probability density function here. One might, for example, be tempted to transform the KK probabilities into the K−1K−1 internal cut points and an affinity γγ only to find that this transformation is singular and yields an ill-posed probability density function. This is because the internal cut points and affinity are non-identified and the map to probabilities is many-to-one.
Instead we condition on an anchor point , ϕϕ, and then map the KK probabilities and their sum-to-one constraint to the K−1K−1 cut points and a new variable that encodes the constraint, S=∑Kk=1pk=1S=∑k=1Kpk=1…

I believe what he talks about is the indirect regularization / inducing by sampling c and then just adding the corresponding Dirichlet lpd.

WDYT? There is a Patreon live stream with Michael on 17th Aug, so I can submit it as a question.


Proposal for the transform class:

class Ordered2SimplexTransform (dist.transforms.Transform):
    """
    Transform specific to the context of OrderedLogistic model
    Transform an ordered vector of cutpoints into a simplex representing ordered category probabilities (via the difference in Logistic CDF at cutpoints)
 
    **References:**

    1. *Ordinal Regression Case Study, section 2.2*,
   M. Betancourt, https://betanalpha.github.io/assets/case_studies/ordinal_regression.html
    """

    domain = constraints.ordered_vector
    codomain = constraints.simplex
    def __init__(self, anchor_point=0.0):
        self.anchor_point = anchor_point

    def __call__(self, x):
        K=x.shape[0]+1 # number of categories
        sigma=expit(self.anchor_point-x)
    
        # Implied probabilities of each category in Ordered Logistic model
        y=jnp.zeros((K,),dtype=float)
        y=y.at[0].set(1-sigma[0])
        y=y.at[1:-1].set(sigma[:-1]-sigma[1:])
        y=y.at[-1].set(sigma[-1])

        return y

    def _inverse(self, y):
        K=y.shape[0]
    
        def scan_fn(expit_term, y):
            x=self.anchor_point-logit(expit_term-y)        
            next_expit_term=expit(self.anchor_point-x)
            return next_expit_term, x

        x0=1
        _, x = jax.lax.scan(scan_fn, x0,y[:-1])
        return x

    def log_abs_det_jacobian(self, x, y, intermediates=None):
        K=x.shape[0]+1 # number of categories
        sigma=expit(self.anchor_point-x)
    
        J=jnp.zeros((K,K),dtype=float)
        # baseline column (encoding the simplex constraint)
        J=J.at[:,0].set(1.0)
        # Diagonal entries of Jacobian
        rho=jnp.multiply(sigma,1-sigma) # partial deriv.
        i,j=jnp.diag_indices(K-1) # needs to start from index 1:
        # diagonal elements
        J=J.at[..., i+1, j+1].set(-rho)
        # off-diagonal elements
        J=J.at[..., i+1, j+1].set(rho)

        _,log_det=jnp.linalg.slogdet(J)
        return log_det

To be used like this (I haven’t found a better way to add this to the joint model density)

def model():
    ...
    anchor_point=0.0
    concentration_prior=jnp.ones(n_categories)*10
    ...
    p_y=Ordered2SimplexTransform(anchor_point)(c_y)
    if Y is not None:
        numpyro.factor("dirichlet_smp", dist.Dirichlet(concentration_prior).log_prob(p_y)+Ordered2SimplexTransform(anchor_point).log_abs_det_jacobian(c_y,p_y))
    ...
    numpyro.sample('obs', dist.OrderedLogistic(eta, c_y), obs=Y)

The way NumPyro calculates log density of a transformed distribution is to transform the value to base domain, calculate logdensity there, then add log determinant. It is just the same as if you have cutpoints, then transform it to simplex, then calculate log probability of this simplex and add log determinant of Jacobian. This will givr the same log density as what the article did.

To be clearer, if we apply improper distribution for a variable and add a separate normal density to it, then the behavior (under MCMC) should be the same as if we use normal prior for that variable. All MCMC needs is a latent site and its log density. As long as log density is the same, it is your choice to use improper distribution or not. And two ways of coding are equivalent.

When running the model to get samples or predictives (this is different from using MCMC), the way NumPyro samples those cutpoints from a TransformedDistribution is to sample a simplex, then transform it to cutpoints. In Stan article, I see the same thing - Dirchlet samples are drawn first, then transformed into cutpoints (in the rng section there).

That’s my understanding. Please ask the author about points that are unclear to you. :slight_smile: You might want to try yourself calculating your custom density first and compare it with TransformedDistribution log density. If they are different, there might be a bug in how transformed distribution is implemented (due to the involvement of non-scalar event shape here).

1 Like

Sold!

Yes, that makes sense since we have one-to-one mapping, so it must be possible to run simplex–>ordered as well.
My head was stuck in comparing it with “stick-breaking + ordered” transform, which wouldn’t be equivalent.

As for the Stan code, I think you’re looking at writeLines(readLines("simulate_ordinal.stan")) when he generates new data, whereas I was looking at the fitting step in writeLines(readLines("ordered_logistic_induced.stan")).

The reason why I care is that if we run it on a map “simplex–>ordered” then we would need a different Jacobian term (he implemented derivations “dp/dc” evaluated at “c”, which corresponds to “ordered–>simplex” map with this sign…). Or would you disagree?

I think the formula in Stan code matches NumPyro implementation of TransformedDistribution with Simplex2OrderedTransform. If you implement Ordered2SimplexTransform, you need to use Ordered2SimplexTransform().inv in TransformedDistribution. If it is helpful for your reasoning, then pls just use it.

I suspect it’s just flipping the sign). Or would you disagree?

I think his calculation is correct, and it matches NumPyro implementation and your implementation of Ordered2SimplexTransform. If you use Simplex2OrderedTransform, please flip the sign in log_abs_det_jacobian implementation.

Regarding simulate_ordinal.stan or ordered_logistic_induced.stan, the corresponding ones in NumPyro are:

  • simulate_ordinal.stan: run NumPyro model, collect values - typically we use Predictive class for this task (or providing a seed and call the model function)
  • ordered_logistic_induced.stan: used to perform MCMC inference

In both usages, there is only 1 generative model. The code in ordered_logistic_induced.stan is a translation of that generative model such that it helps MCMC run. The code in simulate_orginal.stan is a translation of that generative model such that it helps simulating. Because there is no transformed distribution in Stan, the implementations of both usage cases are performed by hand.

1 Like

Thank you for the extensive explanation! You’re right as always.
I’ve done some post-it note math to understand why - all makes sense now.

I’ll move over to github.

I have opened the issue here: Adding Dirichlet prior for Ordinal Regression case study · Issue #1129 · pyro-ppl/numpyro · GitHub

For anyone interested in implementing themselves, please note that you need to flip the sign of log_abs_det_jacobian compared to the case study cited above (due to Numpyro TransformedDistribution implementation)

Thanks, @svilup!

Regarding flipping the sign, just to clarify, the Stan tutorial computes Jacobian of the transform c → p. If we implement the transform c → p in NumPyro, no need to flip the sign. If we implement the transform p → c, we need to flip the sign of the Jacobian derived in the Stan tutorial. (For any transform t, log_det_abs_jacobian(t) = -log_det_abs_jacobian(t.inv).) This has nothing to do with TransformedDistribution implementation. For modeling, we usually don’t need to worry about the details of implementations. All we need is to specify correct distributions (modulo bugs,…) There are tricky likelihoods that are not easily modeled and require us to implement the density by hand - but they are domain-specific I guess.

Noted! I didn’t mean to imply there is anything wrong with TransformedDistribution :slight_smile:

I sometimes get confused when talking about transforms and their Jacobians - I found it helpful to just write down some pseudo-math (c = cutpoints = “Ordered”, p=ordinal probabilities = “Simplex”):
\pi_c \cdot dc = \pi_p \cdot dp
\pi_c = \pi_p \cdot |\frac{ \partial p }{ \partial c}|
log(\pi_c) = log(\pi_p) + log(|\frac{ \partial p }{ \partial c}|) = log(\pi_p) + log(|J_pc|)
where \pi are the respective densities and the last term is the exact version implemented by Michael (J_pc = dp/dc).

On that basis, it seems that for transform Simplex2Ordered Numpyro expects the following:
log(\pi_c) = log(\pi_p) - log(|J_cp|),
so I as you suggest I just flipped the sign ( log(|J_cp|) = -log(|J_pc|)) to get:
log(\pi_c) = log(\pi_p) - (- log(|J_pc|)),
where log(|J_pc|) is the log_abs_det() version that was in the code example above.

When reading math involving both transform and distribution, I’m also confused. Actually there were several bugs in the past that related to wrong signs - it is easy to get mistake. In the last comment, I just wanted to clarify that users don’t need to know about the math of transformed distributions to implement a transform. From our discussions, people might get confused to think that flipping the sign is due to some detailed math of transformed distributions,… Actually, flipping the sign is just a fact of taking inverse transform, and that is independent of probability stuffs. Users only need to know the property log|J_cp|=-log|J_pc| when they want to implement the transform by themselves.

When p(c) and p(p) are involved, things are more complicated. I think your comment will be helpful for those who want to implement the transformed distribution log density by hand, like in the Stan post.