Hi,
I’m using MixedHMC to sample from a distribution has a mixed support. The potential function I’m using is the following (modified version of this) :
def generate_potential_energy_fn(X, y, J, sigma, mu):
def potential_energy(gamma, beta):
beta_prior_potential = jnp.sum(
0.5 * jnp.log(2 * jnp.pi * sigma ** 2) + 0.5 * beta ** 2 / sigma ** 2)
probs = 1 / (
1 + jnp.exp(-jnp.dot(jnp.dot(X, jnp.diag(gamma).astype(jnp.float32)), beta)) )
likelihood_potential = -jnp.sum(
y * jnp.log(probs + 1e-12) + (1 - y) * jnp.log(1 - probs + 1e-12))
gamma_potential = -0.5*jnp.dot(jnp.dot(gamma.T, J), gamma) + mu*jnp.sum(gamma)
return beta_prior_potential + likelihood_potential + gamma_potential
return potential_energy
Here, gamma
is a discrete variable and beta
is a continuous one. I’m trying to sample from this distribution using MixedHMC sampler as follows:
# .... define X, y, J, sigma, mu
potential_eng = generate_potential_energy(X, y, J, sigma, mu)
kernel = MixedHMC(HMC(potential_fn=potential_eng , trajectory_length=2), num_discrete_updates=20, random_walk=True)
mcmc = MCMC(kernel, num_warmup=2000, num_samples=1000, num_chains=2, progress_bar=True)
mcmc.run(key)
Running the above code gives me the following error:
AssertionError: HMCGibbs does not support models specified via a potential function.
How can resolve this error?