Also want to point that the same error occurs when using OrderedTransform
. So for instance the following code does not work
def my_model(L, pi, s_line_fit_params, h_line_fit_params, s_prior, h_prior):
with numpyro.plate("L", L):
c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
s = numpyro.sample("s", dist.Normal(loc=jnp.array(s_prior[c, 0]), scale=jnp.array(s_prior[c, 1])))
h = numpyro.sample("h", dist.Normal(loc=jnp.array(h_prior[c, 0]), scale=jnp.array(h_prior[c, 1])))
gamma_1 = jnp.array(s_line_fit_params[c, 0] + s * s_line_fit_params[c, 1])
gamma_2 = jnp.array(h_line_fit_params[c, 0] + h * h_line_fit_params[c, 1])
theta_mean = jnp.stack([gamma_1, gamma_2], axis=-1)
theta_std = jnp.stack([sigma_gamma[c, 2], sigma_gamma[c, 3]], axis=-1)
theta_ordered = numpyro.sample("theta_34",
dist.TransformedDistribution(
dist.Normal(loc=jnp.array(theta_mean), scale=jnp.array(theta_std)),
transforms.OrderedTransform()
))
But the code below (which I think does the equivalent of what OrderedTransform
does under the hood) works
def my_model(L, pi, s_line_fit_params, h_line_fit_params, s_prior, h_prior):
with numpyro.plate("L", L):
c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
s = numpyro.sample("s", dist.Normal(loc=jnp.array(s_prior[c, 0]), scale=jnp.array(s_prior[c, 1])))
h = numpyro.sample("h", dist.Normal(loc=jnp.array(h_prior[c, 0]), scale=jnp.array(h_prior[c, 1])))
gamma_1 = jnp.array(s_line_fit_params[c, 0] + s * s_line_fit_params[c, 1])
gamma_2 = jnp.array(h_line_fit_params[c, 0] + h * h_line_fit_params[c, 1])
theta_mean = jnp.stack([gamma_1, gamma_2], axis=-1)
theta_std = jnp.stack([sigma_gamma[c, 2], sigma_gamma[c, 3]], axis=-1)
theta_1 = numpyro.sample("theta_1", dist.Normal(loc=theta_mean[..., 0], scale=theta_std[..., 0]))
theta_2_raw = numpyro.sample("theta_2_raw", dist.Normal(loc=theta_mean[..., 1], scale=theta_std[..., 1]))
theta_2 = numpyro.deterministic("theta_2", theta_1 + jnp.exp(theta_2_raw))
Although the latter code does not give an error and converges, it produces values that are very unexpected, even after giving it good initial values (for the variables s
and h
). This model has been working fine in Stan (with the ordered constraint and the positive constraint). I’m hoping to get this to work in numpyro to get speed benefits, but I’m struggling. Any help is appreciated.