Hi there, I have a model with some latent discrete parameters using numpyro:

```
def model(table, n_h=2, n_b=5):
n_row, n_col = table.shape
# mask for missing values
is_observed = (~np.isnan(table))
valid_table = table.copy()
valid_table[~is_observed] = n_b + 1
valid_table = valid_table.astype(jnp.int32)
with numpyro.plate("h", n_h, dim=-2):
with numpyro.plate("col", n_col):
beta = numpyro.sample("β", dist.Dirichlet(jnp.ones(n_b)/n_b))
alpha = numpyro.sample("α", dist.Dirichlet(jnp.ones(n_h)/n_h))
with numpyro.plate("row_i", n_row, dim=-2) as i:
h_i = numpyro.sample("h_i", dist.Categorical(alpha))
with numpyro.plate("col_k", n_positions) as k:
numpyro.sample(
"x_h,i,k", dist.Categorical(beta[h, k,:]).mask(is_observed),
obs=valid_table
)
```

it works well with DiscreteHMCGibbs

```
kernel = DiscreteHMCGibbs(NUTS(model), modified=True)
mcmc = MCMC(kernel, num_samples=1000, num_warmup=500, thinning=2)
rng_key = random.PRNGKey(100)
mcmc.run(rng_key, table)
```

However it’s relatively slow, and it’s extremely slow when my table is big. I was trying to use SVI but ran into many issues, for example

```
guide = autoguide.AutoDelta(model)
optimizer = numpyro.optim.Adam(0.001)
svi = SVI(phasing_model, guide, optimizer, TraceEnum_ELBO(max_plate_nesting=3))
svi_results = svi.run(random.PRNGKey(0), 50, table)
```

I got `ValueError: CategoricalProbs distribution got invalid probs parameter.`

I also tried to use pyro but ran into many problems, like it enumerates over `h_i`

, but I guess I could predicts `h_i`

afterward.

Also I think I should have coded the guide myself instead of using autoguide, but I kind of don’t know how to adjust my model and write a guide, if anyone shows me some pointers or other smarter ways to speed up the model, it’d be appreciated. Thank you!