Hierarchical NUTS with Continuous and Discrete Distributions

I’m looking to perform inference with NUTS that can look at a sequence of actions and reason about the policy. Here’s a minimum working example that samples a policy and takes one action, with no observations:

import time

import jax.numpy as np
import jax.random as random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

def model():
    
    # Prior over the policy
    pi = numpyro.sample("pi", dist.Normal(np.zeros((4, 2)), np.ones((4, 2))))
    
    state = 3*np.ones(2)
    a_logits = np.matmul(pi, state)
    a = numpyro.sample("a", dist.Categorical(logits=a_logits))

# helper function for HMC inference
def run_inference(model, rng_key):
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(kernel, 1000, 2000)
    mcmc.run(rng_key)
    mcmc.print_summary()
    print('\nMCMC elapsed time:', time.time() - start)
    return mcmc.get_samples()

rng_key = random.PRNGKey(0)
samples = run_inference(model, rng_key)

However, when I run this I get a very unhelpful NotImplementedError message that traces back to calling the Categorical distribution. Is the problem that I’m running a model that uses NUTS on Categorical and Normal variables? Is there another issue with my code? Could this be fixed by using Pyro rather than NumPyro?

1 Like

@Exp5LogMingus NUTS only works for continuous latent variables. Discrete latent variables will be marginalized. However, NumPyro does not support this feature yet (we are targeting it). I think that NUTS will work for your model using Pyro.

2 Likes

Thanks! Interestingly, NumPyro can handle the sampling when any discrete variables are being conditioned on (which is true in my model). That is, I can run NUTS on the following model with NumPyro and not run into any errors:

conditioned_model = condition(model, param_map={"a": np.asarray(1)})

Are there any issues I might run into with this further down the line? For instance, would I be able to run a posterior predictive check?

I think that it should work well for you as long as your discrete variables are observation nodes. What doesn’t work yet is for models with discrete “latent” variables. :smiley:

I’ve now expanded the model, and though it is still only sampling from continuous latent variables, the NotImplementedError has returned. Specifically, the new model is as follows:

import time

import jax.numpy as np
import jax.lax as lax
import jax.random as random
from jax.ops import index, index_update

import numpyro
import numpyro.distributions as dist
from numpyro.handlers import condition
from numpyro.infer import MCMC, NUTS
from numpyro.infer.util import Predictive

def model(length):
    
    # Prior over the policy
    pi = numpyro.sample("pi", dist.Normal(np.zeros((4, 2)), np.ones((4, 2))))

    def update_state(state, action):
        # [-1, +1, -1, +1] for horizontal, vertical resp. 
        d = lax.convert_element_type(action/2, 'int32')
        fn = lambda x: index_update(state, index[d], x)
        new_state = lax.cond(np.remainder(action, 2) == 0, \
                             np.max([0, state[d]-1]), fn, \
                             np.min([length-1, state[d]+1]), fn)
        
        return new_state
    
    # Iterative function to run trajectory
    def run_trajectory(state, pi):
        
        # tests whether state is the terminal state
        cond_fun = lambda state_t: np.logical_not(np.allclose(state_t[0], (length-1)*np.ones(2)))
        
        t = 0
        init_val = (state, t)
        
        def body_fun(state_t):
            # unwrap state_t into state and t
            state, t = state_t
            
            # generate logits to sample next action
            a_logits = np.matmul(pi, state)
            a = numpyro.sample("a_{}".format(t), dist.Categorical(logits=a_logits))
            
            # update state, t, and state_t based on newly sampled action
            state = update_state(state, a)
            t += 1
            state_t = (state, t)
            return state_t
        
        return lax.while_loop(cond_fun, body_fun, init_val)
    
    state = run_trajectory(np.zeros(2), pi)[0]

conditioned_model = condition(model, param_map={"a_0": np.asarray(1), "a_1": np.asarray(3)})

# helper function for HMC inference
def run_inference(model, rng_key, length):
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(kernel, 1000, 2000)
    mcmc.run(rng_key, length)
    mcmc.print_summary()
    print('\nMCMC elapsed time:', time.time() - start)
    return mcmc.get_samples()

rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
samples = run_inference(conditioned_model, rng_key_, 2)