In the same model, when using Enumerated NUTS, the obtained WAIC is much larger than the WIAC obtained by Mixed HMC. Is this a problem with the calculation of WAIC or other issues?

When I calculated the WAIC with discrete latent variables, I found that the WAIC obtained by the Enumerated NUTS algorithm is much larger than that obtained by the Mixed HMC algorithm. Here is a toy model, and the code for data generation and parameter estimation is as follows:

import numpy as np
import jax
from jax import nn, random, vmap
import jax.numpy as jnp
import pandas as pd
import numpyro
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive, DiscreteHMCGibbs, MixedHMC, HMC
import arviz as az

cat_p = np.random.uniform(0, 1)
cat = np.random.binomial(1, cat_p)
obs_p =  np.array([0.95, 0.05])[cat]
Y = np.random.binomial(1,obs_p, 1000)


def tree_model():
    cat_p = numpyro.sample("cat_p", dist.Beta(1,1))
    cat = numpyro.sample("cat", dist.Bernoulli(cat_p), infer={"enumerate": "parallel"})
    obs_p = jnp.array([0.95, 0.05])[cat]
    with numpyro.plate("obs_plate", 1000, dim=-1):
        obs = numpyro.sample("obs", dist.Bernoulli(obs_p), obs=Y)

def mixedHCM_tree():
    cat_p = numpyro.sample("cat_p", dist.Beta(1,1))
    cat = numpyro.sample("cat", dist.Bernoulli(cat_p))
    obs_p = jnp.array([0.95, 0.05])[cat]
    with numpyro.plate("obs_plate", 1000, dim=-1):
        obs = numpyro.sample("obs", dist.Bernoulli(obs_p), obs=Y)

idatas = {}
kernel = NUTS(tree_model,)
mcmc = MCMC(kernel, num_warmup=20000, num_samples=10000, num_chains=4,)
mcmc.run(random.PRNGKey(1))
with numpyro.handlers.seed(rng_seed=1):
    idatas["tree"] = az.from_numpyro(
                        mcmc,
                        posterior_predictive=pred
                        )

kernel = MixedHMC(HMC(mixedHCM_tree,))
mcmc = MCMC(kernel, num_warmup=20000, num_samples=10000, num_chains=4,)

mcmc.run(random.PRNGKey(1))
with numpyro.handlers.seed(rng_seed=1):
    idatas["mixed_HMC_tree"] = az.from_numpyro(
                        mcmc,
                        posterior_predictive=pred
                        )

In theory, when estimating the same model with different estimation algorithms, the WAIC values should be similar. However, in the current results, the WAIC values differ greatly.

az.compare(idatas, scale="deviance", ic="waic")

Therefore, I would like to inquire why there is a significant difference in WAIC when using the Enumerated NUTS algorithm compared to using Mixed HMC algorithm. Is it a problem with the calculation of WAIC? If so, how can I correctly calculate fitting indicators such as -2LL and WAIC when using the Enumerated NUTS algorithm?
Thank you in advance. This means a lot to me.

this seems like a more appropriate question for developers of arviz. what does arviz do with the discrete latent variables returned by MixedHMC?

Thank you for your help. I realized that WAIC is related to -2LL, so I tried using numpyro.infer.log_likelihood() to calculate the -2LL for the Enumerated NUTS algorithm, but I got an error. I would like to know how to use numpyro to calculate the -2LL for the Enumerated NUTS model.
Code and error are as follows:


samples = mcmc.get_samples(group_by_chain=False)
numpyro.infer.log_likelihood(
                tree_model, samples)
####################################################
# error
AssertionError                            Traceback (most recent call last)
Cell In[129], line 2
      1 samples = mcmc.get_samples(group_by_chain=False)
----> 2 numpyro.infer.log_likelihood(
      3                 tree_model, samples)

File ~/miniconda3/envs/pyro/lib/python3.11/site-packages/numpyro/infer/util.py:1029, in log_likelihood(model, posterior_samples, parallel, batch_ndims, *args, **kwargs)
   1027 batch_size = int(np.prod(batch_shape))
   1028 chunk_size = batch_size if parallel else 1
-> 1029 return soft_vmap(single_loglik, posterior_samples, len(batch_shape), chunk_size)

File ~/miniconda3/envs/pyro/lib/python3.11/site-packages/numpyro/util.py:410, in soft_vmap(fn, xs, batch_ndims, chunk_size)
    404     xs = tree_map(
    405         lambda x: jnp.reshape(x, prepend_shape + (chunk_size,) + jnp.shape(x)[1:]),
    406         xs,
    407     )
    408     fn = vmap(fn)
--> 410 ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
    411 map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
    412 ys = tree_map(
    413     lambda y: jnp.reshape(
    414         y, (int(np.prod(jnp.shape(y)[:map_ndims])),) + jnp.shape(y)[map_ndims:]
    415     )[:batch_size],
    416     ys,
...
     85         key, self.probs, shape=sample_shape + self.batch_shape
     86     )
     87     return samples.astype(jnp.result_type(samples, int))

AssertionError: 

In addition, I compared (i) the -2LL calculated using the potential discrete variable estimate (i.e., the posterior prediction of the model) of the enumerated NUST algorithm, (ii) the -2LL output of mixedHMC’s numpyro.infer.log_likelihood(), and (iii) the -2LL output of the enumerated NUTS algorithm from arviz. Does it mean that the -2LL obtained by arviz’s enumerated NUTS algorithm is biased, and whether the -2LL can be calculated by using the potential discrete variable estimates instead of using the results output by arviz. The code and its computational results are as follows:

# 1) posterior_predictive calculated -2LL with enumerated NUST algorithm
cat_pre = idatas["tree"].posterior_predictive["cat"].values
pre_obs_p =  np.array([0.95, 0.05])[cat_pre].reshape(4, 10000, 1)
np.log(pre_obs_p*Y.reshape(1,1,-1)+(1-pre_obs_p)*(1-Y.reshape(1,1,-1))).mean(axis=(0,1)).sum()*-2
# -2LL: 402.91936465024804

# 2) numpyro output -2LL with mixedHMC algorithm
samples = mcmc.get_samples(group_by_chain=False)
numpyro.infer.log_likelihood(
                mixedHCM_tree, samples)["obs"].mean(axis=0).sum()*-2
# -2LL: 402.9194

# 3) arviz output -2LL with  enumerated NUST algorithm
idatas["tree"].log_likelihood["obs"].mean(["chain", "draw"]).sum()*-2
# -2LL: 484.18322754

Thanks again for your help.

If you run mcmc with enumeration, you need to use infer_discrete after that to get posterior samples for the discrete latent variables. Then I guess you can compare two results using the non-enumerated model.

1 Like

I manually calculated the -2LL corresponding to the posterior samples of the discrete latent variables, and it is very close to the -2LL obtained by MixedHMC, which is a parameter estimation algorithm that does not require enumeration. However, the -2LL directly output by the enumeration NUTS algorithm differs from the -2LL obtained by MixedHMC. Therefore, I think that manually calculating the -2LL using posterior samples may be a better choice for the enumeration NUTS algorithm.

In addition, I would like to know if numpyro.infer.log_likelihood() can be used to calculate the likelihood of the enumeration NUTS algorithm. If so, how can this be implemented in the code?

Thank you for your patience and time.

As mentioned in the previous comment, you might need to use infer discrete to draw posterior samples for discrete variables. Then you can use log_likelihood for non-enumerated model, which works for both samples from mixed hmc and “enumerated nuts + infer discrete”. Currently, log likelihood does not support enumerated model. I’m not sure why arviz comes up with a number as in your comment…