I want to implement a WAIC (Watanabe–Akaike information criterion) to help me on model evaluation & seletion. For this purpose, I need to calculate a log pointwise predictive density (lppd) as follows:
This involves a density for each posterior sample and each observation.
And I noticed that there is a runtime utility log_density() in NumPyro.
My question is, can I use it, first get the log density, and then simply apply a np.exp() operation to its return value to get the desired density? Or if any better suggestion to calculate the density?
I have an adapted version from log_density as follows:
from numpyro.handlers import substitute, trace
from jax.lax import broadcast_shapes
def log_density(model, dataset, params):
model = substitute(model, data=params)
model_trace = trace(model).get_trace(dataset)
log_joint = jnp.zeros(())
for site in model_trace.values():
if site["type"] == "sample":
value = site["value"]
intermediates = site["intermediates"]
scale = site["scale"]
if intermediates:
log_prob = site["fn"].log_prob(value, intermediates)
else:
guide_shape = jnp.shape(value)
model_shape = tuple(
site["fn"].shape()
) # TensorShape from tfp needs casting to tuple
try:
broadcast_shapes(guide_shape, model_shape)
except ValueError:
raise ValueError(
"Model and guide shapes disagree at site: '{}': {} vs {}".format(
site["name"], model_shape, guide_shape
)
)
log_prob = site["fn"].log_prob(value)
if (scale is not None) and (not is_identically_one(scale)):
log_prob = scale * log_prob
log_prob = jnp.sum(log_prob)
log_joint = log_joint + log_prob
return log_joint, model_trace
And this works for my model to get log_density for a specific set of sampled params.
However, I find that I can get 0.0 when I taking np.exp() on the derived log_density since the log_density can be very small negative number (e.g., -34562.83977473). Seems that this is not an appropriate way for calculating density.
But I notice that the Distribution base class in NumPyro only supports a log_prob method that is related to a distriburion’s density.
Is there any suggestion about a potential way to implement a function to calculate density / lppd?