Passing potential energy function for MixedHMC in numpyro

Hi,

I’m using MixedHMC to sample from a distribution has a mixed support. The potential function I’m using is the following (modified version of this) :

def generate_potential_energy_fn(X, y, J, sigma, mu):

    def potential_energy(gamma, beta):
        beta_prior_potential = jnp.sum(
             0.5 * jnp.log(2 * jnp.pi * sigma ** 2) + 0.5 * beta ** 2 / sigma ** 2)

        probs = 1 / (
                1 + jnp.exp(-jnp.dot(jnp.dot(X, jnp.diag(gamma).astype(jnp.float32)), beta)) )

        likelihood_potential = -jnp.sum(
            y * jnp.log(probs + 1e-12) + (1 - y) * jnp.log(1 - probs + 1e-12))

        gamma_potential = -0.5*jnp.dot(jnp.dot(gamma.T, J), gamma) + mu*jnp.sum(gamma)

        return beta_prior_potential + likelihood_potential + gamma_potential

    return potential_energy

Here, gamma is a discrete variable and beta is a continuous one. I’m trying to sample from this distribution using MixedHMC sampler as follows:

# .... define X, y, J, sigma, mu

potential_eng = generate_potential_energy(X, y, J, sigma, mu)

kernel = MixedHMC(HMC(potential_fn=potential_eng , trajectory_length=2), num_discrete_updates=20, random_walk=True)
mcmc = MCMC(kernel, num_warmup=2000, num_samples=1000, num_chains=2, progress_bar=True)
mcmc.run(key)

Running the above code gives me the following error:

AssertionError: HMCGibbs does not support models specified via a potential function.

How can resolve this error?

I think you can build a model with ImproperUniform and numpyro.factor instead of potential_fn.

1 Like

Is there any example that demonstrates the usage of ImproperUniform and numpyro.factor?

Currently, I’m using this model to make it work with MixedHMC

import numpyro.distributions as dist
from numpyro.infer import MCMC, MixedHMC, HMC

class GammaDist(dist.BernoulliProbs): #Extemd from Bernoulli as gamma is a binary variable

    def __init__(self, J, eta, mu):
        self.J = J
        super().__init__(probs=jnp.full(shape=(J.shape[0],), fill_value=0.5) )
        # super().__init__(batch_shape=(x.shape[0],))

    def log_prob(self, gamma):
        lgp = 0.5*jnp.dot(jnp.dot(gamma.T, self.J), gamma) - jnp.sum(gamma) # potential energy of gamma
        return lgp

def model(X, y, sigma, J, eta=1.0, mu=1.0):
    beta = npyro.sample('beta', dist.MultivariateNormal(0, sigma))
    gamma = npyro.sample('gamma', GammaDist(J, eta, mu))
    prob = npyro.deterministic("prob", logistic(jnp.dot(X, (beta * gamma))))
    likelihood = npyro.sample("y", dist.Bernoulli(probs=prob),
                              obs=y)

However, I’m unsure if this captures the potential energy function I was initially hoping to use as I’m new to Pyro/Numpyro.