Cannot find initial parameter with PERT distribution

Hi Pyro Devs,

I am trying to fit a PERT distribution on the data, for which I have defined a custom distribution. But when I try to run MCMC on top of it, it works for some set of data while doesn’t for others and throws the below error.

~/.local/lib/python3.6/site-packages/numpyro/infer/util.py in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs)
    428     if not_jax_tracer(is_valid):
    429         if device_get(~jnp.all(is_valid)):
--> 430             raise RuntimeError("Cannot find valid initial parameters. Please check your model again.")
    431     return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn, postprocess_fn, model_trace)
    432 

RuntimeError: Cannot find valid initial parameters. Please check your model again.

Below is my definition of PERT distribution

class Pert(numpyro.distributions.Distribution):
    def __init__(self, min_val, ml_val, max_val, lamb = 4.0, **kwargs):
        #batch_shape = broadcast_shape(mask.shape, component0.shape, component1.shape)
        self.a = min_val
        self.b = ml_val
        self.c = max_val
        self.lamb = lamb
        self.rng_key = random.PRNGKey(0)
        super().__init__(batch_shape=(), event_shape=())
        self.build()
        
    def build(self):
        self.alpha = 1 + (self.lamb * ((self.b-self.a) / (self.c-self.a)))
        self.beta = 1 + (self.lamb * ((self.c-self.b) / (self.c-self.a)))
        self.range = self.c - self.a
        self._beta_dist = numpyro.distributions.Beta(self.alpha, self.beta)
        
    def sample(self,key, sample_shape=()):
        return self._beta_dist.sample(key,sample_shape)*self.range + self.a
    
    def log_prob(self, value):
        x = (value - self.a) / self.range
        x = jnp.clip(x,0,1)
        return self._beta_dist.log_prob(x) - jnp.log(self.range)

and my code for PERT fitting is below

class Pert_fitting:
    def __init__(self,data):
        self.data = data
        
    def _pert_model(self, data):
        param1 = numpyro.sample('param1', dist.Gamma(1.0,1.0))
        param2 = numpyro.sample('param2', dist.Gamma(1.0,1.0))
        param3 = numpyro.sample('param3', dist.Gamma(1.0,1.0))
        with numpyro.plate('data_plate', len(data)):
            observation = numpyro.sample('observation', Pert(param1, param2, param3), obs=jnp.array(data))
            
    def run_mcmc(self):
        rng_key = random.PRNGKey(0)
        rng_key, rng_key_ = random.split(rng_key)
        nuts_kernel = NUTS(self._pert_model, init_strategy = init_to_value(values={'param1':0,'param2':3.0,'param3':50.0 }))
        mcmc = MCMC(nuts_kernel,1000,1000)
        #with numpyro.validation_enabled():
        mcmc.run(rng_key_, self.data)
        return mcmc
ptf = Pert_fitting(obs_data)
mcmc= ptf.run_mcmc()

I have tried with numpyro.validation_enabled() too, which is giving the following error

<ipython-input-48-01ac64a0b237> in _pert_model(self, data)
      8         param3 = numpyro.sample('param3', dist.Gamma(1.0,1.0))
      9         with numpyro.plate('data_plate', len(data)):
---> 10             observation = numpyro.sample('observation', Pert(param1, param2, param3), obs=jnp.array(data))
     11 
     12     def run_mcmc(self):

<ipython-input-26-046522502922> in __init__(self, min_val, ml_val, max_val, lamb, **kwargs)
      8         self.rng_key = random.PRNGKey(0)
      9         super().__init__(batch_shape=(), event_shape=())
---> 10         self.build()
     11 
     12     def build(self):

<ipython-input-26-046522502922> in build(self)
     14         self.beta = 1 + (self.lamb * ((self.c-self.b) / (self.c-self.a)))
     15         self.range = self.c - self.a
---> 16         self._beta_dist = numpyro.distributions.Beta(self.alpha, self.beta)
     17 
     18     def sample(self,key, sample_shape=()):

~/.local/lib/python3.6/site-packages/numpyro/distributions/continuous.py in __init__(self, concentration1, concentration0, validate_args)
     59         self.concentration0 = jnp.broadcast_to(concentration0, batch_shape)
     60         self._dirichlet = Dirichlet(jnp.stack([self.concentration1, self.concentration0],
---> 61                                               axis=-1))
     62         super(Beta, self).__init__(batch_shape=batch_shape, validate_args=validate_args)
     63 

