I am trying to develop some intuition about what’s happening to the joint log probability of the model as a function of
This is what my model looks like,
def model(data): x = numpyro.sample("x", dist.Uniform(low=-2, high=2), sample_shape=(10,)) img_estimated = some_function(x) with numpyro.plate("data", size=data.shape): obs_image = numpyro.sample( "obs", dist.TruncatedNormal( loc=img_estimated, scale=measurement_std, low=0, high=1 ), obs=data, )
Specifically, I want to define a function that takes me from the
x_10) to the joint log prob of the model.
How can I get the joint log probabilities?