Hello everyone! I am very new to numpyro and despite trying to search online or on this forum for solutions, I am unable to reason out how to fix my code.
@config_enumerate
def gamma_mixture(data):
K = 3
# Global variables.
weights = numpyro.sample("weights", dist.Dirichlet(jnp.ones(K)))
with numpyro.plate("components", K):
alpha = abs(numpyro.sample("alpha", dist.Uniform(0, 10)))
beta = 1/abs(numpyro.sample("beta", dist.Uniform(0, 3)))
with numpyro.plate("data", len(data)):
# Local variables.
assignment = numpyro.sample("assignment", dist.Categorical(weights))
numpyro.sample("obs", dist.Gamma(concentration=alpha[assignment], rate=beta[assignment]), obs=data)
# Let's try to run MCMC
chosen_model = gamma_mixture
nuts_kernel = NUTS(model=chosen_model, target_accept_prob=0.80)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=2000, num_chains=1)
mcmc.run(random.PRNGKey(0), data)
mcmc.print_summary()
I appreciate in advance for your assistance!