Hi Pyro Devs,
I am trying to fit a PERT distribution on the data, for which I have defined a custom distribution. But when I try to run MCMC on top of it, it works for some set of data while doesn’t for others and throws the below error.
~/.local/lib/python3.6/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs)
428 if not_jax_tracer(is_valid):
429 if device_get(~jnp.all(is_valid)):
--> 430 raise RuntimeError("Cannot find valid initial parameters. Please check your model again.")
431 return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn, postprocess_fn, model_trace)
432
RuntimeError: Cannot find valid initial parameters. Please check your model again.
Below is my definition of PERT distribution
class Pert(numpyro.distributions.Distribution):
def __init__(self, min_val, ml_val, max_val, lamb = 4.0, **kwargs):
#batch_shape = broadcast_shape(mask.shape, component0.shape, component1.shape)
self.a = min_val
self.b = ml_val
self.c = max_val
self.lamb = lamb
self.rng_key = random.PRNGKey(0)
super().__init__(batch_shape=(), event_shape=())
self.build()
def build(self):
self.alpha = 1 + (self.lamb * ((self.b-self.a) / (self.c-self.a)))
self.beta = 1 + (self.lamb * ((self.c-self.b) / (self.c-self.a)))
self.range = self.c - self.a
self._beta_dist = numpyro.distributions.Beta(self.alpha, self.beta)
def sample(self,key, sample_shape=()):
return self._beta_dist.sample(key,sample_shape)*self.range + self.a
def log_prob(self, value):
x = (value - self.a) / self.range
x = jnp.clip(x,0,1)
return self._beta_dist.log_prob(x) - jnp.log(self.range)
and my code for PERT fitting is below
class Pert_fitting:
def __init__(self,data):
self.data = data
def _pert_model(self, data):
param1 = numpyro.sample('param1', dist.Gamma(1.0,1.0))
param2 = numpyro.sample('param2', dist.Gamma(1.0,1.0))
param3 = numpyro.sample('param3', dist.Gamma(1.0,1.0))
with numpyro.plate('data_plate', len(data)):
observation = numpyro.sample('observation', Pert(param1, param2, param3), obs=jnp.array(data))
def run_mcmc(self):
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
nuts_kernel = NUTS(self._pert_model, init_strategy = init_to_value(values={'param1':0,'param2':3.0,'param3':50.0 }))
mcmc = MCMC(nuts_kernel,1000,1000)
#with numpyro.validation_enabled():
mcmc.run(rng_key_, self.data)
return mcmc
ptf = Pert_fitting(obs_data)
mcmc= ptf.run_mcmc()
I have tried with numpyro.validation_enabled() too, which is giving the following error
<ipython-input-48-01ac64a0b237> in _pert_model(self, data)
8 param3 = numpyro.sample('param3', dist.Gamma(1.0,1.0))
9 with numpyro.plate('data_plate', len(data)):
---> 10 observation = numpyro.sample('observation', Pert(param1, param2, param3), obs=jnp.array(data))
11
12 def run_mcmc(self):
<ipython-input-26-046522502922> in __init__(self, min_val, ml_val, max_val, lamb, **kwargs)
8 self.rng_key = random.PRNGKey(0)
9 super().__init__(batch_shape=(), event_shape=())
---> 10 self.build()
11
12 def build(self):
<ipython-input-26-046522502922> in build(self)
14 self.beta = 1 + (self.lamb * ((self.c-self.b) / (self.c-self.a)))
15 self.range = self.c - self.a
---> 16 self._beta_dist = numpyro.distributions.Beta(self.alpha, self.beta)
17
18 def sample(self,key, sample_shape=()):
~/.local/lib/python3.6/site-packages/numpyro/distributions/continuous.py in __init__(self, concentration1, concentration0, validate_args)
59 self.concentration0 = jnp.broadcast_to(concentration0, batch_shape)
60 self._dirichlet = Dirichlet(jnp.stack([self.concentration1, self.concentration0],
---> 61 axis=-1))
62 super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
63
~/.local/lib/python3.6/site-packages/numpyro/distributions/continuous.py in __init__(self, concentration, validate_args)
117 super(Dirichlet, self).__init__(batch_shape=batch_shape,
118 event_shape=event_shape,
--> 119 validate_args=validate_args)
120
121 def sample(self, key, sample_shape=()):
~/.local/lib/python3.6/site-packages/numpyro/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
142 if not_jax_tracer(is_valid):
143 if not is_valid:
--> 144 raise ValueError("The parameter {} has invalid values".format(param))
145 super(Distribution, self).__init__()
146
ValueError: The parameter concentration has invalid values
Can you please help me debug this? I am not sure if there is a problem in defining the PERT distribution or running the model.
Could you please tell me what can be the potential ways of solving “RuntimeError: Cannot find valid initial parameter. Please check your model again.”? Because I am getting this same error in other models too which otherwise works perfectly fine but throws the error for few datasets.
Thanks,