NumPyro Funsor error when using a positive constraint on a normally distributed variable

Also want to point that the same error occurs when using OrderedTransform. So for instance the following code does not work

def my_model(L, pi, s_line_fit_params, h_line_fit_params, s_prior, h_prior):
    
    with numpyro.plate("L", L):
        c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
        
        s = numpyro.sample("s", dist.Normal(loc=jnp.array(s_prior[c, 0]), scale=jnp.array(s_prior[c, 1])))
        h = numpyro.sample("h", dist.Normal(loc=jnp.array(h_prior[c, 0]), scale=jnp.array(h_prior[c, 1])))
        
        gamma_1 = jnp.array(s_line_fit_params[c, 0] + s * s_line_fit_params[c, 1])
        gamma_2 = jnp.array(h_line_fit_params[c, 0] + h * h_line_fit_params[c, 1])
        
        theta_mean = jnp.stack([gamma_1, gamma_2], axis=-1)
        theta_std = jnp.stack([sigma_gamma[c, 2], sigma_gamma[c, 3]], axis=-1)
        
        theta_ordered = numpyro.sample("theta_34",
                                 dist.TransformedDistribution(
                                     dist.Normal(loc=jnp.array(theta_mean), scale=jnp.array(theta_std)),
                                     transforms.OrderedTransform()
                                 ))
        
        

But the code below (which I think does the equivalent of what OrderedTransform does under the hood) works

def my_model(L, pi, s_line_fit_params, h_line_fit_params, s_prior, h_prior):
    
    with numpyro.plate("L", L):
        c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
        
        s = numpyro.sample("s", dist.Normal(loc=jnp.array(s_prior[c, 0]), scale=jnp.array(s_prior[c, 1])))
        h = numpyro.sample("h", dist.Normal(loc=jnp.array(h_prior[c, 0]), scale=jnp.array(h_prior[c, 1])))
        
        gamma_1 = jnp.array(s_line_fit_params[c, 0] + s * s_line_fit_params[c, 1])
        gamma_2 = jnp.array(h_line_fit_params[c, 0] + h * h_line_fit_params[c, 1])
        
        theta_mean = jnp.stack([gamma_1, gamma_2], axis=-1)
        theta_std = jnp.stack([sigma_gamma[c, 2], sigma_gamma[c, 3]], axis=-1)
        
        theta_1 = numpyro.sample("theta_1", dist.Normal(loc=theta_mean[..., 0], scale=theta_std[..., 0]))
        
        theta_2_raw = numpyro.sample("theta_2_raw", dist.Normal(loc=theta_mean[..., 1], scale=theta_std[..., 1]))
        theta_2 = numpyro.deterministic("theta_2", theta_1 + jnp.exp(theta_2_raw))
        

Although the latter code does not give an error and converges, it produces values that are very unexpected, even after giving it good initial values (for the variables s and h). This model has been working fine in Stan (with the ordered constraint and the positive constraint). I’m hoping to get this to work in numpyro to get speed benefits, but I’m struggling. Any help is appreciated.