MixedHMC Extremely Slow

I am attempting to model an interval censored Weibull regression survival model with a discrete spike-slab prior on the feature coefficients. However, the following code is EXTREMELY SLOW. It has been frozen trying to get 40 samples for a couple hours now, and I stopped execution. Variable X dimension is 10,000 x 15 (10,000 datapoints with 15 features each). Any help would be greatly appreciated.

def model(self,X,t1, t2, y):

    N,F = X.shape        
        
    gamma = numpyro.sample("gamma",dist.Bernoulli(.5).expand((F,)))
    tau = numpyro.sample("tau",dist.InverseGamma(0.5,0.5).expand((F,)))

    sigmas = jnp.where(gamma,jnp.sqrt(tau),jnp.sqrt(0.001))
    beta = numpyro.sample("beta",dist.Normal(0,sigmas))
    bias = numpyro.sample("bias",dist.Normal(0,1))

    k = numpyro.sample("k",dist.Gamma(1,1))

    p = 1. - jnp.exp(-jnp.exp(jnp.einsum("ij,j->i",X,beta) + bias)*( jnp.power(t2.ravel(),k) - jnp.power(t1.ravel(),k) ) )
    numpyro.sample("likelihood",dist.Bernoulli(probs=p),obs=y.ravel())

kernel = MixedHMC(HMC(model,target_accept_prob=0.8,dense_mass=True,init_strategy=init_to_median))

mcmc = MCMC(kernel,
              num_chains=4,
              num_samples=20,
              num_warmup=20,
              chain_method="parallel",
              progress_bar=True,
              jit_model_args=True)

mcmc.run(jaxkey,X=X,t1=t1, t2=t2, y=y)

OS: Windows 10 Version 20H2
Python 3.9.13
Packages:

  • jax==0.4.11
  • jaxlib==0.4.11
  • numpyro==0.12.1

i don’t necessarily expect MixedHMC to perform well for this kind of model you might try DiscreteHMCGibbs instead

Yes, switching the kernel did seem to speed up the sampling. I have not read the papers for MixedHMC or DiscreteHMCGibbs, so if you do not mind, could you briefly explain why DiscreteHMCGibbs was the more appropriate choice? Thank you so much!

it’s hard to say i think MixedHMC might struggle with lots of (possibly very) different looking modes indexed by discrete variables

As a last follow up, does numpyro support Bayesian Variable Selection, i.e. a prior distribution like N(0,\tau \times \gamma) where \gamma \sim Bernoulli(.5)?

Although right now I am using spike-slab prior, I really would like to use the prior specified above but am having quite a difficult time doing so with numpyro.

Could I just do something like
\begin{align} w &\sim N(0,\tau) \\ \gamma &\sim Bernoulli(0.5) \\ \beta &= w \times \gamma \end{align}

def model(self,X,t1, t2, y):

    N,F = X.shape        
        
    gamma = numpyro.sample("gamma",dist.Bernoulli(.5).expand((F,)))
    tau = numpyro.sample("tau",dist.InverseGamma(0.5,0.5).expand((F,)))

    w= numpyro.sample("w",dist.Normal(0,jnp.sqrt(tau)))
    beta= numpyro.deterministic("beta",w*gamma)
    bias = numpyro.sample("bias",dist.Normal(0,1))

    k = numpyro.sample("k",dist.Gamma(1,1))

    p = 1. - jnp.exp(-jnp.exp(jnp.einsum("ij,j->i",X,beta) + bias)*( jnp.power(t2.ravel(),k) - jnp.power(t1.ravel(),k) ) )
    numpyro.sample("likelihood",dist.Bernoulli(probs=p),obs=y.ravel())

doing bayesian variable selection properly with black-box inference techniques would be quite difficult because in bayesian variable selection you actually have independent weights (here w) for each configuration of gamma (i.e. it’s a trans-dimensional inference problem). for bayesian variable selection you really want custom inference algorithms, e.g. see my library here (though it may not cover whatever likelihood assumptions etc you’d like to make)