Log joint probability of NumPyro Model


I am trying to develop some intuition about what’s happening to the joint log probability of the model as a function of x.

This is what my model looks like,

def model(data):
    x = numpyro.sample("x", dist.Uniform(low=-2, high=2), sample_shape=(10,))
    img_estimated = some_function(x)

    with numpyro.plate("data", size=data.shape[0]):
        obs_image = numpyro.sample(
                loc=img_estimated, scale=measurement_std, low=0, high=1

Specifically, I want to define a function that takes me from the x (i.e. x_1, x_2, …, x_10) to the joint log prob of the model.

How can I get the joint log probabilities?

Thank you,

You can use the helper function numpyro.infer.util.log_density:

def model(data):
    x = numpyro.sample("x", ...)

log_joint, _ = numpyro.infer.util.log_density(model, (data,), {}, {"x": x_value})
1 Like

Thank you @eb8680_2. It worked great!