Hi,
I have a question about NumPyro model:
def model(data):
x_loc = jnp.zeros((Np * 3,))
x_scale = jnp.ones((Np * 3,))
x = numpyro.sample("x",dist.Normal(x_loc, x_scale).to_event(1))
img_intensity = some_function(x)
obs_image = numpyro.sample("obs", dist.TruncatedNormal(loc=img_intensity, scale=measurement_std, low=0, high=1), obs=data)
My question is how can i make scale
(measurement_std
) in the TruncatedNormal distribution a parameter that I infer during SVI?
Thanks,
Atharva
I think you can write it like
init_value = jnp.ones(n) # n is the dimension of img_intensity
scale_param = numpyro.param("scale", init_value, constraints=dist.constraints.positive)
Then you can use it for the observed variable as
obs_image = numpyro.sample("obs", dist.TruncatedNormal(loc=img_intensity, scale=scale_param, low=0, high=1), obs=data)
Thanks for the reply @dilara!
What if I want to assign it a distribution? Will the following work:
def model(data):
x_loc = jnp.zeros((10,))
x_scale = jnp.ones((10,))
x = numpyro.sample("x",dist.Normal(x_loc, x_scale).to_event(1))
img_intensity = some_function(x)
init_value = jnp.ones(len(img_intensity))
scale_param = numpyro.param("scale", init_value, constraints=dist.constraints.positive)
measurement_std = numpyro.sample("measurement_std", dist.Exponential(scale_param))
obs_image = numpyro.sample("obs", dist.TruncatedNormal(loc=img_intensity, scale=measurement_std, low=0, high=1), obs=data)
Also, will I have to make any changes to the guide?
I think the model looks legit programming-wise but whether it will work well or not depends on what you aim with the model.
Make sure that all latent variables in the model (x
and measurement_std
) have a matching sample
statement in the guide. If you are using SVI, you can check SVI Part I: An Introduction to Stochastic Variational Inference in Pyro — Pyro Tutorials 1.8.6 documentation for more information.
1 Like