"Problem whit initial parameters" in bayesian fitting

Hi, I am trying to make a bayesian fit of a microlensing light curve, I prepare these model and i adapted some parts using jax:

def model_d(observado,tiempo):
    w_min=np.where(np.min(observado)==observado)[0][0]
    t_0=numpyro.sample("t_0", dist.Uniform(tiempo[w_min]-500,tiempo[w_min]+500))
    u_0=numpyro.sample("u_0", dist.Uniform(1e-30,1.0))
    t_E=numpyro.sample("t_E", dist.Uniform(100.,300.0))
    q=numpyro.sample("q", dist.Uniform(1e-30,1.0))
    s=numpyro.sample("s", dist.Uniform(1e-30,2))
    alpha=numpyro.sample("alpha", dist.Uniform(1e-30,2*jnp.pi))
    tau = (tiempo-t_0)/t_E 
    beta = u_0
    x_old,y_old,theta=tau,beta,alpha
    #binrot(alpha, tau, beta)
    cos_theta = jnp.cos(theta)
    sin_theta = jnp.sin(theta)
    x = x_old * cos_theta - y_old * sin_theta
    y = x_old * sin_theta + y_old * cos_theta    
    x = x - s*q/(1.+q)
    zeta = x + 1j*y
    cofs0=(1+q)**2*(s+zeta.conjugate())*zeta.conjugate()
    cofs1=(1+q)*(s*(q-abs(zeta)**2*(1+q))+(1+q)*((1+2*s**2)-abs(zeta)**2+2*s*zeta.conjugate())*zeta.conjugate())
    cofs2=(1+q)*(s**2*q-s*(1+q)*zeta+(2*s+s**3*(1+q)+s**2*(1+q)*zeta.conjugate())*zeta.conjugate()-2*abs(zeta)**2*(1+q)*(1+s**2+s*zeta.conjugate()))
    cofs3=-(1+q)*(s*q+s**2*(q-1)*zeta.conjugate()+(1+q+s**2*(2+q))*zeta+abs(zeta)**2*(2*s*(2+q)+s**2*(1+q)*(s+zeta.conjugate())))
    cofs4=-s*(1+q)*((2+s**2)*zeta+2*s*abs(zeta)**2)-s**2*q
    cofs5=-s**2*zeta
    coefs = jnp.vstack((cofs0, cofs1,cofs2,cofs3,cofs4,cofs5))
    z0=jnp.apply_along_axis(jnp.roots,0,coefs,strip_zeros=False)
    W1 = (1./(1.+q)*(1./z0+q/(z0+s)))
    z1=z0.copy()
    z1 =jnp.where(jnp.abs(z0-W1.conjugate()-zeta)>0.01, jnp.nan,z1) 
    W2 = -1./(1.+q)*(1./z1**2+q/(z1+s)**2)
    W2=jnp.nan_to_num(W2,1e-30)
    mu0 = 1./(1.-jnp.abs(W2)**2)
    mu1=mu0.copy()
    A = jnp.sum(jnp.abs(mu1),axis=0)
    mag0=numpyro.sample("mag0", dist.Uniform(17,23))
    fs=numpyro.sample("fs", dist.Uniform(1e-30,1.0))
    mu=mag0-2.5*jnp.log10(fs*(A-1)+1)
    numpyro.sample("mag", dist.Normal(mu, 0.001), obs=observado)

How you can observe in some parts i had to fine the roots of a polynom, whit complex numbers, then to run the fit I use:

#Running a HMC-NUTS
with numpyro.validation_enabled():
    rng_key = random.PRNGKey(20)
    rng_key, rng_key_ = random.split(rng_key)
    num_warmup, num_samples = 4000, 10000
    kernel = NUTS(model_d, init_strategy=init_to_median(), target_accept_prob=0.9)
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
    mcmc.run(rng_key_,observado=observado,tiempo=tiempo)
    mcmc.print_summary().

After run, this error appears “Cannot find valid initial parameters. Please check your model again.”
I tried to look for answers in the forum, but i don’t find nothing similar and also i dint know if the code is the problem or if the problem is a stats miss conception of my self, so I really appreciate a little light on these.
Thanks Felipe
PD: observados= an array of values from 18-23 and tiempo: another array this time whit values in the order of 2458762, and the version of numpyro is 0.10.

