Using the TFPKernel wrapper?

I’m attempting to use the TFPKernel wrapper to access the ReplicaExchangeMC kernel implemented in TensorFlow Probability. I wasn’t able to get MCMC to run with ReplicaExchangeMC, so I wrote a simple toy model (fitting a Gaussian) using the NoUTurnSampler to see if any of the TFP kernels work properly. I was able to sample with TFP NUTS, but the chains don’t move at all (stuck at the initial values). NumPyro NUTS works perfectly on the same model.

I’m new to NumPyro (coming from PyMC), so wanted to check that I’m using the TFPKernel class correctly before attempting to debug the source code. Toy model:

from jax import random
import numpyro.distributions as dist
from numpyro.infer import MCMC
from numpyro.contrib.tfp.mcmc import NoUTurnSampler

samples = np.random.normal(10, 2, 500)

def toy_model(observations):
    mu = numpyro.sample('mu', dist.Normal(0, 20))
    sigma = numpyro.sample('sigma', dist.HalfNormal(10))
    obs = numpyro.sample('obs', dist.Normal(mu, sigma), obs = observations) 
    
tfp_nuts_kernel = NoUTurnSampler(model = toy_model, step_size = 1)

mcmc = MCMC(tfp_nuts_kernel, num_chains = 4, num_warmup = 1000, num_samples = 2000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, observations = samples) 

trace = mcmc.get_samples(group_by_chain = True)

A secondary issue is that this syntax doesn’t work:

tfp_nuts_kernel = TFPKernel[tfp.mcmc.NoUTurnSampler](model = toy_model, step_size = 1.)
AssertionError                            Traceback (most recent call last)
Cell In[235], line 1
----> 1 kernel_test = TFPKernel[tfp.mcmc.NoUTurnSampler](model = toy_model, step_size = 1)

File ~/opt/anaconda3/envs/bayestrat-dev-m1-pymc570/lib/python3.11/site-packages/numpyro/contrib/tfp/mcmc.py:58, in _TFPKernelMeta.__getitem__(cls, kernel_class)
     57 def __getitem__(cls, kernel_class):
---> 58     assert issubclass(kernel_class, tfp.mcmc.TransitionKernel)
     59     assert (
     60         "target_log_prob_fn" in inspect.getfullargspec(kernel_class).args
     61     ), f"the first argument of {kernel_class} must be `target_log_prob_fn`"
     63     _PyroKernel = type(kernel_class.__name__, (TFPKernel,), {})

AssertionError: 

Is there an error with my usage of the TFPKernel wrapper, or an issue with the wrapper itself (perhaps related to package versions)? I’m running latest NumPyro (0.12.1) and TFP (0.21.0). I attempted to test older versions of TFP, but ran into compatibility issues with TensorFlow (unable to install versions older than 2.12.0 on M1 Mac).

TFP docs suggest that NUTS is a subclass of TransitionKernel, so I’m not sure what causes the issue. Maybe you can check what is kernel_class and test again if it a subclass of tfp TransitionKernel.

We have some tests here which seems to suggest that your code will work.