Hi,
I am trying to develop some intuition about what’s happening to the joint log probability of the model as a function of x
.
This is what my model looks like,
def model(data):
x = numpyro.sample("x", dist.Uniform(low=-2, high=2), sample_shape=(10,))
img_estimated = some_function(x)
with numpyro.plate("data", size=data.shape[0]):
obs_image = numpyro.sample(
"obs",
dist.TruncatedNormal(
loc=img_estimated, scale=measurement_std, low=0, high=1
),
obs=data,
)
Specifically, I want to define a function that takes me from the x
(i.e. x_1
, x_2
, …, x_10
) to the joint log prob of the model.
How can I get the joint log probabilities?
Thank you,
Atharva