Hi,
I am working with a model where I would like to have conditionally independent betas that are drawn from a global beta distribution. However, each of the betas has a different constraint interval. When I try to apply these intervals, the sampled value may be outside of the constraint interval, and then the model initiates with a NaN loss and never recovers. Here is a simplified example for illustration:
def model(dat,est=None):
beta_mu = numpyro.sample("beta_mu", dist.Normal(0.,1.))
beta_sigma = numpyro.sample("beta_sigma", dist.HalfNormal(1.))
with numpyro.plate("plate_groups", 3):
betas = numpyro.sample("betas", dist.Normal(np.array([beta_mu]*3),np.array([beta_sigma]*3))
estimate = betas[dat["group"]]
observed_sigma = numpyro.sample("observed_sigma", dist.HalfNormal(1.))
with numpyro.plate("data", len(dat)):
numpyro.sample("obs", dist.Normal(estimate, observed_sigma), obs=est)
def guide(dat,est=None):
beta_mu = numpyro.sample("beta_mu", dist.Normal(loc = numpyro.param("loc_beta_mu", 0.), scale = numpyro.param("scale_beta_mu", 1., constraint=dist.constraint.positive)))
beta_sigma = numpyro.sample("beta_sigma", dist.HalfNormal( numpyro.param("scale_beta_sigma", 1., constraint=dist.constraint.positive))
with numpyro.plate("plate_groups", 3):
betas = numpyro.sample("betas", dist.Normal(
loc = numpyro.param("loc_betas", np.array([beta_mu]*3), constraint= dist.constraints.interval(lower_bound = [-3, -3, -1], upper_bound = [2, 3, 3])),
scale = numpyro.param("scale_betas", np.array([beta_sigma]*3), constraint=dist.constraint.positive)))
observed_sigma = numpyro.sample("observed_sigma", dist.HalfNormal(numpyro.param("scale_observed_sigma", 1., constraint=dist.constraint.positive)))
Is it possible constrain the sampled values for beta_mu so that they are not outside of the beta constrain intervals? Constraining the location and scale of beta_mu was not sufficient. I also tried working with converting beta_mu to a numpy array and manually enforcing the conditions, and while that was working if I pulled a sample on its own, inside the model it was failing to convert the JAX device array to numpy.
Thanks in advance!