it’s hard to say what your problem(s) might be but glancing at your post i can see at least three possible sources of issues:

  • the root finding may not be differentiable (?)
  • the parameters have a wide variety of scales which can be problematic. it’s best if all parameters are reparameterized to be order unity. e.g. you might write:
    t_E=100.0 * numpyro.sample("t_E", dist.Uniform(1.0,3.0))
  • try 64 bit precision; see docs

you might also try using alternative initialization strategies (click on link for example)

Hi, thanks for your quick answer but the problem continue. Now I append the exact error. And if the root finding is not differentiable can i make my own function to find those?.
thanks

> --------------------------------------------------------------------------
> RuntimeError                              Traceback (most recent call last)
> /mnt/c/Users/jataq/Desktop/Paczynski-curve/zeustry.ipynb Cell 50 in <cell line: 5>()
>       9 kernel = NUTS(model_d, init_strategy=init_to_value(), target_accept_prob=0.9)
>      10 mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
> ---> 11 mcmc.run(rng_key_,observado=observado,tiempo=tiempo,)
>      12 mcmc.print_summary()
> 
> File ~/miniconda3/envs/zeus/lib/python3.10/site-packages/numpyro/infer/mcmc.py:593, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
>     591 map_args = (rng_key, init_state, init_params)
>     592 if self.num_chains == 1:
> --> 593     states_flat, last_state = partial_map_fn(map_args)
>     594     states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
>     595 else:
> 
> File ~/miniconda3/envs/zeus/lib/python3.10/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
>     379 rng_key, init_state, init_params = init
>     380 if init_state is None:
> --> 381     init_state = self.sampler.init(
>     382         rng_key,
>     383         self.num_warmup,
>     384         init_params,
>     385         model_args=args,
>     386         model_kwargs=kwargs,
>     387     )
>     388 sample_fn, postprocess_fn = self._get_cached_fns()
>     389 diagnostics = (
>     390     lambda x: self.sampler.get_diagnostics_str(x[0])
>     391     if rng_key.ndim == 1
>     392     else ""
>     393 )  # noqa: E731
> 
> File ~/miniconda3/envs/zeus/lib/python3.10/site-packages/numpyro/infer/hmc.py:706, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
>     701 # vectorized
>     702 else:
>     703     rng_key, rng_key_init_model = jnp.swapaxes(
>     704         vmap(random.split)(rng_key), 0, 1
>     705     )
> --> 706 init_params = self._init_state(
>     707     rng_key_init_model, model_args, model_kwargs, init_params
>     708 )
>     709 if self._potential_fn and init_params is None:
>     710     raise ValueError(
>     711         "Valid value of `init_params` must be provided with" " `potential_fn`."
>     712     )
> 
> File ~/miniconda3/envs/zeus/lib/python3.10/site-packages/numpyro/infer/hmc.py:652, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
>     650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
>     651     if self._model is not None:
> --> 652         init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
>     653             rng_key,
>     654             self._model,
>     655             dynamic_args=True,
>     656             init_strategy=self._init_strategy,
>     657             model_args=model_args,
>     658             model_kwargs=model_kwargs,
>     659             forward_mode_differentiation=self._forward_mode_differentiation,
>     660         )
>     661         if self._init_fn is None:
>     662             self._init_fn, self._sample_fn = hmc(
>     663                 potential_fn_gen=potential_fn,
>     664                 kinetic_fn=self._kinetic_fn,
>     665                 algo=self._algo,
>     666             )
> 
> File ~/miniconda3/envs/zeus/lib/python3.10/site-packages/numpyro/infer/util.py:698, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
>     685                             w.message.args = (
>     686                                 "Site {}: {}".format(
>     687                                     site["name"], w.message.args[0]
>     688                                 ),
>     689                             ) + w.message.args[1:]
>     690                             warnings.showwarning(
>     691                                 w.message,
>     692                                 w.category,
>    (...)
>     696                                 line=w.line,
>     697                             )
> --> 698         raise RuntimeError(
>     699             "Cannot find valid initial parameters. Please check your model again."
>     700         )
>     701 return ModelInfo(
>     702     ParamInfo(init_params, pe, grad), potential_fn, postprocess_fn, model_trace
>     703 )
> 
> RuntimeError: Cannot find valid initial parameters. Please check your model again.

did you try all my suggestions? try initializing with a parameter value for which you know the root finder you used returns a reasonable output (not a nan). also i’d generally avoid using copy.

also please note i do not have time to debug your model in detail—i can only offer a few suggestions

Firstly thanks for your answer, second i don’t want you solve my model I only ask if is possible do what I am trying to do, finally I tried whit different initializing and when “print” for example the mu parameter I don’t get any nan.