Cannot find valid initial parameters when using NUTS and BernoulliProbs

Hi everyone,

I get an RuntimeError: Cannot find valid initial parameters. Please check your model again. and don’t quite know how to solve it.

The model in question is quite simple:

def model(X, y=None):
   init_weight = numpyro.sample('initial_weight', dist.Normal(jnp.zeros(3), 10).to_event(1)).flatten()
   drift = numpyro.sample('drift', dist.Normal(jnp.zeros(3), 0.01).to_event(1))
   def transition(weight_prev, xs):
       weight_curr = numpyro.deterministic('weights', weight_prev + drift)
       return weight_curr, (weight_curr)
   _, (weights) = scan(transition, init_weight, None, length=len(X))
   logits = numpyro.deterministic('logits', X[:,0] * weights[:,0] + X[:,1] * weights[:,1] + weights[:,2])
   obs = numpyro.sample('y', dist.Bernoulli(probs=jax.nn.sigmoid(logits)), obs=y)

I know that it’s possible to express this model in a simpler way, but in reality I am using a more complicated model which requires the use of scan – The one above is just a minimal example to reproduce the problem.

To produce the problem, I create data using an array X = np.randn(200, 2).
Then I use the model to sample given X, producing y (and actually making the model the data-generating process).
Then I try to fit the model using X and y, to recover the model’s parameters, but at that point I get the error.

This problem doesn’t arise if I use obs = numpyro.sample('y', dist.Bernoulli(logits=logits), obs=y), so the sigmoid seems to be the problem.
Now in my actual model, I need to use probabilities instead of logits because I am assuming that the generating process sometimes makes mistakes/lapses, which I’d like to model in terms of a mixture model of a probability that depends on X and one that is independent of X.

Any suggestions how to solve this? Thanks in advance!

P.S. This is my first post here and I just want to say that I absolutely love using NumPyro - Thanks to all the developers!

I suspected that maybe the problem is numerical in nature, and while I still can’t rule it out, what I can say is that replacing jax.nn.sigmoid with

def sigmoid(x):
    return jnp.where(x >= 0, 
                    1 / (1 + jnp.exp(-x)), 
                    jnp.exp(x) / (1 + jnp.exp(x)))

does not help.

What also doesn’t help is to condition the model on the previously sampled initial state and fitting the conditioned model.

Welcome, @julianstastny! I think you need to clamp your probs to (finfo.tiny, 1 - finfo.eps). If probs=0 and value=1, it is likely that you will get an invalid log probability because there is no way we can get value=1 with probs=0. How about making a PR to directly using clamp_probs in BernoulliProbs.log_prob implementation?

1 Like

Thanks, this works! I’ve made a pull request: PR 1107

(I noticed only after making it that you suggested doing it in BernoulliProbs.log_prob, though maybe it does make more sense to do it in the initialization, as I’ve done?)

Thanks! I think it is better to use the clamp logic in log_prob implementation because the sample method does not require probs to be clamped.

Makes sense. Here is the pull request

1 Like