Thanks @martinjankowiak. I’ve tried init_to_uniform
and init_to_median
with no luck. Here is my model code. You can see that obs_mean
is a fairly complicated function of the model parameters. Do you know if there’s an easy way to break up the gradient using the chain rule to see where it is becoming NaN
?
import numpy as np
import xarray as xr
from jax import numpy as jnp
from numpyro import plate, sample, deterministic
from numpyro.distributions import Normal, HalfNormal
from numpyro.handlers import reparam
from numpyro.infer.reparam import LocScaleReparam
def model(ds: xr.Dataset, use_data) -> None:
wh_obs = ds.LB.notnull().to_numpy()
obs = np.log(ds.LB**2 + 1).to_numpy()[wh_obs] / 2 if use_data else None
n = ds.US.size
arm_idx = ds.wh_treatment.astype(int).to_numpy()
treatment_date_days = jnp.asarray(ds.treatment_date_days.to_numpy())
rm_start_days = jnp.asarray(ds.rm_start_days.to_numpy())
rm_end_days = jnp.asarray(ds.rm_end_days.to_numpy())
lab_date_days = jnp.asarray(ds.lab_date_days.to_numpy())
lambda_treat = sample("lambda_treat", HalfNormal(30))
lambda_rescue = sample("lambda_rescue", HalfNormal(30))
mu_log_p0 = sample("mu_log_p0", Normal(np.log(30), np.log(4)))
sigma_log_p0 = sample("sigma_log_p0", HalfNormal(np.log(4)))
mu_log_sigma_res = sample("mu_log_sigma_res", Normal(np.log(np.log(2)), np.log(3)))
sigma_log_sigma_res = sample("sigma_log_sigma_res", HalfNormal(np.log(3)))
mu_log_c1 = sample("mu_log_c1", Normal(np.log(500), np.log(10)))
sigma_log_c1 = sample("sigma_log_c1", HalfNormal(np.log(10)))
mu_log_c3 = sample("mu_log_c3", Normal(np.log(0.1), np.log(10)))
sigma_log_c3 = sample("sigma_log_c3", HalfNormal(np.log(10)))
mu_c4 = sample("mu_c4", Normal(scale=np.log(1.3)))
sigma_c4 = sample("sigma_c4", HalfNormal(np.log(1.3)))
with plate("treatment", 2):
mu_c5 = sample("mu_c5", Normal(scale=np.log(2)))
sigma_c5 = sample("sigma_c5", HalfNormal(np.log(2)))
lsr = LocScaleReparam(0)
with plate("us", n), reparam(config={f: lsr for f in ["c4", "c5"]}):
d = {}
for f in "log_p0 log_c1 log_c3 log_sigma_res".split():
var_name = f"{f}_z"
d[var_name] = sample(var_name, Normal())
p0 = deterministic("p0", jnp.exp(mu_log_p0 + sigma_log_p0 * d["log_p0_z"]))
c1 = deterministic("c1", jnp.exp(mu_log_c1 + sigma_log_c1 * d["log_c1_z"]))
c3 = deterministic("c3", jnp.exp(mu_log_c3 + sigma_log_c3 * d["log_c3_z"]))
sigma_res = deterministic(
"sigma_res",
jnp.exp(mu_log_sigma_res + sigma_log_sigma_res * d["log_sigma_res_z"]),
)
c4 = sample("c4", Normal(loc=mu_c4, scale=sigma_c4))
c5 = sample("c5", Normal(loc=mu_c5[arm_idx], scale=sigma_c5[arm_idx]))
# U x p
dt = lab_date_days - lab_date_days[:, 0, None]
dt_treat = lab_date_days[:, :, None] - jnp.maximum(
treatment_date_days[:, None, :], lab_date_days[:, 0, None, None]
) # U x p x t
rm_off = jnp.nan_to_num((lab_date_days - rm_end_days[:, None]).clip(0))
rm_duration = (
jnp.nan_to_num((lab_date_days - rm_start_days[:, None]).clip(0)) - rm_off
)
abs_int_c2 = deterministic(
"int_c2",
c3[:, None] * dt
- c4[:, None]
* (
rm_duration
- lambda_rescue
* -jnp.expm1(-rm_duration / lambda_rescue)
* jnp.exp(-rm_off / lambda_rescue)
)
- c5[:, None]
* jnp.where(dt_treat > 0, -jnp.expm1(-dt_treat / lambda_treat), 0).sum(2),
)
exp_2_int_c2 = jnp.exp(-2 * abs_int_c2)
p02 = jnp.square(p0)
pt2 = deterministic(
"pt2",
jnp.where(
dt > 0,
p02[:, None] * exp_2_int_c2
+ (1 - exp_2_int_c2) * c1[:, None] / (abs_int_c2 / dt),
p02[:, None],
),
)
obs_mean = deterministic("obs_mean", jnp.log(pt2[wh_obs] + 1) / 2)
sample(
"p",
Normal(
loc=obs_mean,
scale=jnp.broadcast_to(sigma_res[:, None], wh_obs.shape)[wh_obs],
),
obs=obs,
)