Large potential energy while using HMCGibbs at the initial stage

Gibbs part is basically erased in the following code because it is not related to the problem. I just don’t know why the vv_state.z_grad in hmc.py is extremely large.

numpyro/numpyro/infer/hmc.py

Here is the simplified version of the code. The acceptance rate keeps at 0.
Is there any problem?

import jax
import jax.numpy as jnp
import jax.random as jnr
import jax.nn as nn
import numpyro
import numpyro.distributions as dist

key = jnr.PRNGKey(0)
ramp_ev=jnr.uniform(key, (800,100))
ramp_ed=jnr.randint(key,(800,),50,100)

from latimerOrigin import Latimer
import numpyro.handlers as handlers

class RampingModel:
    def __init__(self,
                 ev_batch,ed_batch,
                 dt=10e-3
                 ):
        self.dt=dt
        self.ev_batch,self.ed_batch=ev_batch,ed_batch
        self.constants()

    def constants(self):
        
        self.x0_mu=0.0
        self.x0_sigma=1.0

        self.beta_mu=0.0
        self.beta_sigma=0.1
        
        # self.omega_alpha_=1.02
        self.omega_alpha=0.02
        self.omega_beta=0.02
        
        self.gamma_alpha=2.0
        self.gamma_beta=0.05

    def model(self,ev_batch,ed_batch): # single trial spike train series, coherence
        
        T,N=ev_batch.shape
        
        x0=numpyro.sample("x_0",dist.Normal(self.x0_mu,self.x0_sigma))
        beta=numpyro.sample("beta",dist.Normal(
            self.beta_mu*jnp.zeros(5),
            self.beta_sigma*jnp.ones(5)))
        omega2=numpyro.sample("omega2",dist.InverseGamma(
            self.omega_alpha,self.omega_beta))
        
        x=numpyro.deterministic('x',jnp.ones([T,N]))
        tau=numpyro.deterministic('tau',jnp.ones(N)*T)
        gamma=numpyro.sample("gamma",dist.Gamma(
            self.gamma_alpha,self.gamma_beta))
        # FIXME: gamma becomes NaN or very large, 
        # not related to concrete distribution form, not related to gamma number
        # not related to the gibbs part
        
        mint=jnp.minimum(tau+1,ed_batch)

        with handlers.mask(mask=(jnp.arange(T)[...,None]<mint)):
            # y_mean=jnp.log(1.0+jnp.exp(gamma*jnp.where((x>=1.0),1.0,x)))*self.dt
            y_mean=jnp.log(1.0+jnp.exp(gamma*x))*self.dt
            
            numpyro.sample("obs",
                           dist.Poisson(jnp.maximum(y_mean,1e-16)),
                           obs=ev_batch)
            jax.debug.print("gamma={gamma}",gamma=gamma)
    
    def gibbs_fn(self,rng_key,gibbs_sites,hmc_sites):
        print("---- gibbs start ----")
        jax.debug.print("jax: ---- gibbs start ----")

        T,N=self.ev_batch.shape
        gamma=hmc_sites['gamma']
        omega2=gibbs_sites['omega2']
        beta=gibbs_sites['beta']
        x0=gibbs_sites['x_0']
        
        
        x=dist.Normal(jnp.zeros([T,N]),jnp.ones([T,N])).sample(rng_key)
        tau=jnp.ones(N)*T

        # something updated here but not related to large `gamma`

        jax.debug.print("jax: beta={beta} gamma={gamma} omega2={omega2} x0={x0}",
                        beta=beta,gamma=gamma,omega2=omega2,x0=x0)
        jax.debug.print("jax: xmax={xmax}, xmin={xmin}",xmax=x.max(),xmin=x.min())

        print("---- gibbs end ----")
        jax.debug.print("jax: ---- gibbs end ----")
        return {'beta':beta,'x_0':x0,'omega2':omega2,'x':x,'tau':tau}


from numpyro.infer import MCMC, HMCGibbs, HMC
la=RampingModel(ramp_ev.T,ramp_ed)
print(ramp_ed[:10],ramp_ed[-10:])

def MCMCAll(la,ramp_ev,ramp_ed):

    gibbs_fn=la.gibbs_fn
    hmc=HMC(la.model,
            target_accept_prob=0.8,step_size=1.0,
            )

    kernel=HMCGibbs(hmc,gibbs_fn=gibbs_fn,gibbs_sites=['beta','x_0','omega2','x','tau'])

    mcmc = MCMC(kernel, num_warmup=10, num_samples=10, progress_bar=True,num_chains=1)
    mcmc.run(jnr.PRNGKey(0),jnp.nan_to_num(ramp_ev).T,ramp_ed,init_params={
        'x_0':0.5,'beta':jnp.array([0.0,0.0,0.0,0.0,0.0]),
        'x':0.0*jnp.zeros_like(ramp_ev.T),'gamma':50.0,'omega2':0.005})

    mcmc.print_summary()
    return mcmc

mcmc=MCMCAll(la,ramp_ev,ramp_ed)

Blockquote
Outputs are:

[53 57 93 59 68 89 85 55 89 52] [79 85 68 68 80 92 52 94 53 97]
gamma=60.00014114379883
gamma=5.300010681152344
gamma=5.300010681152344
gamma=5.184705457665547e+21
  0%|          | 0/20 [00:00<?, ?it/s]
---- gibbs start ----
---- gibbs end ----
warmup:  20%|██        | 4/20 [00:01<00:03,  4.41it/s, 377 steps of size 1.07e-03. acc. prob=0.00]
gamma=5.184705457665547e+21
jax: ---- gibbs start ----
jax: beta=[0. 0. 0. 0. 0.] gamma=5.184705457665547e+21 omega2=0.004999999888241291 x0=0.5
jax: xmax=4.439892768859863, xmin=-4.001267433166504
jax: ---- gibbs end ----
gamma=5.184705457665547e+21
gamma=nan
gamma=nan
gamma=nan
gamma=nan
...

It seems that jax.value_and_grad in hmc_utils.py gives the key contribution to this, though I still don’t understand why.

/numpyro/infer/hmc_util.py

It also seems not denpedent on distribution with gamma or the distribution for obs, I also turn dist.Gamma to dist.Normal with dist.TrucatedDistribution. Or trun dist.Poisson to dist.Exponential

Have you tried using log_density or potential_energy? Those tools allow you to compute potential energy explicitly. Or you can also compute density manually at each random variable.