About SimplexToOrderedTransform and TransformReparam

Hi all,

I run the tutorial Ordinal Regression (Ordinal Regression โ€” NumPyro documentation).
I am not sure how the transforms.SimplexToOrderedTransform behaves.

I have confirmed that it usually works as follows.

d = dist.TransformedDistribution(dist.Dirichlet(np.ones((3,))),dist.transforms.SimplexToOrderedTransform(0))
d.sample(random.PRNGKey(0))
Array([-1.584826 ,  0.6460422], dtype=float32)

However, when executing MCMC, it does not seem to work correctly as shown in the following code.
The tutorial code uses TransformReparam(), can you tell me how this works?

import os

import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from jax.experimental.ode import odeint
import numpyro.distributions.constraints as constraints
from numpyro.infer import MCMC, NUTS
from numpyro.infer import Predictive
from numpyro.infer.util import initialize_model

import arviz as az

az.style.use("arviz-darkgrid")

assert numpyro.__version__.startswith("0.11.0")

numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)
from numpyro.infer.reparam import TransformReparam


# data generation
simkeys = random.split(random.PRNGKey(1), 2)
nsim = 50
nclasses = 3
Y = dist.Categorical(logits=np.zeros(nclasses)).sample(simkeys[0], sample_shape=(nsim,))
X =dist. Normal().sample(simkeys[1], sample_shape=(nsim,))
X += Y
df = pd.DataFrame({"X": X, "Y": Y})


def model_ng(X, Y, nclasses, concentration, anchor_point=0.0):
    b_X_eta = numpyro.sample("b_X_eta", dist.Normal(0, 5))

    #with numpyro.handlers.reparam(config={"c_y": TransformReparam()}):
    c_y = numpyro.sample(
        "c_y",
        dist.TransformedDistribution(
            dist.Dirichlet(concentration),
            dist.transforms.SimplexToOrderedTransform(anchor_point),
        )
    )
    print(c_y.shape, c_y)
    with numpyro.plate("obs", X.shape[0]):
        eta = X * b_X_eta
        numpyro.sample("Y", dist.OrderedLogistic(eta, c_y), obs=Y)


concentration = np.ones((nclasses,)) * 10.0

rng_key= random.PRNGKey(0)
kernel = NUTS(model_ng)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
mcmc.run(
    rng_key=rng_key,
    X=df["X"].values,
    Y=df["Y"].values,
    nclasses=nclasses,
    concentration=concentration,
)
# with exclude_deterministic=False, we will also show the ordinal probabilities sampled from Dirichlet (vis. `c_y_base`)
mcmc.print_summary(exclude_deterministic=False)
(3,) [1.2993407 2.6367652 9.292245 ]
(3,) Traced<ConcreteArray([1.2993407 2.6367652 9.292245 ], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([1.2993407, 2.6367652, 9.292245 ], dtype=float32)
  tangent = Traced<ShapedArray(float32[3])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[3]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x3019bfba0>, in_tracers=(Traced<ShapedArray(float32[3]):JaxprTrace(level=1/0)>,), out_tracer_refs=[<weakref at 0x3028781d0; to 'JaxprTracer' at 0x302878400>], out_avals=[ShapedArray(float32[3])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[3]. let b:f32[3] = cumsum[axis=0 reverse=False] a in (b,) }, 'in_shardings': (UnspecifiedValue,), 'out_shardings': (UnspecifiedValue,), 'resource_env': None, 'donated_invars': (False,), 'name': '_cumulative_reduction', 'in_positional_semantics': (<_PositionalSemantics.GLOBAL: 1>,), 'out_positional_semantics': <_PositionalSemantics.GLOBAL: 1>, 'keep_unused': False, 'inline': False}, effects=set(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x174ac8330>, name_stack=NameStack(stack=(Transform(name='jvp'),))))
Output exceeds the size limit. Open the full output data in a text editor
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[58], line 62
     60 kernel = NUTS(model_ng)
     61 mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=1, thinning=1)
---> 62 mcmc.run(
     63     rng_key=rng_key,
     64     X=df["X"].values,
     65     Y=df["Y"].values,
     66     nclasses=nclasses,
     67     concentration=concentration,
     68 )
     69 # with exclude_deterministic=False, we will also show the ordinal probabilities sampled from Dirichlet (vis. `c_y_base`)
     70 mcmc.print_summary(exclude_deterministic=False)

File ~/Desktop/programming/numpyro_intro/.venv/lib/python3.9/site-packages/numpyro/infer/mcmc.py:628, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    626 map_args = (rng_key, init_state, init_params)
    627 if self.num_chains == 1:
--> 628     states_flat, last_state = partial_map_fn(map_args)
    629     states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    630 else:

File ~/Desktop/programming/numpyro_intro/.venv/lib/python3.9/site-packages/numpyro/infer/mcmc.py:410, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    408 # Check if _sample_fn is None, then we need to initialize the sampler.
    409 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
...
-> 1617       raise TypeError(f'{name} got incompatible shapes for broadcasting: '
   1618                       f'{", ".join(map(str, map(tuple, shapes)))}.')
   1620 return tuple(result_shape)

TypeError: mul got incompatible shapes for broadcasting: (4,), (3,).

Hi @yoshida, it is a bug in SimplexToOrderedTransform where the methods forward_shape and inverse_shape are not implemented correctly (it is currently using the default ones, that maps shape to shape rather than shape to shape[:-1] + (shape[-1] - 1,). Could you make a github issue or a PR to fix this?

1 Like

Thank you for your prompt reply.
Your comment helped me resolve the issue, and I made a gitHub issue (a bug in SimplexToOrderedTransform ยท Issue #1580 ยท pyro-ppl/numpyro ยท GitHub).