Hi all,
I’m still quite new to Numpyro and I’ve been trying to fit a simple time to event model on some simulated data. The definition of the model is as follows:
# -------------------------------------------------------------------
# NumPyro model
# -------------------------------------------------------------------
def model(test_times, results):
"""
test_times: jnp array, shape (ntests,)
results: jnp array of 0/1, shape (nindivs, ntests)
"""
nindivs, ntests = results.shape
event_times = numpyro.sample(
"event_times",
dist.Exponential(1/10.0).expand([nindivs])
)
# Expand event_times to shape (nindivs, ntests)
# test_times to shape (nindivs, ntests)
event_times_2d = jnp.broadcast_to(event_times[:, None], (nindivs, ntests))
test_times_2d = jnp.broadcast_to(test_times[None, :], (nindivs, ntests))
# cdf_weibull = Weibull CDF of (test_times_2d - event_times_2d)
alpha, beta = 3.432, 0.5328
diff = test_times_2d - event_times_2d
p_before_event = 1e-3
cdf_weibull = weibull_cdf(diff, alpha ,beta)
tail_weibull = 1 - cdf_weibull
# p_after_event = jnp.clip(tail_weibull, a_min=0.001)
p_after_event = tail_weibull
# Where test_time <= event_time, p = p_before_event
# otherwise, p = p_after_event
p = jnp.where(test_times_2d <= event_times_2d, p_before_event, p_after_event)
# Observed results ~ Bernoulli(p)
numpyro.sample("obs", dist.Bernoulli(probs=p), obs=results)
On simulated synthetic data, I was able to recover the true values of event times quite well, when using methods such as NUTS, Laplace approximation and MAP. However, when I run the SVI algorithm, the approximated event time all collapse to a singular value (in this particular case all the event times collapse to approximately 0.09). I’ve been trying to debug this for the past few days, but have had no success, so I would be extremely grateful if someone could point out what potential issues there may be.
Thanks very much!