How to use multiple enumeration in regression model?

I want to fit a ordinal regression model with three categorical inputs using enumeration and TraceEnum_ELBO. Here is an example with simulated data.

from jax import numpy as jnp
import numpy as np
import numpyro
from numpyro.infer import reparam
from numpyro import distributions as dist
from numpyro.ops.indexing import Vindex
from numpyro import handlers, optim
from jax.random import PRNGKey
from numpyro.infer import (
    SVI,
    Trace_ELBO,
    TraceEnum_ELBO,
)
from numpyro.contrib.funsor import config_enumerate
from numpyro.infer.autoguide import AutoNormal

# data simulation, y, x0, x1, x2 are all categorical variables
def data_simulator(N: int = 1000, m: int = 200):
    y = np.random.negative_binomial(3, 3 / 13, N)
    y = np.clip(y, a_min=0, a_max=39)

    def modify_sample(n):
        (id,) = np.where(y == n)
        mod_id = np.random.choice(id, int(len(id) * 0.3), False)
        y[mod_id] = n + 1

    for i in [2, 4, 6]:
        modify_sample(i)

    def construct_frequency(a):
        v = np.unique(a, return_counts=True)
        out = np.zeros(40)
        out[v[0]] = v[1] / m
        return out

    def proc_array(a):
        out = np.round(a)
        out = np.clip(out, a_min=0, a_max=39).astype("int")
        return np.apply_along_axis(construct_frequency, 1, out)

    y_T = np.expand_dims(y, 1)

    x0 = y_T + np.random.standard_normal((N, m))
    x0 = proc_array(x0)

    x1 = y_T + np.random.standard_normal((N, m)) * 1.5 + 0.2
    x1 = proc_array(x1)

    x2 = y_T + np.random.standard_normal((N, m)) * 0.5 - 0.1
    x2 = proc_array(x2)

    x = np.stack([x0, x1, x2], axis=1)
    n_split = int(N * 0.8)
    y_train, y_test = np.split(y, [n_split])
    x_train, x_test = np.split(x, [n_split])
    return y_train, x_train, y_test, x_test


# ordinal regression model with x0, x1, x2 as inputs, 
# x = column_stack([x0, x1, x2], axis=1)
@config_enumerate
def model_f(data, **kwargs):
    array = jnp.arange(40)
    dm = []
    for i, x in enumerate(np.split(data["X"], 3, axis=1)):
        # with numpyro.handlers.block(), numpyro.handlers.seed(rng_seed=PRNGKey(i)):
        # with numpyro.plate(f"model_plate_{i}", 800) as ind:
        sub = numpyro.sample(f"model{i}", dist.CategoricalProbs(x.squeeze()))
        sub_i = Vindex(array)[sub]
        dm.append(sub_i)
    dm = jnp.column_stack(dm)
    # with numpryo.plate()
    with numpyro.plate("beta_plate", 3):
        beta = numpyro.sample("beta", dist.Normal(0, 1))
    # beta = numpyro.sample("beta", dist.Normal(0, 1), sample_shape=(3,))
    out = dm @ beta

    with numpyro.handlers.reparam(config={"c_y": reparam.TransformReparam()}):
        c_y = numpyro.sample(
            "c_y",
            dist.TransformedDistribution(
                # dirichlet prior on latent class probabilities
                dist.Dirichlet(np.ones(40)),
                # anchor point for cutpoints
                dist.transforms.SimplexToOrderedTransform(0),
            ),
        )

    with numpyro.plate("y_plate", data["X"].shape[0]):
        return numpyro.sample("y", dist.OrderedLogistic(out, c_y), obs=data["y"])


if __name__ == "__main__":
    y, x, y1, x1 = data_simulator()
    # sns.histplot(y, binwidth=1)
    # sns.histplot(y1, binwidth=1)

    train_data = {"X": x, "y": y}
    test_data = {"X": x1, "y": y1}
    elbo = TraceEnum_ELBO(max_plate_nesting=1)
    svi = SVI(
        model_f, AutoNormal(model_f), optim.Adam(step_size=0.005), loss=elbo, data=train_data
    )
    svi.run(PRNGKey(0), 100)
    print()

And I get the following errors

Exception has occurred: ValueError

Missing a plate statement for batch dimension -1 at site 'model0'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.

I am confused about the error. Is plate notation necesarry here? x is a (800, 40) matrix and sub should be a (800,) vector for one batch of enumeration. I tried adding plate notation but it gave me a much bigger matrix with shape like (800, 800).

sub = numpyro.sample(f"model{i}", dist.CategoricalProbs(x.squeeze()))
sub_i = Vindex(array)[sub]

It seems to me that you need a plate dimension for model_i sites, because it has batch dimension with size 800. Make sure to denote dim=-1 for your plate statements. To make enumeration work, you need to denote all batch dimensions of the variables using plate statements.

Thank you! Setting dimension solves the problem. Then I have another error:

Cannot find valid initial parameters. Please check your model again.

After reading this github issue, I tried init_to_value in AutoNormal:

svi = SVI(
        model_f,
        AutoNormal(
            model_f,
            init_loc_fn=partial(
                initialization.init_to_value(
                    values={"model0": 3, "model1": 3, "model2": 3}
                )
            ),
        ),
        optim.Adam(step_size=0.005),
        loss=elbo,
    )
    svi.run(PRNGKey(0), num_steps=100, data=train_data)

but this cannot solve the initial parameter error. Could you give me some advice on next step?

I’m not sure if AutoNormal supports enumeration. @ordabayev do you know?

@fehiepsi I think so. At least I was able to use it with AutoDelta. @statsman can you have a look at this tutorial that I wrote - Gaussian Mixture Model — NumPyro documentation? I remember that it wasn’t very straightforward and I had to use block/seed handlers and warm up the guide to make it work.