Constraining sample values

Hi,

I am working with a model where I would like to have conditionally independent betas that are drawn from a global beta distribution. However, each of the betas has a different constraint interval. When I try to apply these intervals, the sampled value may be outside of the constraint interval, and then the model initiates with a NaN loss and never recovers. Here is a simplified example for illustration:

def model(dat,est=None):

beta_mu = numpyro.sample("beta_mu", dist.Normal(0.,1.))
beta_sigma = numpyro.sample("beta_sigma", dist.HalfNormal(1.))

with numpyro.plate("plate_groups", 3):
  betas = numpyro.sample("betas", dist.Normal(np.array([beta_mu]*3),np.array([beta_sigma]*3))
  
estimate = betas[dat["group"]]

observed_sigma = numpyro.sample("observed_sigma", dist.HalfNormal(1.))

with numpyro.plate("data", len(dat)):
    numpyro.sample("obs", dist.Normal(estimate, observed_sigma), obs=est)

def guide(dat,est=None):

beta_mu = numpyro.sample("beta_mu", dist.Normal(loc = numpyro.param("loc_beta_mu", 0.), scale = numpyro.param("scale_beta_mu", 1., constraint=dist.constraint.positive)))
beta_sigma = numpyro.sample("beta_sigma", dist.HalfNormal( numpyro.param("scale_beta_sigma", 1., constraint=dist.constraint.positive))

with numpyro.plate("plate_groups", 3):
  betas = numpyro.sample("betas", dist.Normal(
                          loc = numpyro.param("loc_betas", np.array([beta_mu]*3), constraint= dist.constraints.interval(lower_bound = [-3, -3, -1], upper_bound = [2, 3, 3])), 
                          scale = numpyro.param("scale_betas", np.array([beta_sigma]*3), constraint=dist.constraint.positive)))

observed_sigma = numpyro.sample("observed_sigma", dist.HalfNormal(numpyro.param("scale_observed_sigma", 1., constraint=dist.constraint.positive)))

Is it possible constrain the sampled values for beta_mu so that they are not outside of the beta constrain intervals? Constraining the location and scale of beta_mu was not sufficient. I also tried working with converting beta_mu to a numpy array and manually enforcing the conditions, and while that was working if I pulled a sample on its own, inside the model it was failing to convert the JAX device array to numpy.

Thanks in advance!

Hi @tjm, I think what you are looking for is TruncatedDistribution

beta_mu = numpyro.sample(
    "beta_mu", TruncatedDistribution(Normal(...), low=..., high=...))

Thanks, @fehiepsi, exactly what I need. As a novice in numpyro the documentation isn’t always the easiest to navigate, but you’ve been very helpful in the forums so I appreciate your responses!

Yeah, we always want to make the documentation more accessible (along with examples/tutorials from the community). Please let me know if you have any suggestions. How about having a notebook for tips and tricks, which links to questions and answers in the forum/github,…? I feel that it would be super helpful.

I think that would be super helpful. Speaking for myself, I found numpyro while looking for a computationally efficient bayesian framework and stumbled upon this post: https://florianwilhelm.info/2020/10/bayesian_hierarchical_modelling_at_scale/

It was a great post, but I think too complex to fully understand without the basics. I then found the pyro SVI tutorial (https://pyro.ai/examples/svi_part_i.html) which was great for a true intro. What I haven’t found a lot of are examples that are slightly more complex (preparing data and data shapes appropriately, nesting plates, incorporating constraints or transforming/truncating distributions, etc.), but are still generalizable. Tutorials with that type of real-world application of the general case would be fantastic, and even a more in depth or varied tips and tricks for helping to optimally design and tune the model. This one is a nice start: https://pyro.ai/examples/svi_part_iv.html

Finally, I think one improvement would be error messaging around things like NaN loss. I’ve run into a couple of issues like either mis-naming or forgetting a parameter in the guide, or setting a initial value outside of constraints, but it took a lot of exploration to figure out what the issue was when I would hope those would be handled as errors.

Again, thanks for all your help and I’m finding numpyro tremendously useful, these are just small things that would have made the learning curve a little less steep.

1 Like