Hi,
I have a question about NumPyro model:
def model(data):
x_loc = jnp.zeros((Np * 3,))
x_scale = jnp.ones((Np * 3,))
x = numpyro.sample("x",dist.Normal(x_loc, x_scale).to_event(1))
img_intensity = some_function(x)
obs_image = numpyro.sample("obs", dist.TruncatedNormal(loc=img_intensity, scale=measurement_std, low=0, high=1), obs=data)
My question is how can i make scale
(measurement_std
) in the TruncatedNormal distribution a parameter that I infer during SVI?
Thanks,
Atharva