Need help with speeding up my generative model

Hi there, I have a model with some latent discrete parameters using numpyro:

def model(table, n_h=2, n_b=5):
    n_row, n_col = table.shape

    # mask for missing values
    is_observed = (~np.isnan(table))
    valid_table = table.copy()
    valid_table[~is_observed] = n_b + 1
    valid_table = valid_table.astype(jnp.int32)

    
    with numpyro.plate("h", n_h, dim=-2):
        with numpyro.plate("col", n_col):
            beta = numpyro.sample("β", dist.Dirichlet(jnp.ones(n_b)/n_b))
                
    alpha = numpyro.sample("α", dist.Dirichlet(jnp.ones(n_h)/n_h))
    
    with numpyro.plate("row_i", n_row, dim=-2) as i:
        h_i = numpyro.sample("h_i", dist.Categorical(alpha))
        
        with numpyro.plate("col_k", n_positions) as k:
            numpyro.sample(
                "x_h,i,k", dist.Categorical(beta[h, k,:]).mask(is_observed),
                obs=valid_table
            )

it works well with DiscreteHMCGibbs

kernel = DiscreteHMCGibbs(NUTS(model), modified=True)
mcmc = MCMC(kernel, num_samples=1000, num_warmup=500, thinning=2)
rng_key = random.PRNGKey(100)
mcmc.run(rng_key, table)

However it’s relatively slow, and it’s extremely slow when my table is big. I was trying to use SVI but ran into many issues, for example

guide = autoguide.AutoDelta(model)
optimizer = numpyro.optim.Adam(0.001)
svi = SVI(phasing_model, guide, optimizer, TraceEnum_ELBO(max_plate_nesting=3))
svi_results = svi.run(random.PRNGKey(0), 50, table)

I got ValueError: CategoricalProbs distribution got invalid probs parameter.

I also tried to use pyro but ran into many problems, like it enumerates over h_i, but I guess I could predicts h_i afterward.
Also I think I should have coded the guide myself instead of using autoguide, but I kind of don’t know how to adjust my model and write a guide, if anyone shows me some pointers or other smarter ways to speed up the model, it’d be appreciated. Thank you!

I think your model can be enumerated. So you can add the annotation infer={"enumerate": True} like in hmm enum example. It could be much faster than using DiscreteHMCGibbs.