Problem in Bayesian Neural Network solved with SVI

I am using the following simple example to see how to use SVI for a Bayesian NN. Here is the code

from bokeh.plotting import figure, show
from bokeh.io import output_notebook
import numpy as np
import scipy.stats as stats

import jax.numpy as jnp
import jax.random as random
from jax.random import PRNGKey

import numpyro
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import Predictive, SVI, Trace_ELBO
output_notebook()

def actfun(x):
    return jnp.tanh(x)
def model(x,N,data):
    n_neurons = 2
    
    w1 = numpyro.sample("w1",dist.Normal(jnp.zeros((1,n_neurons)),jnp.ones((1,n_neurons))) )
    b1 = numpyro.sample("b1",dist.Normal(jnp.zeros((1,n_neurons)),jnp.ones((1,n_neurons))))
    
    w2 = numpyro.sample("w2",dist.Normal(jnp.zeros((n_neurons,1)), jnp.ones((n_neurons,1))))
    b2 = numpyro.sample("b2",dist.Normal(0., 1.))
    
    out1 = actfun(jnp.matmul(x,w1)+b1)
    y = jnp.matmul(out1,w2)+b2
    
    sigma = numpyro.sample("sigma",dist.Beta(2.,2.))

    numpyro.sample("obs",dist.Normal(y,sigma),obs=data)

def guide(x,N,data):
    n_neurons = 2
    
    w1_mean = numpyro.param("w1_mean",jnp.zeros((1,n_neurons)))
    w1_var = numpyro.param("w1_var",jnp.ones((1,n_neurons)))
    b1_mean = numpyro.param("b1_mean",jnp.zeros((1,n_neurons)))
    b1_var = numpyro.param("b1_var",jnp.ones((1,n_neurons)))

    
    w1 = numpyro.sample("w1",dist.Normal(w1_mean,w1_var) )
    b1 = numpyro.sample("b1",dist.Normal(b1_mean,b1_var))
    
    w2_mean = numpyro.param("w2_mean",jnp.zeros((n_neurons,1)))
    w2_var = numpyro.param("w2_var",jnp.ones((n_neurons,1)))
    b2_mean = numpyro.param("b2_mean",0.)
    b2_var = numpyro.param("b2_var",1.)
    
    w2 = numpyro.sample("w2",dist.Normal(w2_mean, w2_var))
    b2 = numpyro.sample("b2",dist.Normal(b2_mean, b2_var))
    
    out1 = actfun(jnp.matmul(x,w1)+b1)
    y = jnp.matmul(out1,w2)+b2
    
    sigma = numpyro.sample("sigma",dist.Beta(2.,2.))

# Generate data
N = 100
x = np.linspace(0,10,N)
data = 2*x + np.random.normal(0,0.5,N)

# Solve the SVI
optimizer = numpyro.optim.Adam(step_size=0.0005)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 2000, jnp.atleast_2d(x).T, N, data)

# Sample the distribution of weights, biases and measurement standard deviation
predictive = Predictive(guide, params=svi_result.params, num_samples=1000)
samples = predictive(random.PRNGKey(1), jnp.atleast_2d(x).T,N, data)

# Use the model to predict the output
predictive_output = Predictive(model, samples)(PRNGKey(2), jnp.atleast_2d(x).T,N, None)["obs"]

# Plot the results
mean_prediction = jnp.mean(predictive_output, axis=0)
percentiles = np.percentile(predictive_output, [5.0, 95.0], axis=0)
p = figure()
p.cross(x,data)
p.line(x, mean_prediction)
show(p)

The code does not work. I get the following results. It seems that the NN it is not trained at all. Can you see where the mistake is?

didn’t take a close look but initializing parameters like b1_var at 1 (as opposed to e.g. 0.01) is generally a bad idea. you might try using an AutoGuide instead of constructing a custom one

Thanks @martinjankowiak, I was not aware of the AutoGuide.
I solved the problem. I was passing the data to the model and guide in the wrong way. It works both using AutoGuide (AutoNormal) and with the custom guide. Here’s the code if someone is interested:

