Hi there,
I am new to NumPyro and I am trying to build a cox proportional hazard model for survival analysis. I followed this example and tried to replicate the model in NumPyro:
def cox_model(df, duration_col='time', event_col='event'):
lambda0 = numpyro.sample('lambda0', dist.Gamma(0.01, 0.01), sample_shape=(n_intervals,))
beta = numpyro.sample('beta', dist.Normal(0, 1000))
lambda_ = numpyro.deterministic('lambda_', jnp.outer(jnp.exp(beta * df.metastasized.values), lambda0))
mu = numpyro.deterministic('mu', exposure * lambda_)
y = numpyro.sample('obs', dist.Poisson(mu), obs=death)
rng_key = random.PRNGKey(0)
rng_key, alpha_key = random.split(rng_key)
mcmc = MCMC(NUTS(cox_model, target_accept_prob=0.95), num_samples=1000, num_warmup=1000, num_chains=2)
mcmc.run(alpha_key, df)
mcmc_samples = mcmc.get_samples()
but I get this error
ValueError: Poisson distribution got invalid rate parameter
How do I modify my model to prevent this error message. Thanks for your help.