Hi everyone!
I’m quite new at probabilistic programming. I’m trying to convert simple hierarchical two-layer bayesian neural network for make moons dataset from pymc3 to numpyro. You can see the pymc3 version at pyprobml.
Here is the implementation.
''' A two-layer bayesian neural network with computational flow given by D_X => D_H => D_H => D_Y where D_H is the number of hidden units. ''' def model(X, Y, D_H): N, D_X = X.shape D_Y = 1 # sample first layer (we put unit normal priors on all weights) w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H)))) # D_X D_H z1 = jnp.tanh(jnp.matmul(X, w1)) # N D_H <= first layer of activations # sample second layer w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H)))) # D_H D_H z2 = jnp.tanh(jnp.matmul(z1, w2)) # N D_H <= second layer of activations # sample final layer of weights and neural network output w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y)))) # D_H D_Y z3 = jnp.matmul(z2, w3) # N D_Y <= output of the neural network # observe data with numpyro.plate("classes", N): Y = numpyro.sample("Y", dist.Bernoulli(logits=z3), obs=Y) #return Y def run_inference(model, rng_key, X, Y, D_H): kernel = NUTS(model) guide = AutoDiagonalNormal(model) svi = SVI(model, guide, optim=numpyro.optim.Adam(step_size=1e-2), loss=numpyro.infer.Trace_ELBO()) svi_result = svi.run(rng_key, 10000, X, Y, D_H) params = svi_result.params svi_state = svi.init(rng_key, X, Y, D_H) pred_train, pred_test = [], [] D_H = 5 for X_train, Y_train, X_test, Y_test in zip(Xs_train, Ys_train, Xs_test, Ys_test): # do inference rng_key, rng_key_train, rng_key_test = random.split(random.PRNGKey(0), 3) samples = run_inference(model, rng_key, X_train, Y_train, D_H) # predict Y_train at inputs X_train vmap_args = (samples, random.split(rng_key_train, 1000)) predictions = vmap(lambda samples, rng_key: predict(model, rng_key, samples, X_train, D_H))(*vmap_args) predictions = predictions[..., 0] # compute mean prediction and confidence interval around median mean_prediction = jnp.mean(predictions, axis=0) pred_train.append(mean_prediction > 0.5) print(mean_prediction) break
Here are the results.
> [[0.473 0.496 0.517 0.504 0.512 0.501 0.511 0.484 0.487 0.49 0.465 0.506
> 0.494 0.496 0.524 0.488 0.521 0.478 0.499 0.472 0.493 0.52 0.489 0.508
> 0.502 0.517 0.501 0.498 0.502 0.515 0.509 0.509 0.489 0.501 0.516 0.49
> 0.495 0.505 0.5 0.503 0.474 0.51 0.484 0.47 0.506 0.512 0.516 0.516
> 0.5 0.497]]
I have examined all documentations, examples and codes under the hood at numpyro. However, I could not solve the problem that the loss slightly decreases and then remains firm at very high value. Furthermore, the mean predictions are so close to one half. It implies that it cannot learn anything and the prediction of any sample point is just a coin flip.
Thanks for your answer in advance.