I get an
RuntimeError: Cannot find valid initial parameters. Please check your model again. and don’t quite know how to solve it.
The model in question is quite simple:
def model(X, y=None): init_weight = numpyro.sample('initial_weight', dist.Normal(jnp.zeros(3), 10).to_event(1)).flatten() drift = numpyro.sample('drift', dist.Normal(jnp.zeros(3), 0.01).to_event(1)) def transition(weight_prev, xs): weight_curr = numpyro.deterministic('weights', weight_prev + drift) return weight_curr, (weight_curr) _, (weights) = scan(transition, init_weight, None, length=len(X)) logits = numpyro.deterministic('logits', X[:,0] * weights[:,0] + X[:,1] * weights[:,1] + weights[:,2]) obs = numpyro.sample('y', dist.Bernoulli(probs=jax.nn.sigmoid(logits)), obs=y)
I know that it’s possible to express this model in a simpler way, but in reality I am using a more complicated model which requires the use of scan – The one above is just a minimal example to reproduce the problem.
To produce the problem, I create data using an array X = np.randn(200, 2).
Then I use the model to sample given X, producing y (and actually making the model the data-generating process).
Then I try to fit the model using X and y, to recover the model’s parameters, but at that point I get the error.
This problem doesn’t arise if I use
obs = numpyro.sample('y', dist.Bernoulli(logits=logits), obs=y), so the sigmoid seems to be the problem.
Now in my actual model, I need to use probabilities instead of logits because I am assuming that the generating process sometimes makes mistakes/lapses, which I’d like to model in terms of a mixture model of a probability that depends on X and one that is independent of X.
Any suggestions how to solve this? Thanks in advance!
P.S. This is my first post here and I just want to say that I absolutely love using NumPyro - Thanks to all the developers!