# In the same model, when using Enumerated NUTS, the obtained WAIC is much larger than the WIAC obtained by Mixed HMC. Is this a problem with the calculation of WAIC or other issues?

When I calculated the WAIC with discrete latent variables, I found that the WAIC obtained by the Enumerated NUTS algorithm is much larger than that obtained by the Mixed HMC algorithm. Here is a toy model, and the code for data generation and parameter estimation is as follows:

``````import numpy as np
import jax
from jax import nn, random, vmap
import jax.numpy as jnp
import pandas as pd
import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive, DiscreteHMCGibbs, MixedHMC, HMC
import arviz as az

cat_p = np.random.uniform(0, 1)
cat = np.random.binomial(1, cat_p)
obs_p =  np.array([0.95, 0.05])[cat]
Y = np.random.binomial(1,obs_p, 1000)

def tree_model():
cat_p = numpyro.sample("cat_p", dist.Beta(1,1))
cat = numpyro.sample("cat", dist.Bernoulli(cat_p), infer={"enumerate": "parallel"})
obs_p = jnp.array([0.95, 0.05])[cat]
with numpyro.plate("obs_plate", 1000, dim=-1):
obs = numpyro.sample("obs", dist.Bernoulli(obs_p), obs=Y)

def mixedHCM_tree():
cat_p = numpyro.sample("cat_p", dist.Beta(1,1))
cat = numpyro.sample("cat", dist.Bernoulli(cat_p))
obs_p = jnp.array([0.95, 0.05])[cat]
with numpyro.plate("obs_plate", 1000, dim=-1):
obs = numpyro.sample("obs", dist.Bernoulli(obs_p), obs=Y)

idatas = {}
kernel = NUTS(tree_model,)
mcmc = MCMC(kernel, num_warmup=20000, num_samples=10000, num_chains=4,)
mcmc.run(random.PRNGKey(1))
with numpyro.handlers.seed(rng_seed=1):
idatas["tree"] = az.from_numpyro(
mcmc,
posterior_predictive=pred
)

kernel = MixedHMC(HMC(mixedHCM_tree,))
mcmc = MCMC(kernel, num_warmup=20000, num_samples=10000, num_chains=4,)

mcmc.run(random.PRNGKey(1))
with numpyro.handlers.seed(rng_seed=1):
idatas["mixed_HMC_tree"] = az.from_numpyro(
mcmc,
posterior_predictive=pred
)
``````

In theory, when estimating the same model with different estimation algorithms, the WAIC values should be similar. However, in the current results, the WAIC values differ greatly.

``````az.compare(idatas, scale="deviance", ic="waic")
``````

Therefore, I would like to inquire why there is a significant difference in WAIC when using the Enumerated NUTS algorithm compared to using Mixed HMC algorithm. Is it a problem with the calculation of WAIC? If so, how can I correctly calculate fitting indicators such as -2LL and WAIC when using the Enumerated NUTS algorithm?
Thank you in advance. This means a lot to me.

this seems like a more appropriate question for developers of `arviz`. what does `arviz` do with the discrete latent variables returned by `MixedHMC`?

Thank you for your help. I realized that WAIC is related to -2LL, so I tried using `numpyro.infer.log_likelihood()` to calculate the -2LL for the Enumerated NUTS algorithm, but I got an error. I would like to know how to use numpyro to calculate the -2LL for the Enumerated NUTS model.
Code and error are as follows:

``````
samples = mcmc.get_samples(group_by_chain=False)
numpyro.infer.log_likelihood(
tree_model, samples)
####################################################
# error
AssertionError                            Traceback (most recent call last)
Cell In, line 2
1 samples = mcmc.get_samples(group_by_chain=False)
----> 2 numpyro.infer.log_likelihood(
3                 tree_model, samples)

File ~/miniconda3/envs/pyro/lib/python3.11/site-packages/numpyro/infer/util.py:1029, in log_likelihood(model, posterior_samples, parallel, batch_ndims, *args, **kwargs)
1027 batch_size = int(np.prod(batch_shape))
1028 chunk_size = batch_size if parallel else 1
-> 1029 return soft_vmap(single_loglik, posterior_samples, len(batch_shape), chunk_size)

File ~/miniconda3/envs/pyro/lib/python3.11/site-packages/numpyro/util.py:410, in soft_vmap(fn, xs, batch_ndims, chunk_size)
404     xs = tree_map(
405         lambda x: jnp.reshape(x, prepend_shape + (chunk_size,) + jnp.shape(x)[1:]),
406         xs,
407     )
408     fn = vmap(fn)
--> 410 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
411 map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
412 ys = tree_map(
413     lambda y: jnp.reshape(
414         y, (int(np.prod(jnp.shape(y)[:map_ndims])),) + jnp.shape(y)[map_ndims:]
415     )[:batch_size],
416     ys,
...
85         key, self.probs, shape=sample_shape + self.batch_shape
86     )
87     return samples.astype(jnp.result_type(samples, int))

AssertionError:
``````

In addition, I compared (i) the -2LL calculated using the potential discrete variable estimate (i.e., the posterior prediction of the model) of the enumerated NUST algorithm, (ii) the -2LL output of mixedHMC’s `numpyro.infer.log_likelihood()`, and (iii) the -2LL output of the enumerated NUTS algorithm from arviz. Does it mean that the -2LL obtained by arviz’s enumerated NUTS algorithm is biased, and whether the -2LL can be calculated by using the potential discrete variable estimates instead of using the results output by arviz. The code and its computational results are as follows:

``````# 1) posterior_predictive calculated -2LL with enumerated NUST algorithm
cat_pre = idatas["tree"].posterior_predictive["cat"].values
pre_obs_p =  np.array([0.95, 0.05])[cat_pre].reshape(4, 10000, 1)
np.log(pre_obs_p*Y.reshape(1,1,-1)+(1-pre_obs_p)*(1-Y.reshape(1,1,-1))).mean(axis=(0,1)).sum()*-2
# -2LL: 402.91936465024804

# 2) numpyro output -2LL with mixedHMC algorithm
samples = mcmc.get_samples(group_by_chain=False)
numpyro.infer.log_likelihood(
mixedHCM_tree, samples)["obs"].mean(axis=0).sum()*-2
# -2LL: 402.9194

# 3) arviz output -2LL with  enumerated NUST algorithm
idatas["tree"].log_likelihood["obs"].mean(["chain", "draw"]).sum()*-2
# -2LL: 484.18322754

``````

In addition, I would like to know if `numpyro.infer.log_likelihood()` can be used to calculate the likelihood of the enumeration NUTS algorithm. If so, how can this be implemented in the code?