Hi,

I am trying to develop some intuition about what’s happening to the joint log probability of the model as a function of `x`

.

This is what my model looks like,

```
def model(data):
x = numpyro.sample("x", dist.Uniform(low=-2, high=2), sample_shape=(10,))
img_estimated = some_function(x)
with numpyro.plate("data", size=data.shape[0]):
obs_image = numpyro.sample(
"obs",
dist.TruncatedNormal(
loc=img_estimated, scale=measurement_std, low=0, high=1
),
obs=data,
)
```

Specifically, I want to define a function that takes me from the `x`

(i.e. `x_1`

, `x_2`

, …, `x_10`

) to the joint log prob of the model.

How can I get the joint log probabilities?

Thank you,

Atharva