I’m looking to perform inference with NUTS that can look at a sequence of actions and reason about the policy. Here’s a minimum working example that samples a policy and takes one action, with no observations:
import time import jax.numpy as np import jax.random as random import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS def model(): # Prior over the policy pi = numpyro.sample("pi", dist.Normal(np.zeros((4, 2)), np.ones((4, 2)))) state = 3*np.ones(2) a_logits = np.matmul(pi, state) a = numpyro.sample("a", dist.Categorical(logits=a_logits)) # helper function for HMC inference def run_inference(model, rng_key): start = time.time() kernel = NUTS(model) mcmc = MCMC(kernel, 1000, 2000) mcmc.run(rng_key) mcmc.print_summary() print('\nMCMC elapsed time:', time.time() - start) return mcmc.get_samples() rng_key = random.PRNGKey(0) samples = run_inference(model, rng_key)
However, when I run this I get a very unhelpful
NotImplementedError message that traces back to calling the
Categorical distribution. Is the problem that I’m running a model that uses NUTS on Categorical and Normal variables? Is there another issue with my code? Could this be fixed by using Pyro rather than NumPyro?