Bayesian Residual Neural Network

class BayesianDynResNet(PyroModule):
        def __init__(self):
            super().__init__()
            HIDDEN_DIM = 50
            self.linear1 = PyroModule[nn.Linear](2, HIDDEN_DIM)
            self.linear1.weight = PyroSample(dist.Normal(0., 1.)
                                            .expand([HIDDEN_DIM, 2]).to_event(2))
            self.linear1.bias = PyroSample(dist.Normal(
                0., 10.).expand([HIDDEN_DIM]).to_event(1))
            self.linear2 = PyroModule[nn.Linear](HIDDEN_DIM, 2)
            self.linear2.weight = PyroSample(dist.Normal(
                0., 1.).expand([2, HIDDEN_DIM]).to_event(2))
            self.linear2.bias = PyroSample(
                dist.Normal(0., 10.).expand([1]).to_event(1))

        def res_forward(self, x, w1, b1, w2, b2):
            res = x@w1.T + b1
            res = torch.tanh(res)
            res = res@w2.T + b2
            return res

        def propagate(self, t, y0):

            w1 = self.linear1.weight
            b1 = self.linear1.bias
            w2 = self.linear2.weight
            b2 = self.linear2.bias
            y = []
            y.append(y0[None, ...])

            # the first step is already the initial condition
            for i in range(t.shape[0]-1):
                y0 = y0 + self.res_forward(y0, w1, b1, w2, b2)
                y.append(y0[None, ...])
            y = torch.cat(y, axis=0)
            return y

        def forward(self, t, y0, y=None):
            sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
            mean = self.propagate(t, y0)
            obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
            return mean

# defining the model for inference
    model = BayesianDynResNet3()
    guide = AutoDiagonalNormal(model)
    # define optimizera and variational inference
    adam = pyro.optim.Adam({"lr": 1e-3})
    svi = SVI(model, guide, adam, loss=Trace_ELBO())

Hi so I’m trying to construct a residual neural network that can learn dynamic systems. The idea is to sample from the linear layers before we use our res net block and then use variational inference to learn to distribution of the weights. Sadly the model doesn’t converge at all. (Althoug when using droput the model learns the dynamics). Does anyone see some implementation failures?

:warning: BNNs are hard :warning:

I’d recommend starting with variational inference using an AutoDelta guide. That should be almost like a non-Bayesian nn. If that doesn’t train, you probably have a shape error. Once that works, I’d try relaxing one variable at a time to be an AutoNormal. One easy way to play around with different guide combinations is with EasyGuide:

model = BayesianDynResNet3()

# First make sure this works.
@easy_guide(model)
def guide_1(self, t, y0, y=None):
    self.group(".*").map_estimate()

# Then try to be Bayesian about just one variable.
@easy_guide(model)
def guide_2(self, t, y0, y=None):
    # Point estimate these.
    self.group("*.linear1.*").map_estimate()
    self.group("*.linear2.weight").map_estimate()

    # Use effectively an AutoNormal on these.
    bayesian = self.group.("*.linear2.bias")
    loc = pyro.param("loc", bayesian.event_shape)
    scale = pyro.param("scale", bayesian.event_shape, constraint=constraints.positive)
    group.sample("bayesian", dist.Normal(loc, scale))

# Then try to be Bayesian about two variables.
@easy_guide(model)
def guide_3(self, t, y0, y=None):
    # Point estimate these.
    self.group("*.linear1.*").map_estimate()

    # Use effectively an AutoNormal on these.
    bayesian = self.group.("*.linear2.*")
    loc = pyro.param("loc", bayesian.event_shape)
    scale = pyro.param("scale", bayesian.event_shape, constraint=constraints.positive)
    group.sample("bayesian", dist.Normal(loc, scale))

You could do that other ways too, e.g. with poutine.block and AutoGuideList of AutoNormal and AutoDelta.