I want to fit a ordinal regression model with three categorical inputs using enumeration and TraceEnum_ELBO. Here is an example with simulated data.
from jax import numpy as jnp
import numpy as np
import numpyro
from numpyro.infer import reparam
from numpyro import distributions as dist
from numpyro.ops.indexing import Vindex
from numpyro import handlers, optim
from jax.random import PRNGKey
from numpyro.infer import (
SVI,
Trace_ELBO,
TraceEnum_ELBO,
)
from numpyro.contrib.funsor import config_enumerate
from numpyro.infer.autoguide import AutoNormal
# data simulation, y, x0, x1, x2 are all categorical variables
def data_simulator(N: int = 1000, m: int = 200):
y = np.random.negative_binomial(3, 3 / 13, N)
y = np.clip(y, a_min=0, a_max=39)
def modify_sample(n):
(id,) = np.where(y == n)
mod_id = np.random.choice(id, int(len(id) * 0.3), False)
y[mod_id] = n + 1
for i in [2, 4, 6]:
modify_sample(i)
def construct_frequency(a):
v = np.unique(a, return_counts=True)
out = np.zeros(40)
out[v[0]] = v[1] / m
return out
def proc_array(a):
out = np.round(a)
out = np.clip(out, a_min=0, a_max=39).astype("int")
return np.apply_along_axis(construct_frequency, 1, out)
y_T = np.expand_dims(y, 1)
x0 = y_T + np.random.standard_normal((N, m))
x0 = proc_array(x0)
x1 = y_T + np.random.standard_normal((N, m)) * 1.5 + 0.2
x1 = proc_array(x1)
x2 = y_T + np.random.standard_normal((N, m)) * 0.5 - 0.1
x2 = proc_array(x2)
x = np.stack([x0, x1, x2], axis=1)
n_split = int(N * 0.8)
y_train, y_test = np.split(y, [n_split])
x_train, x_test = np.split(x, [n_split])
return y_train, x_train, y_test, x_test
# ordinal regression model with x0, x1, x2 as inputs,
# x = column_stack([x0, x1, x2], axis=1)
@config_enumerate
def model_f(data, **kwargs):
array = jnp.arange(40)
dm = []
for i, x in enumerate(np.split(data["X"], 3, axis=1)):
# with numpyro.handlers.block(), numpyro.handlers.seed(rng_seed=PRNGKey(i)):
# with numpyro.plate(f"model_plate_{i}", 800) as ind:
sub = numpyro.sample(f"model{i}", dist.CategoricalProbs(x.squeeze()))
sub_i = Vindex(array)[sub]
dm.append(sub_i)
dm = jnp.column_stack(dm)
# with numpryo.plate()
with numpyro.plate("beta_plate", 3):
beta = numpyro.sample("beta", dist.Normal(0, 1))
# beta = numpyro.sample("beta", dist.Normal(0, 1), sample_shape=(3,))
out = dm @ beta
with numpyro.handlers.reparam(config={"c_y": reparam.TransformReparam()}):
c_y = numpyro.sample(
"c_y",
dist.TransformedDistribution(
# dirichlet prior on latent class probabilities
dist.Dirichlet(np.ones(40)),
# anchor point for cutpoints
dist.transforms.SimplexToOrderedTransform(0),
),
)
with numpyro.plate("y_plate", data["X"].shape[0]):
return numpyro.sample("y", dist.OrderedLogistic(out, c_y), obs=data["y"])
if __name__ == "__main__":
y, x, y1, x1 = data_simulator()
# sns.histplot(y, binwidth=1)
# sns.histplot(y1, binwidth=1)
train_data = {"X": x, "y": y}
test_data = {"X": x1, "y": y1}
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(
model_f, AutoNormal(model_f), optim.Adam(step_size=0.005), loss=elbo, data=train_data
)
svi.run(PRNGKey(0), 100)
print()
And I get the following errors
Exception has occurred: ValueError
Missing a plate statement for batch dimension -1 at site 'model0'. You can use `numpyro.util.format_shapes` utility to check shapes at all sites of your model.
I am confused about the error. Is plate notation necesarry here? x is a (800, 40) matrix and sub should be a (800,) vector for one batch of enumeration. I tried adding plate notation but it gave me a much bigger matrix with shape like (800, 800).
sub = numpyro.sample(f"model{i}", dist.CategoricalProbs(x.squeeze()))
sub_i = Vindex(array)[sub]