Log joint probability of NumPyro Model

Hi,

I am trying to develop some intuition about what’s happening to the joint log probability of the model as a function of x.

This is what my model looks like,

def model(data):
    x = numpyro.sample("x", dist.Uniform(low=-2, high=2), sample_shape=(10,))
    img_estimated = some_function(x)

    with numpyro.plate("data", size=data.shape[0]):
        obs_image = numpyro.sample(
            "obs",
            dist.TruncatedNormal(
                loc=img_estimated, scale=measurement_std, low=0, high=1
            ),
            obs=data,
        )

Specifically, I want to define a function that takes me from the x (i.e. x_1, x_2, …, x_10) to the joint log prob of the model.

How can I get the joint log probabilities?

Thank you,
Atharva

You can use the helper function numpyro.infer.util.log_density:

def model(data):
    x = numpyro.sample("x", ...)
    ...

log_joint, _ = numpyro.infer.util.log_density(model, (data,), {}, {"x": x_value})
1 Like

Thank you @eb8680_2. It worked great!