Mixture Model: Evaluation of predictive fails while model fitting works

Hi,
I am encountering a "KeyError: ‘cluster_axis’ " when evaluating the Predictive on my input_data. The thing that surprises me is, that the model fitting works but the predicitive is throwing the error.

To give some context:
I would like to use a Dirichlet Mixture model to cluster words in documents by their similarity. I already managed to apply the Dirichlet Mixture model to cluster words in only one document (thanks for the great tutorial on that :slight_smile: ). So basically now I am only adding an extra layer that allows each document to have different cluster weights.

I created a minimal example which still gives the error. It does not include all the dependencies that I aim to model, however it includes the multi-dimensional cluster weights and assignments that seem to be the problem.

# Minimal example input: 2 documents with each 4 words
words = jnp.array([[0, 0, 1, 1], 
                   [0, 1, 1, 1]])
alphabet ='01'
cluster_num = 8
input_data = words, len(alphabet), cluster_num

def model_SBP(input_data):
    word_data, alphabet_length, cluster_num = input_data

    # parameters
    document_count = word_data.shape[0]
    word_count = word_data.shape[1]

    # define rates
    mutation_rate = 0.8
    error_rate = 0.5

    # create matrix of rates
    mutation_rate_matrix = jnp.array([mutation_rate, 1 - mutation_rate])

    cluster_weights = numpyro.sample(
        "cluster_weights", dist.Dirichlet(jnp.ones((document_count, word_count, cluster_num)))
    )

    with numpyro.plate("cluster_axis", cluster_num, dim=-1):
        centers = numpyro.sample(
            "centers", dist.Categorical(mutation_rate_matrix)
        )

    cluster_assignments = numpyro.sample(
        "cluster_assignments", dist.Categorical(cluster_weights)
    )

    with numpyro.plate("document_plate", document_count, dim=-2):  
        with numpyro.plate("word_axis", word_count, dim=-1):
            #cluster_assignments = numpyro.sample(
            #    "cluster_assignments", dist.Categorical(cluster_weights)
            #)
            
            # create matrix of rates depending on cluster assignments and centers
            error_rate_matrix = jnp.full(
                (document_count, word_count, alphabet_length),
                (1 - error_rate) / (alphabet_length - 1),
            )

            numpyro.sample("obs", dist.Categorical(error_rate_matrix), obs=word_data)

# Model fitting 
rng_key = jax.random.PRNGKey(0)
num_samples = 10
num_warmup = int(num_samples / 2)

kernel = NUTS(model_SBP)
mcmc = MCMC(
    DiscreteHMCGibbs(kernel),
    num_warmup=num_warmup,
    num_samples=num_samples,
    num_chains=1
)
mcmc.run(rng_key, input_data)

posterior_samples = mcmc.get_samples()

posterior_predictive = Predictive(model_SBP, posterior_samples, infer_discrete=True)
posterior_predictions = posterior_predictive(rng_key, input_data=input_data)

I am very grateful for any hints on what the problem could be.

Could you try to remove infer_discrete=True? Your posterior_samples already contain all discrete latent variables I guess.

Thanks a lot for your help :pray: Removing infer_discrete=True solved the problem. I noticed that I understood this parameter wrong before.