I am running a plain vanilla latent dirichlet allocation model with Numpyro, and I am interested in investigating results once I fix the topic concentration (beta) parameters. I have noticed that 1) if I use substitute() or condition() trying to substitute in a site that does not actually exist in the model, the model runs from scratch and does not throw any error. This can be quite confusing, so I was wondering whether it is intentional; 2) if I substitute in a site that does exist, condition() works while substitute() breaks. As i want to treat beta parameters as data, I thought substitute() would be the most appropriate here, so I wonder what I am doing wrong.
Here the model:
def hmc_simple_model(W, K, alpha, eta):
D, V = jnp.shape(W)
N = W.sum(axis = 1)
with numpyro.plate("topics", K):
# topic-word distributions
beta = numpyro.sample("beta", dist.Dirichlet((eta) * jnp.ones([V])))
with numpyro.plate("docs", D):
# document-topic distributions
theta = numpyro.sample("theta", dist.Dirichlet(alpha*jnp.ones([K])))
ThetaBeta = jnp.matmul(theta, beta)
distMultinomial = dist.Multinomial(total_count = N, probs = ThetaBeta)
with numpyro.plate("hist", D):
numpyro.sample("obs", distMultinomial, obs = W)
Minimal code for data simulation:
D = 200 # n. documents
V = 3000 # n. terms in the vocabulary
eta = 0.2 # hyperparameter for topic concentrations
alpha = 1 # hyperparameter for document distributions
K = 10 # number of topics
# topic-term distributions
beta = dist.Dirichlet(eta * jnp.ones([V])).expand([K]).sample(Key(8))
# document-topic distributions
theta = dist.Dirichlet(alpha * jnp.ones([K])).expand([D]).sample(Key(89))
thetabeta = jnp.matmul(theta,beta)
N = jnp.floor(dist.Uniform(100,200).expand([D]).sample(Key(2)))
# doc-term matrix
W = dist.Multinomial(total_count = N, probs = thetabeta).sample(Key(20))
Minimal code for running the model:
n_warmup = 500
n_samples = 500
# substitute (or condition) with non-existent site: this runs
hmc_model_sub = condition(hmc_simple_model, {"not_a_site" : 5})
nuts_kernel_sub = NUTS(hmc_model_sub)
mcmc_sub = MCMC(nuts_kernel_sub, num_warmup=n_warmup,
num_samples=n_samples)
mcmc_sub.run(rng_key = Key(75), W = W, eta=eta, K=K, alpha=alpha)
# substitute with existent site: this does not run (while condition runs)
hmc_model_sub = substitute(hmc_simple_model, {"beta" : beta})
nuts_kernel_sub = NUTS(hmc_model_sub)
mcmc_sub = MCMC(nuts_kernel_sub, num_warmup=n_warmup,
num_samples=n_samples)
mcmc_sub.run(rng_key = Key(75), W = W, eta=eta, K=K, alpha=alpha)
And below the error I get when trying substitute() with an existent site:
