Implementing WAIC for NumPyro model

Hi there,

I want to implement a WAIC (Watanabe–Akaike information criterion) to help me on model evaluation & seletion. For this purpose, I need to calculate a log pointwise predictive density (lppd) as follows:

image

This involves a density for each posterior sample and each observation.

And I noticed that there is a runtime utility log_density() in NumPyro.

My question is, can I use it, first get the log density, and then simply apply a np.exp() operation to its return value to get the desired density? Or if any better suggestion to calculate the density?

Many Thanks!

I have an adapted version from log_density as follows:

from numpyro.handlers import substitute, trace
from jax.lax import broadcast_shapes

def log_density(model, dataset, params):
    model = substitute(model, data=params)
    model_trace = trace(model).get_trace(dataset)
    log_joint = jnp.zeros(())
    for site in model_trace.values():
        if site["type"] == "sample":
            value = site["value"]
            intermediates = site["intermediates"]
            scale = site["scale"]
            if intermediates:
                log_prob = site["fn"].log_prob(value, intermediates)
            else:
                guide_shape = jnp.shape(value)
                model_shape = tuple(
                    site["fn"].shape()
                )  # TensorShape from tfp needs casting to tuple
                try:
                    broadcast_shapes(guide_shape, model_shape)
                except ValueError:
                    raise ValueError(
                        "Model and guide shapes disagree at site: '{}': {} vs {}".format(
                            site["name"], model_shape, guide_shape
                        )
                    )
                log_prob = site["fn"].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            log_prob = jnp.sum(log_prob)
            log_joint = log_joint + log_prob
    return log_joint, model_trace

And this works for my model to get log_density for a specific set of sampled params.

However, I find that I can get 0.0 when I taking np.exp() on the derived log_density since the log_density can be very small negative number (e.g., -34562.83977473). Seems that this is not an appropriate way for calculating density.

But I notice that the Distribution base class in NumPyro only supports a log_prob method that is related to a distriburion’s density.

Is there any suggestion about a potential way to implement a function to calculate density / lppd?

Thanks for any feedback!

Update:

I have a working version of WAIC with log_likelihood() which produces log likehood matrix as a start point.

My implementation refers to several R implementations of WAIC, e.g.,:

Widely Applicable Information Criterion

calculates the WAIC

Hope my experience can help others who has similar demand.

Also welcome further discussion on this topic. Thanks!

arviz has a good implementation here. The core logic is pretty simple, you just need to ignore all the wrappers.

Thank you for your suggestion! I will check it out.