I am attempting to model an interval censored Weibull regression survival model with a discrete spike-slab prior on the feature coefficients. However, the following code is EXTREMELY SLOW. It has been frozen trying to get 40 samples for a couple hours now, and I stopped execution. Variable X dimension is 10,000 x 15 (10,000 datapoints with 15 features each). Any help would be greatly appreciated.
def model(self,X,t1, t2, y):
N,F = X.shape
gamma = numpyro.sample("gamma",dist.Bernoulli(.5).expand((F,)))
tau = numpyro.sample("tau",dist.InverseGamma(0.5,0.5).expand((F,)))
sigmas = jnp.where(gamma,jnp.sqrt(tau),jnp.sqrt(0.001))
beta = numpyro.sample("beta",dist.Normal(0,sigmas))
bias = numpyro.sample("bias",dist.Normal(0,1))
k = numpyro.sample("k",dist.Gamma(1,1))
p = 1. - jnp.exp(-jnp.exp(jnp.einsum("ij,j->i",X,beta) + bias)*( jnp.power(t2.ravel(),k) - jnp.power(t1.ravel(),k) ) )
numpyro.sample("likelihood",dist.Bernoulli(probs=p),obs=y.ravel())
kernel = MixedHMC(HMC(model,target_accept_prob=0.8,dense_mass=True,init_strategy=init_to_median))
mcmc = MCMC(kernel,
num_chains=4,
num_samples=20,
num_warmup=20,
chain_method="parallel",
progress_bar=True,
jit_model_args=True)
mcmc.run(jaxkey,X=X,t1=t1, t2=t2, y=y)
OS: Windows 10 Version 20H2
Python 3.9.13
Packages:
- jax==0.4.11
- jaxlib==0.4.11
- numpyro==0.12.1