Text classification with bayesian convolution

Hello

I would like to use pyro to implement bayesian network for text classification. I use BERT to get embeddings than I use apply convolutions and fully connected layers. Convolution and fully connected layers are bayesian.

Here is a pytorch model:

class BertClassifier(nn.Module):

def __init__(self, emb_size, num_classes, dropout=0.5):

    super(BertClassifier, self).__init__()

    self.bert = BertModel.from_pretrained('DeepPavlov/rubert-base-cased-sentence')

    self.conv1 = nn.Conv1d(emb_size, ch_out,1)
   self.fc1 = nn.Linear(in1,out1)
    self.fc2 = nn.Linear(in2,num_classes)

And here are correspondent pyro functions:

def model(input_ids, mask, label):

conv1W_prior = Normal(loc=torch.zeros_like(net.conv1.weight), scale=torch.ones_like(net.conv1.weight))
conv1b_prior = Normal(loc=torch.zeros_like(net.conv1.bias), scale=torch.ones_like(net.conv1.bias))

fc1W_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight))
fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias))

fc2W_prior = Normal(loc=torch.zeros_like(net.fc2.weight), scale=torch.ones_like(net.fc2.weight))
fc2b_prior = Normal(loc=torch.zeros_like(net.fc2.bias), scale=torch.ones_like(net.fc2.bias))

priors = {'conv1.weight': conv1W_prior, 'conv1.bias': conv1b_prior, \
            'fc1.weight': fc1W_prior, 'fc1.bias': fc1b_prior, \
            'fc2.weight': fc2W_prior, 'fc2.bias': fc2b_prior}

# lift module parameters to random variables sampled from the priors
lifted_module = pyro.random_module("module", net, priors)
# sample a regressor (which also samples w and b)
lifted_reg_model = lifted_module()

lhat = log_softmax(lifted_reg_model(input_ids, mask))

pyro.sample("obs", Categorical(logits=lhat), obs=label) # categorical возвращает i с вероятностью lhat[i]

log_softmax = nn.LogSoftmax(dim=1)
softplus = torch.nn.Softplus()

def guide(input_ids, mask, label):

# First conv weight distribution priors
conv1W_mu = torch.randn_like(net.conv1.weight)
conv1W_sigma = torch.randn_like(net.conv1.weight)
conv1W_mu_param = pyro.param("conv1W_mu", conv1W_mu)
conv1W_sigma_param = softplus(pyro.param("conv1W_sigma", conv1W_sigma))
conv1W_prior = Normal(loc=conv1W_mu_param, scale=conv1W_sigma_param)
# First conv bias distribution priors
conv1b_mu = torch.randn_like(net.conv1.bias)
conv1b_sigma = torch.randn_like(net.conv1.bias)
conv1b_mu_param = pyro.param("conv1b_mu", conv1b_mu)
conv1b_sigma_param = softplus(pyro.param("conv1b_sigma", conv1b_sigma))
conv1b_prior = Normal(loc=conv1b_mu_param, scale=conv1b_sigma_param)

# Inner layer weight distribution priors
fc1W_mu = torch.randn_like(net.fc1.weight)
fc1W_sigma = torch.randn_like(net.fc1.weight)
fc1W_mu_param = pyro.param("fc1W_mu", fc1W_mu)
fc1W_sigma_param = softplus(pyro.param("fc1W_sigma", fc1W_sigma))
fc1W_prior = Normal(loc=fc1W_mu_param, scale=fc1W_sigma_param).independent(1)
# Inner layer bias distribution priors
fc1b_mu = torch.randn_like(net.fc1.bias)
fc1b_sigma = torch.randn_like(net.fc1.bias)
fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)

# Outer layer
fc2W_mu = torch.randn_like(net.fc2.weight)
fc2W_sigma = torch.randn_like(net.fc2.weight)
fc2W_mu_param = pyro.param("fc2W_mu", fc2W_mu)
fc2W_sigma_param = softplus(pyro.param("fc2W_sigma", fc2W_sigma))
fc2W_prior = Normal(loc=fc2W_mu_param, scale=fc2W_sigma_param).independent(1)
# Outer layer bias distribution priors
fc2b_mu = torch.randn_like(net.fc2.bias)
fc2b_sigma = torch.randn_like(net.fc2.bias)
fc2b_mu_param = pyro.param("fc2b_mu", fc2b_mu)
fc2b_sigma_param = softplus(pyro.param("fc2b_sigma", fc2b_sigma))
fc2b_prior = Normal(loc=fc2b_mu_param, scale=fc2b_sigma_param)

priors = {'conv1.weight': conv1W_prior, 'conv1.bias': conv1b_prior, \
            'fc1.weight': fc1W_prior, 'fc1.bias': fc1b_prior, \
            'fc2.weight': fc2W_prior, 'fc2.bias': fc2b_prior}

lifted_module = pyro.random_module("module", net, priors)

return lifted_module()

I use stochastic variational inference.

While training, loss is decreasing but metrics (accuracy) does not improve… In short, model does not learn.

It looks to me that it is impossible to implement such a thing in pyro, am I right? Or there is something I am missing?

we recommend using tyxe for bayesian neural networks in pyro