TyXe distilBert - almost there


Much progress on bayes distillbert with tyxe. Please help me with the below issue. As you surely know by now this model has driven me to insanity. Should be just a bug now.

    import pyro
import pyro.distributions as dist
import tyxe as ezbnn
from tyxe import SupervisedBNN 
from tyxe import bnn

from tyxe import util 
import pyro.nn as pynn
import pyro.poutine as poutine
from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO
import torchvision
from torch.utils.data import DataLoader
from transformers import DistilBertForSequenceClassification, AdamW

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

def _empty_guide(*args, **kwargs):
    return {}

class SupervisedBertBNN(SupervisedBNN):
  def __init__(self, net, prior, observation_model, net_guide_builder=None, observation_guide_builder=None, name=""):
        super().__init__(net, prior, net_guide_builder, name=name)
        self.observation_model = observation_model
        weight_sample_sites = list(util.pyro_sample_sites(self.net))
        if observation_guide_builder is not None:
            self.observation_guide = observation_guide_builder(poutine.block(
                self.model, hide=weight_sample_sites + [self.observation_model.observation_name]))
            self.observation_guide = _empty_guide
  def guide(self, input_ids, attention_mask, obs=None):
        result = self.net_guide(input_ids, attention_mask) or {}
        result.update(self.observation_guide(input_ids, attention_mask, obs) or {})
        return result

  def fit(self, data_loader, optim, num_epochs, callback=None, closed_form_kl=True, device=None):
      old_training_state = self.net.training

      loss = TraceMeanField_ELBO() if closed_form_kl else Trace_ELBO()
      svi = SVI(self.model, self.guide, optim, loss=loss)

      for i in range(num_epochs):
          elbo = 0.
          num_batch = 1
          for num_batch, data in enumerate(iter(data_loader), 1):
              elbo += svi.step(data["input_ids"].to(device), data['attention_mask'].to(device), data['labels'].to(device)) #data["input_ids"].to(device)
              print("epoch: " + i + " batch: " + num_batch + "/" + len(train_dataset)/16)

            # the callback can stop training by returning True
          if callback is not None and callback(self, i, elbo / num_batch):

      return svi
  def model(self, inputs, attention_mask, obs=None):
        predictions = self(inputs, attention_mask)
        self.observation_model(predictions[0], obs)
        return predictions[0]
  def guide(self, inputs, attention_mask, obs=None):
      result = self.net_guide(inputs, attention_mask) or {}
      result.update(self.observation_guide(input, attention_mask, obs) or {})
      return result

this returns this error:

KeyError                                  Traceback (most recent call last)
<ipython-input-15-364d06b641b5> in <module>()
      1 optim = pyro.optim.AdamW({"lr": 5e-5})
      2 #prediction = bayes_bert.forward(train_dataset[0]['input_ids'].unsqueeze(1), train_dataset[0]['attention_mask'])
----> 3 svi = bayes_bert.fit(train_loader, optim, 3)

3 frames
/usr/local/lib/python3.6/dist-packages/pyro/infer/trace_mean_field_elbo.py in _differentiable_loss_particle(self, model_trace, guide_trace)
     94                     elbo_particle = elbo_particle + model_site["log_prob_sum"]
     95                 else:
---> 96                     guide_site = guide_trace.nodes[name]
     97                     if is_validation_enabled():
     98                         check_fully_reparametrized(guide_site)

KeyError: 'distilbert.embeddings.word_embeddings.weight'

I gather that basically the attention mask is screwing up this process. Just as some background I need to pass in the input_ids and attention_mask as the model input. As you will see in the new subclass I needed to separate the transformers tuple to get its raw outputs. I imagine the attention_mask has some irregularities like that as well. Please help. I want to experiment so badly! Very excited. Mad props to karalets – tyxe is awesome! Super excited to contribute!

Thanks for your help I know this is awfully troubleshooty.

1 Like

Dear @Arcco96,

Let’s pool this in Tyxe/ezbnn contributions and help , and what I responded there is obviously still my answer.
Thank you for your detailed bug-reports, I fully expect things to be buggy now. Hopefully we can work on specific models once the basics of tyxe are stable and you can help us figure these bespoke models out then.

1 Like

Yes not to drive you totally insane but is there anyway you could quickly check the way I set up the model and guide so I can apply this technique. I think its nearly set it trains so long as you dont pass ‘attention_mask’ which is an input to the huggingface model. It’s important I get this working with ‘attention_mask’ as it will take a very long time to train. I urgently want to apply this to nlp. I will have a continued interest in your project and contribution to it and bayesian nn as well albeit in a slower pace manner. For now words. Your help is much appreciated have you the time.