~/.local/lib/python3.6/site-packages/numpyro/distributions/continuous.py in __init__(self, concentration, validate_args)
    117         super(Dirichlet, self).__init__(batch_shape=batch_shape,
    118                                         event_shape=event_shape,
--> 119                                         validate_args=validate_args)
    120 
    121     def sample(self, key, sample_shape=()):

~/.local/lib/python3.6/site-packages/numpyro/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
    142                 if not_jax_tracer(is_valid):
    143                     if not is_valid:
--> 144                         raise ValueError("The parameter {} has invalid values".format(param))
    145         super(Distribution, self).__init__()
    146 

ValueError: The parameter concentration has invalid values

Can you please help me debug this? I am not sure if there is a problem in defining the PERT distribution or running the model.

Could you please tell me what can be the potential ways of solving “RuntimeError: Cannot find valid initial parameter. Please check your model again.”? Because I am getting this same error in other models too which otherwise works perfectly fine but throws the error for few datasets.

Thanks,

Hi @umaruza, I think your PERT distribution requires min_val < ml_val < max_val. With

        param1 = numpyro.sample('param1', dist.Gamma(1.0,1.0))
        param2 = numpyro.sample('param2', dist.Gamma(1.0,1.0))
        param3 = numpyro.sample('param3', dist.Gamma(1.0,1.0))

we can’t guarantee that condition will happen. You might use an ordered prior for param like this ordinal regression tutorial, then unpack it into param1, param2, param3. For example,

param = numpyro.sample("param", dist.TransformedDistribution(
    dist.Normal(0, 1).expand([3]), dist.transforms.OrderedTransform()))
param1, param2, param3 = param

Hi @fehiepsi ,
Thank you for the reply.

I changed the implementation as you suggested, but still, that problem persists. Can you please tell me if I am missing anything?

class Pert_fitting:
    def __init__(self,data):
        self.data = data
        
    def _pert_model(self, data):
        #param = numpyro.sample("param", dist.TransformedDistribution(dist.Gamma(1.0, 1.0).expand([3]), dist.transforms.OrderedTransform()))
        param = numpyro.sample("param", dist.TransformedDistribution(dist.Normal(0, 1).expand([3]), dist.transforms.OrderedTransform()))
        param1, param2, param3 = param
        with numpyro.plate('data_plate', len(data)):
            observation = numpyro.sample('observation', Pert(param1, param2, param3), obs=jnp.array(data))
            
    def run_mcmc(self):
        rng_key = random.PRNGKey(0)
        rng_key, rng_key_ = random.split(rng_key)
        #nuts_kernel = NUTS(self._pert_model, init_strategy = init_to_value(values={'param1':0,'param2':3.0,'param3':50.0 }))
        nuts_kernel = NUTS(self._pert_model)
        mcmc = MCMC(nuts_kernel,1000,1000)
        with numpyro.validation_enabled():
            mcmc.run(rng_key_, self.data)
        return mcmc

I am getting the below warnings too.

/home/ubuntu/.local/lib/python3.6/site-packages/numpyro/distributions/distribution.py:246: UserWarning: Out-of-support values provided to log prob method. The value argument should be within the support.
  warnings.warn('Out-of-support values provided to log prob method. '
/home/ubuntu/.local/lib/python3.6/site-packages/numpyro/distributions/distribution.py:246: UserWarning: Out-of-support values provided to log prob method. The value argument should be within the support.
  warnings.warn('Out-of-support values provided to log prob method. '

Do let me know if you need any other information.

thanks,

How about using x = jnp.clip(x, 1e-6, 1 - 1e-6) in log probability of PERT? If your data has NaN values then the issue will also happen.

Btw, because your data should belong to the support (a, c), how about letting a belong to (-inf, x.min()), b belong to (x.min(), x.max()), and c belong to (x.max(), inf)? I think the inference will be more stable (no clip is required).

Thank you @fehiepsi x = jnp.clip(x, 1e-6, 1 - 1e-6) worked for me. I’ll let you know if I face any other challenges.

Can you please tell me what is x here? and where to set these constraints?

Yeah it works. :smile:

what is x here

I meant x is your data. To set those constraints, you can put corresponding priors for -a, b, c: e.g. TruncatedNormal(-x.min()), Uniform(x.min(), x.max()), TruncatedNormal(x.max()).