Implementing WAIC for NumPyro model

Hi there,

I want to implement a WAIC (Watanabe–Akaike information criterion) to help me on model evaluation & seletion. For this purpose, I need to calculate a log pointwise predictive density (lppd) as follows:

image

This involves a density for each posterior sample and each observation.

And I noticed that there is a runtime utility log_density() in NumPyro.

My question is, can I use it, first get the log density, and then simply apply a np.exp() operation to its return value to get the desired density? Or if any better suggestion to calculate the density?

Many Thanks!

I have an adapted version from log_density as follows:

from numpyro.handlers import substitute, trace
from jax.lax import broadcast_shapes

def log_density(model, dataset, params):
    model = substitute(model, data=params)
    model_trace = trace(model).get_trace(dataset)
    log_joint = jnp.zeros(())
    for site in model_trace.values():
        if site["type"] == "sample":
            value = site["value"]
            intermediates = site["intermediates"]
            scale = site["scale"]
            if intermediates:
                log_prob = site["fn"].log_prob(value, intermediates)
            else:
                guide_shape = jnp.shape(value)
                model_shape = tuple(
                    site["fn"].shape()
                )  # TensorShape from tfp needs casting to tuple
                try:
                    broadcast_shapes(guide_shape, model_shape)
                except ValueError:
                    raise ValueError(
                        "Model and guide shapes disagree at site: '{}': {} vs {}".format(
                            site["name"], model_shape, guide_shape
                        )
                    )
                log_prob = site["fn"].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            log_prob = jnp.sum(log_prob)
            log_joint = log_joint + log_prob
    return log_joint, model_trace

And this works for my model to get log_density for a specific set of sampled params.

However, I find that I can get 0.0 when I taking np.exp() on the derived log_density since the log_density can be very small negative number (e.g., -34562.83977473). Seems that this is not an appropriate way for calculating density.

But I notice that the Distribution base class in NumPyro only supports a log_prob method that is related to a distriburion’s density.

Is there any suggestion about a potential way to implement a function to calculate density / lppd?

Thanks for any feedback!

Update:

I have a working version of WAIC with log_likelihood() which produces log likehood matrix as a start point.

My implementation refers to several R implementations of WAIC, e.g.,:

Widely Applicable Information Criterion

calculates the WAIC

Hope my experience can help others who has similar demand.

Also welcome further discussion on this topic. Thanks!

arviz has a good implementation here. The core logic is pretty simple, you just need to ignore all the wrappers.

Thank you for your suggestion! I will check it out.

What I’ve been using. Port from Pyro

import jax.numpy as jnp
import jax.scipy.special as jsp


def _weighted_mean(
    input: jnp.ndarray,
    log_weights: jnp.ndarray,
    axis: int = 0,
    keepdims: bool = False,
) -> jnp.ndarray:
    """
    Computes the weighted mean of `input` using `log_weights` along the specified axis.

    Args:
        input: Array containing values whose mean is computed.
        log_weights: Log weights array. Must be 1D and match `input.shape[axis]`.
        axis: Axis along which to compute the mean.
        keepdims: Whether to keep the reduced axis dimension.

    Returns:
        Weighted mean of `input` along the specified axis.
    """
    axis = input.ndim + axis if axis < 0 else axis

    expand_dims = tuple(range(1, input.ndim - axis))
    log_weights = jnp.expand_dims(log_weights, axis=expand_dims)

    log_weights_norm = log_weights - jsp.logsumexp(
        log_weights, axis=axis, keepdims=True
    )
    weights = jnp.exp(log_weights_norm)

    return jnp.sum(input * weights, axis=axis, keepdims=keepdims)


def _weighted_variance(
    input: jnp.ndarray,
    log_weights: jnp.ndarray,
    axis: int = 0,
    keepdims: bool = False,
    unbiased: bool = False,
) -> jnp.ndarray:
    """
    Computes the weighted variance of `input` using `log_weights` along the specified axis.

    Args:
        input: Array containing values whose variance is computed.
        log_weights: Log weights array. Must be 1D and match `input.shape[axis]`.
        axis: Axis along which to compute the variance.
        keepdims: Whether to keep the reduced axis dimension.
        unbiased: Whether to apply Bessel's correction (use N/(N-1) scaling).

    Returns:
        Weighted variance of `input` along the specified axis.
    """
    mean = _weighted_mean(input, log_weights, axis=axis, keepdims=True)
    deviation_squared = (input - mean) ** 2

    if unbiased:
        weights = jnp.exp(log_weights - jsp.logsumexp(log_weights))
        ess = (jnp.sum(weights) ** 2) / jnp.sum(weights**2)
        correction = ess / (ess - 1.0)
    else:
        correction = 1.0

    return (
        _weighted_mean(deviation_squared, log_weights, axis=axis, keepdims=keepdims)
        * correction
    )


def waic(
    log_likelihood: jnp.ndarray,
    log_weights: jnp.ndarray = None,
    pointwise: bool = False,
    axis: int = 0,
):
    """
    Computes the WAIC (Widely Applicable Information Criterion) for Bayesian model comparison.

    Args:
        log_likelihood: Log-likelihood array of shape (n_samples, ...).
        log_weights: Optional log weights array of shape (n_samples,).
        pointwise: Return pointwise values instead of summed.
        axis: Sample axis in `log_likelihood` (default 0).

    Returns:
        tuple: (WAIC value, effective number of parameters p_waic)
                If `pointwise=True`, returns pointwise components instead of summed totals.
    """
    if log_weights is None:
        log_weights = jnp.zeros(log_likelihood.shape[axis])
    else:
        expected_shape = (log_likelihood.shape[axis],)
        if log_weights.shape != expected_shape:
            raise ValueError(
                f"log_weights must have shape {expected_shape}, got {log_weights.shape}"
            )

    axis = log_likelihood.ndim + axis if axis < 0 else axis

    expanded_weights = log_weights.reshape(
        [-1] + (log_likelihood.ndim - axis - 1) * [1]
    )
    weighted_log_likelihood = log_likelihood + expanded_weights
    lpd = jsp.logsumexp(weighted_log_likelihood, axis=axis) - jsp.logsumexp(
        log_weights, axis=0
    )

    p_waic = _weighted_variance(log_likelihood, log_weights, axis=axis, unbiased=False)

    elpd = lpd - p_waic
    waic_values = -2 * elpd

    return (
        (waic_values, p_waic) if pointwise else (jnp.sum(waic_values), jnp.sum(p_waic))
    )