Cannot find initial params (even with init_to_median)

Given the model

def model():
    x5 = npy.sample("x5", Uniform(-99.9999, 99.9999))
    x6 = npy.sample("x6", Uniform(-99.9999, 99.9999))
    x7 = npy.sample("x7", Uniform(0, 1))
    x8 = npy.sample("x8", Uniform(-99.9999, 99.9999))
    x9 = npy.sample("x9", Uniform(0.0001, 99.9999))
    x10 = npy.sample("x10", Uniform(-99.9999, 99.9999))

    with npy.plate("observations", len(data)):
        npy.sample(
            f"obs",
            Uniform(
                npy.sample("Sample0", Uniform(x5, x6)),
                jnp.where(
                    npy.sample("Sample1", Bernoulli(x7)),
                    npy.sample("Sample2", Normal(x8, x9)),
                    x10,
                ),
            ),
            obs=data,
        )

The initial params cannot be found (when validation is enabled).

Reproduction notebook:

try init_to_value it probably doesn’t like it when e.g. the xs have large/extreme magnitude (which can easily happen depending on your initialization scheme)