Hello,
Much progress on bayes distillbert with tyxe. Please help me with the below issue. As you surely know by now this model has driven me to insanity. Should be just a bug now.
import pyro
import pyro.distributions as dist
import tyxe as ezbnn
from tyxe import SupervisedBNN
from tyxe import bnn
from tyxe import util
import pyro.nn as pynn
import pyro.poutine as poutine
from pyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO
import torchvision
from torch.utils.data import DataLoader
from transformers import DistilBertForSequenceClassification, AdamW
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
model.to(device)
model.train()
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
def _empty_guide(*args, **kwargs):
return {}
class SupervisedBertBNN(SupervisedBNN):
def __init__(self, net, prior, observation_model, net_guide_builder=None, observation_guide_builder=None, name=""):
super().__init__(net, prior, net_guide_builder, name=name)
self.observation_model = observation_model
weight_sample_sites = list(util.pyro_sample_sites(self.net))
if observation_guide_builder is not None:
self.observation_guide = observation_guide_builder(poutine.block(
self.model, hide=weight_sample_sites + [self.observation_model.observation_name]))
else:
self.observation_guide = _empty_guide
def guide(self, input_ids, attention_mask, obs=None):
result = self.net_guide(input_ids, attention_mask) or {}
result.update(self.observation_guide(input_ids, attention_mask, obs) or {})
return result
def fit(self, data_loader, optim, num_epochs, callback=None, closed_form_kl=True, device=None):
old_training_state = self.net.training
self.net.train(True)
loss = TraceMeanField_ELBO() if closed_form_kl else Trace_ELBO()
svi = SVI(self.model, self.guide, optim, loss=loss)
for i in range(num_epochs):
elbo = 0.
num_batch = 1
for num_batch, data in enumerate(iter(data_loader), 1):
elbo += svi.step(data["input_ids"].to(device), data['attention_mask'].to(device), data['labels'].to(device)) #data["input_ids"].to(device)
print("epoch: " + i + " batch: " + num_batch + "/" + len(train_dataset)/16)
# the callback can stop training by returning True
if callback is not None and callback(self, i, elbo / num_batch):
break
self.net.train(old_training_state)
return svi
def model(self, inputs, attention_mask, obs=None):
predictions = self(inputs, attention_mask)
self.observation_model(predictions[0], obs)
return predictions[0]
def guide(self, inputs, attention_mask, obs=None):
result = self.net_guide(inputs, attention_mask) or {}
result.update(self.observation_guide(input, attention_mask, obs) or {})
return result
this returns this error:
KeyError Traceback (most recent call last)
<ipython-input-15-364d06b641b5> in <module>()
1 optim = pyro.optim.AdamW({"lr": 5e-5})
2 #prediction = bayes_bert.forward(train_dataset[0]['input_ids'].unsqueeze(1), train_dataset[0]['attention_mask'])
----> 3 svi = bayes_bert.fit(train_loader, optim, 3)
3 frames
/usr/local/lib/python3.6/dist-packages/pyro/infer/trace_mean_field_elbo.py in _differentiable_loss_particle(self, model_trace, guide_trace)
94 elbo_particle = elbo_particle + model_site["log_prob_sum"]
95 else:
---> 96 guide_site = guide_trace.nodes[name]
97 if is_validation_enabled():
98 check_fully_reparametrized(guide_site)
KeyError: 'distilbert.embeddings.word_embeddings.weight'
I gather that basically the attention mask is screwing up this process. Just as some background I need to pass in the input_ids and attention_mask as the model input. As you will see in the new subclass I needed to separate the transformers tuple to get its raw outputs. I imagine the attention_mask has some irregularities like that as well. Please help. I want to experiment so badly! Very excited. Mad props to karalets – tyxe is awesome! Super excited to contribute!
Thanks for your help I know this is awfully troubleshooty.