Hi,
I am encountering a "KeyError: ‘cluster_axis’ " when evaluating the Predictive on my input_data
. The thing that surprises me is, that the model fitting works but the predicitive is throwing the error.
To give some context:
I would like to use a Dirichlet Mixture model to cluster words in documents by their similarity. I already managed to apply the Dirichlet Mixture model to cluster words in only one document (thanks for the great tutorial on that ). So basically now I am only adding an extra layer that allows each document to have different cluster weights.
I created a minimal example which still gives the error. It does not include all the dependencies that I aim to model, however it includes the multi-dimensional cluster weights and assignments that seem to be the problem.
# Minimal example input: 2 documents with each 4 words
words = jnp.array([[0, 0, 1, 1],
[0, 1, 1, 1]])
alphabet ='01'
cluster_num = 8
input_data = words, len(alphabet), cluster_num
def model_SBP(input_data):
word_data, alphabet_length, cluster_num = input_data
# parameters
document_count = word_data.shape[0]
word_count = word_data.shape[1]
# define rates
mutation_rate = 0.8
error_rate = 0.5
# create matrix of rates
mutation_rate_matrix = jnp.array([mutation_rate, 1 - mutation_rate])
cluster_weights = numpyro.sample(
"cluster_weights", dist.Dirichlet(jnp.ones((document_count, word_count, cluster_num)))
)
with numpyro.plate("cluster_axis", cluster_num, dim=-1):
centers = numpyro.sample(
"centers", dist.Categorical(mutation_rate_matrix)
)
cluster_assignments = numpyro.sample(
"cluster_assignments", dist.Categorical(cluster_weights)
)
with numpyro.plate("document_plate", document_count, dim=-2):
with numpyro.plate("word_axis", word_count, dim=-1):
#cluster_assignments = numpyro.sample(
# "cluster_assignments", dist.Categorical(cluster_weights)
#)
# create matrix of rates depending on cluster assignments and centers
error_rate_matrix = jnp.full(
(document_count, word_count, alphabet_length),
(1 - error_rate) / (alphabet_length - 1),
)
numpyro.sample("obs", dist.Categorical(error_rate_matrix), obs=word_data)
# Model fitting
rng_key = jax.random.PRNGKey(0)
num_samples = 10
num_warmup = int(num_samples / 2)
kernel = NUTS(model_SBP)
mcmc = MCMC(
DiscreteHMCGibbs(kernel),
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=1
)
mcmc.run(rng_key, input_data)
posterior_samples = mcmc.get_samples()
posterior_predictive = Predictive(model_SBP, posterior_samples, infer_discrete=True)
posterior_predictions = posterior_predictive(rng_key, input_data=input_data)
I am very grateful for any hints on what the problem could be.