Given the model
def model():
x5 = npy.sample("x5", Uniform(-99.9999, 99.9999))
x6 = npy.sample("x6", Uniform(-99.9999, 99.9999))
x7 = npy.sample("x7", Uniform(0, 1))
x8 = npy.sample("x8", Uniform(-99.9999, 99.9999))
x9 = npy.sample("x9", Uniform(0.0001, 99.9999))
x10 = npy.sample("x10", Uniform(-99.9999, 99.9999))
with npy.plate("observations", len(data)):
npy.sample(
f"obs",
Uniform(
npy.sample("Sample0", Uniform(x5, x6)),
jnp.where(
npy.sample("Sample1", Bernoulli(x7)),
npy.sample("Sample2", Normal(x8, x9)),
x10,
),
),
obs=data,
)
The initial params cannot be found (when validation is enabled).
Reproduction notebook: