Trouble Initializing Parameters with Highly Degenerate Data

I am dealing with a highly degenerate space of 5 parameters. It seems that this degeneracy is preventing the sampler from being initialized as I keep getting the error, RuntimeError: Cannot find valid initial parameters. Please check your model again.

Now, I can get the sampling to work if I feed the true underlying parameters that actually characterize the data, and I am therefore pretty confident that my likelihood and model are correct/reasonable. My likelihood function, however, immediately goes to negative infinity if I perturb any of the initial parameters even slightly from the truth. This isn’t really surprising due to the high degeneracy of my parameters (as can be seen by the corner plot below from when I get the sampler to run). Is there a way to get numpyro to do a little more exploring of the prior space in order to find suitable initialization values without feeding it the true underlying parameter values?

I’ve attached some snippets of my code here, which may or may not help. If necessary/helpful, I can try to come up with a minimal working example that’s a little less complex that shows the same issue when initializing parameters in a highly degenerate parameter space.

observed_events_arr
Array([[ 7.34172580e+01,  5.03441519e+01,  3.08489716e+02,
         2.72776158e+02,  6.89062366e+01,  6.64002407e-02,
         6.49709053e-02,  6.51768431e-02,  6.45408666e-02,
         5.32270433e-04, -1.93336753e-04, -7.17941016e-03,
         6.90383933e-04,  7.58451098e+01],
       [ 4.38415890e+01,  3.74324644e+01,  4.51519409e+02,
         3.87845539e+02,  1.48479969e+02,  9.52877513e-02,
         9.64373619e-02,  9.65043597e-02,  9.59946169e-02,
         5.47997308e-04, -6.11011847e-05,  5.90109523e-03,
         4.67350738e-04,  7.84641557e+01]], dtype=float64)

def numpyro_many_ll(dL_arr_samp, ztot_arr_samp, H0_samp):
    """Log likehlihood function."""
    # Probabilities relating to each GW event luminosity distance
    gw_prob = gauss(observed_events_arr[:, 3], observed_events_arr[:, 4], dL_arr_samp)

    # Probability related to redshift data
    zobs_prob = gauss(observed_events_arr[:, 8], observed_events_arr[:, 9], ztot_arr_samp)

    # Probability related to peculiar velocity (or log-distance ratio) data
    logdist_arr_samp = luminosity_distance_to_logdistance_ratio(
        dL_arr_samp,
        ztot_arr_samp,
        H0_samp,
        observed_events_arr[:, 0],
        observed_events_arr[:, 1],
    )
    eta_prob = gauss(observed_events_arr[:, 11], observed_events_arr[:, 12], logdist_arr_samp)

    # Combined likelihood from all forms of data
    like = gw_prob * zobs_prob * eta_prob

    # Take the product of the likelihood of all event; then take the log to get loglikelihood
    loglike = jnp.log(jnp.prod(like))

    # Avoid infinities by clipping the loglikelihood close to positive & negative infinity
    finfo64 = jnp.finfo(jnp.float64)
    loglike = jnp.clip(loglike, a_min=finfo64.min, a_max=finfo64.max)

    return loglike


def many_model():
    with numpyro.plate("events", len(observed_events_arr)):
        dL_obs_min = observed_events_arr[:, 3] - N_sigma * observed_events_arr[:, 4]
        zeros = jnp.zeros(len(observed_events_arr))

        dL_prior = Uniform(
            low=jnp.max(jnp.array([zeros, dL_obs_min])),
            high=observed_events_arr[:, 3] + N_sigma * observed_events_arr[:, 4],
        )
        ztot_prior = Uniform(
            low=observed_events_arr[:, 8] - N_sigma * observed_events_arr[:, 9],
            high=observed_events_arr[:, 8] + N_sigma * observed_events_arr[:, 9],
        )

        dL_arr_samp = numpyro.sample("dL", dL_prior)
        ztot_arr_samp = numpyro.sample("ztot", ztot_prior)

    H0_prior = Uniform(low=40.0, high=90.0)
    H0_samp = numpyro.sample("H0", H0_prior)

    ll = numpyro_many_ll(dL_arr_samp, ztot_arr_samp, H0_samp)
    numpyro.factor("custom_logp", ll)
    
    
true_initial_values = {"dL": observed_events_arr[:, 3], "ztot": observed_events_arr[:, 8], "H0": Planck15.H0.to_value()}

sampler = numpyro.infer.MCMC(
    numpyro.infer.NUTS(many_model, init_strategy=init_to_value(values=true_initial_values)),
    num_warmup=500,
    num_samples=2000,
    num_chains=1,
)

sampler.run(jax.random.PRNGKey(seed))