I’m looking to perform inference with NUTS that can look at a sequence of actions and reason about the policy. Here’s a minimum working example that samples a policy and takes one action, with no observations:
import time
import jax.numpy as np
import jax.random as random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
def model():
# Prior over the policy
pi = numpyro.sample("pi", dist.Normal(np.zeros((4, 2)), np.ones((4, 2))))
state = 3*np.ones(2)
a_logits = np.matmul(pi, state)
a = numpyro.sample("a", dist.Categorical(logits=a_logits))
# helper function for HMC inference
def run_inference(model, rng_key):
start = time.time()
kernel = NUTS(model)
mcmc = MCMC(kernel, 1000, 2000)
mcmc.run(rng_key)
mcmc.print_summary()
print('\nMCMC elapsed time:', time.time() - start)
return mcmc.get_samples()
rng_key = random.PRNGKey(0)
samples = run_inference(model, rng_key)
However, when I run this I get a very unhelpful NotImplementedError
message that traces back to calling the Categorical
distribution. Is the problem that I’m running a model that uses NUTS on Categorical and Normal variables? Is there another issue with my code? Could this be fixed by using Pyro rather than NumPyro?