from bokeh.plotting import figure, show
from bokeh.io import output_notebook
import numpy as np
import scipy.stats as stats

import jax.numpy as jnp
import jax.random as random
from jax.random import PRNGKey

import numpyro
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import Predictive, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal


def actfun(x):
    return jnp.tanh(x)

def model(x,N,data):
    n_neurons = 2
    
    w1 = numpyro.sample("w1",dist.Normal(0.1*jnp.ones((1,n_neurons)),0.1*jnp.ones((1,n_neurons))) )
    b1 = numpyro.sample("b1",dist.Normal(0.1*jnp.ones((1,n_neurons)),0.1*jnp.ones((1,n_neurons))))
    
    w2 = numpyro.sample("w2",dist.Normal(jnp.zeros((n_neurons,1)), jnp.ones((n_neurons,1))))
    b2 = numpyro.sample("b2",dist.Normal(0., 1.))
    
    out1 = actfun(jnp.matmul(x,w1)+b1)
    y = jnp.matmul(out1,w2)+b2
#     y = jnp.matmul(x,w1)+b1
    sigma = numpyro.sample("sigma",dist.Beta(2.,2.))
    with numpyro.plate("data",N):
        numpyro.sample("obs",dist.Normal(y,sigma).to_event(1),obs=data)

def guide(x,N,data):
    n_neurons = 2
    
    w1_mean = numpyro.param("w1_mean",jnp.zeros((1,n_neurons)))
    w1_var = numpyro.param("w1_var",jnp.ones((1,n_neurons)), constraint=constraints.positive)
    b1_mean = numpyro.param("b1_mean",jnp.zeros((1,n_neurons)))
    b1_var = numpyro.param("b1_var",jnp.ones((1,n_neurons)), constraint=constraints.positive)

    
    w1 = numpyro.sample("w1",dist.Normal(w1_mean,w1_var) )
    b1 = numpyro.sample("b1",dist.Normal(b1_mean,b1_var))
    
    w2_mean = numpyro.param("w2_mean",jnp.zeros((n_neurons,1)))
    w2_var = numpyro.param("w2_var",jnp.ones((n_neurons,1)), constraint=constraints.positive)
    b2_mean = numpyro.param("b2_mean",0.)
    b2_var = numpyro.param("b2_var",1., constraint=constraints.positive)
    
    w2 = numpyro.sample("w2",dist.Normal(w2_mean, w2_var))
    b2 = numpyro.sample("b2",dist.Normal(b2_mean, b2_var))
    
    out1 = actfun(jnp.matmul(x,w1)+b1)
    y = jnp.matmul(out1,w2)+b2
#     y = jnp.matmul(x,w1)+b1
    sigma_a = numpyro.param("sigma_a", 2., constraint=constraints.positive)
    sigma_b = numpyro.param("sigma_b", 2., constraint=constraints.positive)
    sigma = numpyro.sample("sigma",dist.Beta(sigma_a,sigma_b))

N = 100
x =np.atleast_2d(np.linspace(0,10,N)).T
data =np.atleast_2d(2*x.T + np.random.normal(0,0.5,N)).T

optimizer = numpyro.optim.Adam(step_size=0.0005)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 70000, x, N, data)

predictive = Predictive(guide, params=svi_result.params, num_samples=1000)

samples = predictive(random.PRNGKey(1), x,N, None)
predictive_output = Predictive(model, posterior_samples=samples)(PRNGKey(2), x,N, None)["obs"]
mean_prediction = jnp.mean(predictive_output, axis=0)
percentiles = np.percentile(predictive_output, [5.0, 95.0], axis=0)

p = figure()
p.cross(x.squeeze(),data.squeeze())
p.line(x.squeeze(), mean_prediction)
p.varea(x.squeeze(), percentiles[0, :].squeeze(), percentiles[1, :].squeeze(),fill_alpha=0.4,fill_color="lightblue")
show(p)