# Cannot find valid initial parameters when using NUTS and BernoulliProbs

Hi everyone,

I get an `RuntimeError: Cannot find valid initial parameters. Please check your model again.` and don’t quite know how to solve it.

The model in question is quite simple:

``````def model(X, y=None):
init_weight = numpyro.sample('initial_weight', dist.Normal(jnp.zeros(3), 10).to_event(1)).flatten()
drift = numpyro.sample('drift', dist.Normal(jnp.zeros(3), 0.01).to_event(1))
def transition(weight_prev, xs):
weight_curr = numpyro.deterministic('weights', weight_prev + drift)
return weight_curr, (weight_curr)
_, (weights) = scan(transition, init_weight, None, length=len(X))
logits = numpyro.deterministic('logits', X[:,0] * weights[:,0] + X[:,1] * weights[:,1] + weights[:,2])
obs = numpyro.sample('y', dist.Bernoulli(probs=jax.nn.sigmoid(logits)), obs=y)
``````

I know that it’s possible to express this model in a simpler way, but in reality I am using a more complicated model which requires the use of scan – The one above is just a minimal example to reproduce the problem.

To produce the problem, I create data using an array X = np.randn(200, 2).
Then I use the model to sample given X, producing y (and actually making the model the data-generating process).
Then I try to fit the model using X and y, to recover the model’s parameters, but at that point I get the error.

This problem doesn’t arise if I use `obs = numpyro.sample('y', dist.Bernoulli(logits=logits), obs=y)`, so the sigmoid seems to be the problem.
Now in my actual model, I need to use probabilities instead of logits because I am assuming that the generating process sometimes makes mistakes/lapses, which I’d like to model in terms of a mixture model of a probability that depends on X and one that is independent of X.

Any suggestions how to solve this? Thanks in advance!

P.S. This is my first post here and I just want to say that I absolutely love using NumPyro - Thanks to all the developers!

I suspected that maybe the problem is numerical in nature, and while I still can’t rule it out, what I can say is that replacing `jax.nn.sigmoid` with

``````def sigmoid(x):
return jnp.where(x >= 0,
1 / (1 + jnp.exp(-x)),
jnp.exp(x) / (1 + jnp.exp(x)))
``````

does not help.

What also doesn’t help is to condition the model on the previously sampled initial state and fitting the conditioned model.

Welcome, @julianstastny! I think you need to clamp your probs to `(finfo.tiny, 1 - finfo.eps)`. If `probs=0` and `value=1`, it is likely that you will get an invalid log probability because there is no way we can get `value=1` with `probs=0`. How about making a PR to directly using clamp_probs in `BernoulliProbs.log_prob` implementation?

1 Like

Thanks, this works! I’ve made a pull request: PR 1107

(I noticed only after making it that you suggested doing it in BernoulliProbs.log_prob, though maybe it does make more sense to do it in the initialization, as I’ve done?)

Thanks! I think it is better to use the clamp logic in log_prob implementation because the sample method does not require probs to be clamped.

Makes sense. Here is the pull request

1 Like