I have an adapted version from log_density as follows:

```
from numpyro.handlers import substitute, trace
from jax.lax import broadcast_shapes
def log_density(model, dataset, params):
model = substitute(model, data=params)
model_trace = trace(model).get_trace(dataset)
log_joint = jnp.zeros(())
for site in model_trace.values():
if site["type"] == "sample":
value = site["value"]
intermediates = site["intermediates"]
scale = site["scale"]
if intermediates:
log_prob = site["fn"].log_prob(value, intermediates)
else:
guide_shape = jnp.shape(value)
model_shape = tuple(
site["fn"].shape()
) # TensorShape from tfp needs casting to tuple
try:
broadcast_shapes(guide_shape, model_shape)
except ValueError:
raise ValueError(
"Model and guide shapes disagree at site: '{}': {} vs {}".format(
site["name"], model_shape, guide_shape
)
)
log_prob = site["fn"].log_prob(value)
if (scale is not None) and (not is_identically_one(scale)):
log_prob = scale * log_prob
log_prob = jnp.sum(log_prob)
log_joint = log_joint + log_prob
return log_joint, model_trace
```

And this works for my model to get log_density for a specific set of sampled params.

However, I find that I can get `0.0`

when I taking `np.exp()`

on the derived log_density since the log_density can be very small negative number (e.g., `-34562.83977473`

). Seems that this is not an appropriate way for calculating density.

But I notice that the Distribution base class in NumPyro only supports a log_prob method that is related to a distriburion’s density.

Is there any suggestion about a potential way to implement a function to calculate density / lppd?

Thanks for any feedback!