Cannot find valid initial parameters: posterior has NaN gradients

I’m trying to run inference with NUTS, and I’m getting the error message

RuntimeError: Cannot find valid initial parameters. Please check your model again.

The likelihood function is Normal with a fairly complicated loc parameter, and this seems to be what’s causing the problem. However, I’ve been using numpyro for over a year, and I’ve never encountered this error before.

Using pdb it seems that the program is failing in the initialize_model() function, where find_valid_initial_params() returns some initial parameters and their gradients, but some of the gradients are NaN, so initialize_model() throws an exception.

So the question is why my log posterior has undefined gradients, and I’m wondering if anyone has any suggestions for how to dig into this.

Debugging is made even more tricky by the fact inside numpyro many of the variables are transformed. For example, a HalfNormal might be logged.

hard to say what’s going on without looking at the model in great detail but you might try different initialization strategies and see if that address your issue. demo’d e.g. here

Thanks @martinjankowiak. I’ve tried init_to_uniform and init_to_median with no luck. Here is my model code. You can see that obs_mean is a fairly complicated function of the model parameters. Do you know if there’s an easy way to break up the gradient using the chain rule to see where it is becoming NaN?

import numpy as np
import xarray as xr
from jax import numpy as jnp
from numpyro import plate, sample, deterministic
from numpyro.distributions import Normal, HalfNormal
from numpyro.handlers import reparam
from numpyro.infer.reparam import LocScaleReparam

def model(ds: xr.Dataset, use_data) -> None:
    wh_obs = ds.LB.notnull().to_numpy()
    obs = np.log(ds.LB**2 + 1).to_numpy()[wh_obs] / 2 if use_data else None
    n = ds.US.size
    arm_idx = ds.wh_treatment.astype(int).to_numpy()
    treatment_date_days = jnp.asarray(ds.treatment_date_days.to_numpy())
    rm_start_days = jnp.asarray(ds.rm_start_days.to_numpy())
    rm_end_days = jnp.asarray(ds.rm_end_days.to_numpy())
    lab_date_days = jnp.asarray(ds.lab_date_days.to_numpy())

    lambda_treat = sample("lambda_treat", HalfNormal(30))
    lambda_rescue = sample("lambda_rescue", HalfNormal(30))

    mu_log_p0 = sample("mu_log_p0", Normal(np.log(30), np.log(4)))
    sigma_log_p0 = sample("sigma_log_p0", HalfNormal(np.log(4)))

    mu_log_sigma_res = sample("mu_log_sigma_res", Normal(np.log(np.log(2)), np.log(3)))
    sigma_log_sigma_res = sample("sigma_log_sigma_res", HalfNormal(np.log(3)))

    mu_log_c1 = sample("mu_log_c1", Normal(np.log(500), np.log(10)))
    sigma_log_c1 = sample("sigma_log_c1", HalfNormal(np.log(10)))

    mu_log_c3 = sample("mu_log_c3", Normal(np.log(0.1), np.log(10)))
    sigma_log_c3 = sample("sigma_log_c3", HalfNormal(np.log(10)))

    mu_c4 = sample("mu_c4", Normal(scale=np.log(1.3)))
    sigma_c4 = sample("sigma_c4", HalfNormal(np.log(1.3)))

    with plate("treatment", 2):
        mu_c5 = sample("mu_c5", Normal(scale=np.log(2)))
        sigma_c5 = sample("sigma_c5", HalfNormal(np.log(2)))

    lsr = LocScaleReparam(0)
    with plate("us", n), reparam(config={f: lsr for f in ["c4", "c5"]}):
        d = {}
        for f in "log_p0 log_c1 log_c3 log_sigma_res".split():
            var_name = f"{f}_z"
            d[var_name] = sample(var_name, Normal())
        p0 = deterministic("p0", jnp.exp(mu_log_p0 + sigma_log_p0 * d["log_p0_z"]))
        c1 = deterministic("c1", jnp.exp(mu_log_c1 + sigma_log_c1 * d["log_c1_z"]))
        c3 = deterministic("c3", jnp.exp(mu_log_c3 + sigma_log_c3 * d["log_c3_z"]))
        sigma_res = deterministic(
            "sigma_res",
            jnp.exp(mu_log_sigma_res + sigma_log_sigma_res * d["log_sigma_res_z"]),
        )
        c4 = sample("c4", Normal(loc=mu_c4, scale=sigma_c4))
        c5 = sample("c5", Normal(loc=mu_c5[arm_idx], scale=sigma_c5[arm_idx]))

    # U x p
    dt = lab_date_days - lab_date_days[:, 0, None]
    dt_treat = lab_date_days[:, :, None] - jnp.maximum(
        treatment_date_days[:, None, :], lab_date_days[:, 0, None, None]
    )  # U x p x t
    rm_off = jnp.nan_to_num((lab_date_days - rm_end_days[:, None]).clip(0))
    rm_duration = (
        jnp.nan_to_num((lab_date_days - rm_start_days[:, None]).clip(0)) - rm_off
    )
    abs_int_c2 = deterministic(
        "int_c2",
        c3[:, None] * dt
        - c4[:, None]
        * (
            rm_duration
            - lambda_rescue
            * -jnp.expm1(-rm_duration / lambda_rescue)
            * jnp.exp(-rm_off / lambda_rescue)
        )
        - c5[:, None]
        * jnp.where(dt_treat > 0, -jnp.expm1(-dt_treat / lambda_treat), 0).sum(2),
    )
    exp_2_int_c2 = jnp.exp(-2 * abs_int_c2)
    p02 = jnp.square(p0)
    pt2 = deterministic(
        "pt2",
        jnp.where(
            dt > 0,
            p02[:, None] * exp_2_int_c2
            + (1 - exp_2_int_c2) * c1[:, None] / (abs_int_c2 / dt),
            p02[:, None],
        ),
    )
    obs_mean = deterministic("obs_mean", jnp.log(pt2[wh_obs] + 1) / 2)
    sample(
        "p",
        Normal(
            loc=obs_mean,
            scale=jnp.broadcast_to(sigma_res[:, None], wh_obs.shape)[wh_obs],
        ),
        obs=obs,
    )

i’m not sure but i’d start by looking in obvious places like divide by zero (/ abs_int_c2), logarithms of non-positive arguments, etc

Thanks @martinjankowiak. I’ve tried your suggestion without any luck. I’m wondering if there is an easy way to extract the model log joint density (and ideally the deterministic variables as well) as a jax function, so that I can compute the gradients manually and see where they’re NaNing out.

You can use log_density. Given a model and parameter’s sample, you can do

joint_fn = partial(log_density, model, model_args, model_kwargs)
jax.grad(joint_fn)(values)

For debugging, you can comment out part of the code to identify the issue. It’s tricky to debug nan in grad.

Thanks @fehiepsi, I’ll do this. I had tried calling log_density directly on my seeded model, but then I just got a value. I didn’t know about this trick of using partial.

Thanks @fehiepsi, using log_density was incredibly helpful. After much detective work, I managed to isolate the issue. Below is a prototypical example.

@jit
def f(x):
    return jnp.where(np.array(False), x / 0, 0)

This function is clearly zero everywhere, yet its gradient evaluates to NaN.

>>> f(1.)
Array(0., dtype=float64)
>>> grad(f)(1.)
Array(nan, dtype=float64, weak_type=True)

Do you think I should report this as an issue in JAX?

Okay, I see this is a known issue in JAX.

This is a known issue. See probability/where-nan.pdf at main · tensorflow/probability · GitHub for a workaround.