Question about NumPyro Model

Hi,

I have a question about NumPyro model:

def model(data):
    x_loc = jnp.zeros((Np * 3,))
    x_scale = jnp.ones((Np * 3,))
    x = numpyro.sample("x",dist.Normal(x_loc, x_scale).to_event(1))

    img_intensity = some_function(x)

    obs_image = numpyro.sample("obs", dist.TruncatedNormal(loc=img_intensity, scale=measurement_std, low=0, high=1), obs=data)

My question is how can i make scale (measurement_std) in the TruncatedNormal distribution a parameter that I infer during SVI?

Thanks,
Atharva

I think you can write it like

init_value = jnp.ones(n)  # n is the dimension of img_intensity
scale_param = numpyro.param("scale", init_value, constraints=dist.constraints.positive)

Then you can use it for the observed variable as

obs_image = numpyro.sample("obs", dist.TruncatedNormal(loc=img_intensity, scale=scale_param, low=0, high=1), obs=data)

Thanks for the reply @dilara!

What if I want to assign it a distribution? Will the following work:

def model(data):
    x_loc = jnp.zeros((10,))
    x_scale = jnp.ones((10,))
    x = numpyro.sample("x",dist.Normal(x_loc, x_scale).to_event(1))

    img_intensity = some_function(x)

    init_value = jnp.ones(len(img_intensity)) 
    scale_param = numpyro.param("scale", init_value, constraints=dist.constraints.positive)
    measurement_std = numpyro.sample("measurement_std", dist.Exponential(scale_param))

    obs_image = numpyro.sample("obs", dist.TruncatedNormal(loc=img_intensity, scale=measurement_std, low=0, high=1), obs=data)

Also, will I have to make any changes to the guide?

I think the model looks legit programming-wise but whether it will work well or not depends on what you aim with the model.

Make sure that all latent variables in the model (x and measurement_std) have a matching sample statement in the guide. If you are using SVI, you can check SVI Part I: An Introduction to Stochastic Variational Inference in Pyro — Pyro Tutorials 1.8.6 documentation for more information.

1 Like