Gibbs
part is basically erased in the following code because it is not related to the problem. I just don’t know why the vv_state.z_grad
in hmc.py
is extremely large.
Here is the simplified version of the code. The acceptance rate keeps at 0.
Is there any problem?
import jax
import jax.numpy as jnp
import jax.random as jnr
import jax.nn as nn
import numpyro
import numpyro.distributions as dist
key = jnr.PRNGKey(0)
ramp_ev=jnr.uniform(key, (800,100))
ramp_ed=jnr.randint(key,(800,),50,100)
from latimerOrigin import Latimer
import numpyro.handlers as handlers
class RampingModel:
def __init__(self,
ev_batch,ed_batch,
dt=10e-3
):
self.dt=dt
self.ev_batch,self.ed_batch=ev_batch,ed_batch
self.constants()
def constants(self):
self.x0_mu=0.0
self.x0_sigma=1.0
self.beta_mu=0.0
self.beta_sigma=0.1
# self.omega_alpha_=1.02
self.omega_alpha=0.02
self.omega_beta=0.02
self.gamma_alpha=2.0
self.gamma_beta=0.05
def model(self,ev_batch,ed_batch): # single trial spike train series, coherence
T,N=ev_batch.shape
x0=numpyro.sample("x_0",dist.Normal(self.x0_mu,self.x0_sigma))
beta=numpyro.sample("beta",dist.Normal(
self.beta_mu*jnp.zeros(5),
self.beta_sigma*jnp.ones(5)))
omega2=numpyro.sample("omega2",dist.InverseGamma(
self.omega_alpha,self.omega_beta))
x=numpyro.deterministic('x',jnp.ones([T,N]))
tau=numpyro.deterministic('tau',jnp.ones(N)*T)
gamma=numpyro.sample("gamma",dist.Gamma(
self.gamma_alpha,self.gamma_beta))
# FIXME: gamma becomes NaN or very large,
# not related to concrete distribution form, not related to gamma number
# not related to the gibbs part
mint=jnp.minimum(tau+1,ed_batch)
with handlers.mask(mask=(jnp.arange(T)[...,None]<mint)):
# y_mean=jnp.log(1.0+jnp.exp(gamma*jnp.where((x>=1.0),1.0,x)))*self.dt
y_mean=jnp.log(1.0+jnp.exp(gamma*x))*self.dt
numpyro.sample("obs",
dist.Poisson(jnp.maximum(y_mean,1e-16)),
obs=ev_batch)
jax.debug.print("gamma={gamma}",gamma=gamma)
def gibbs_fn(self,rng_key,gibbs_sites,hmc_sites):
print("---- gibbs start ----")
jax.debug.print("jax: ---- gibbs start ----")
T,N=self.ev_batch.shape
gamma=hmc_sites['gamma']
omega2=gibbs_sites['omega2']
beta=gibbs_sites['beta']
x0=gibbs_sites['x_0']
x=dist.Normal(jnp.zeros([T,N]),jnp.ones([T,N])).sample(rng_key)
tau=jnp.ones(N)*T
# something updated here but not related to large `gamma`
jax.debug.print("jax: beta={beta} gamma={gamma} omega2={omega2} x0={x0}",
beta=beta,gamma=gamma,omega2=omega2,x0=x0)
jax.debug.print("jax: xmax={xmax}, xmin={xmin}",xmax=x.max(),xmin=x.min())
print("---- gibbs end ----")
jax.debug.print("jax: ---- gibbs end ----")
return {'beta':beta,'x_0':x0,'omega2':omega2,'x':x,'tau':tau}
from numpyro.infer import MCMC, HMCGibbs, HMC
la=RampingModel(ramp_ev.T,ramp_ed)
print(ramp_ed[:10],ramp_ed[-10:])
def MCMCAll(la,ramp_ev,ramp_ed):
gibbs_fn=la.gibbs_fn
hmc=HMC(la.model,
target_accept_prob=0.8,step_size=1.0,
)
kernel=HMCGibbs(hmc,gibbs_fn=gibbs_fn,gibbs_sites=['beta','x_0','omega2','x','tau'])
mcmc = MCMC(kernel, num_warmup=10, num_samples=10, progress_bar=True,num_chains=1)
mcmc.run(jnr.PRNGKey(0),jnp.nan_to_num(ramp_ev).T,ramp_ed,init_params={
'x_0':0.5,'beta':jnp.array([0.0,0.0,0.0,0.0,0.0]),
'x':0.0*jnp.zeros_like(ramp_ev.T),'gamma':50.0,'omega2':0.005})
mcmc.print_summary()
return mcmc
mcmc=MCMCAll(la,ramp_ev,ramp_ed)
Blockquote
Outputs are:
[53 57 93 59 68 89 85 55 89 52] [79 85 68 68 80 92 52 94 53 97]
gamma=60.00014114379883
gamma=5.300010681152344
gamma=5.300010681152344
gamma=5.184705457665547e+21
0%| | 0/20 [00:00<?, ?it/s]
---- gibbs start ----
---- gibbs end ----
warmup: 20%|██ | 4/20 [00:01<00:03, 4.41it/s, 377 steps of size 1.07e-03. acc. prob=0.00]
gamma=5.184705457665547e+21
jax: ---- gibbs start ----
jax: beta=[0. 0. 0. 0. 0.] gamma=5.184705457665547e+21 omega2=0.004999999888241291 x0=0.5
jax: xmax=4.439892768859863, xmin=-4.001267433166504
jax: ---- gibbs end ----
gamma=5.184705457665547e+21
gamma=nan
gamma=nan
gamma=nan
gamma=nan
...
It seems that jax.value_and_grad in hmc_utils.py gives the key contribution to this, though I still don’t understand why.
It also seems not denpedent on distribution with gamma
or the distribution for obs
, I also turn dist.Gamma
to dist.Normal
with dist.TrucatedDistribution
. Or trun dist.Poisson
to dist.Exponential