Hello,
I have provided the reproducible example below. The PyTorch model I am working with is HuggingFace Transformer (SEE: GitHub - huggingface/transformers: 🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX., I am using RobertaForMultipleChoice
to be more specific). Thank you for looking into my case; I hope this is helpful.
from transformers import RobertaTokenizer, RobertaForMultipleChoice, AdamW, get_constant_schedule
from transformers import PreTrainedTokenizer
import torch
import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist
import pyro.nn.module as module
import pyro.infer.autoguide.guides as guides
from torch import nn
from pyro.optim import Adam
from pyro.infer import SVI
from pyro.infer import Trace_ELBO
# define input_ids and attention_masks
input_ids = torch.tensor([[[ 0, 102, 1816, 16, 2343, 816, 14545, 15, 10, 165,
11, 41, 11894, 6545, 479, 5, 1011, 2, 2, 354,
22362, 5134, 124, 8, 7264, 81, 5, 1161, 1533, 498,
479, 2, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1],
[ 0, 102, 1816, 16, 2343, 816, 14545, 15, 10, 165,
11, 41, 11894, 6545, 479, 5, 1011, 2, 2, 354,
3148, 11, 10, 2698, 8, 10, 313, 1420, 69, 39,
1028, 479, 2, 1, 1, 1, 1, 1, 1, 1,
1, 1],
[ 0, 102, 1816, 16, 2343, 816, 14545, 15, 10, 165,
11, 41, 11894, 6545, 479, 5, 1011, 2, 2, 354,
5629, 66, 1706, 5, 1161, 8, 5, 1816, 386, 7,
23322, 5225, 5, 1011, 479, 2, 1, 1, 1, 1,
1, 1],
[ 0, 102, 1816, 16, 2343, 816, 14545, 15, 10, 165,
11, 41, 11894, 6545, 479, 5, 1011, 2, 2, 354,
5629, 31, 69, 865, 3987, 8, 79, 1388, 3022, 24,
124, 8, 7264, 25, 79, 36989, 24, 160, 9, 4257,
479, 2]]])
attention_mask = torch.tensor([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]])
mc_labels = torch.tensor([0])
# get the pre-trained HuggingFace RobertaForMultipleChoice model
model_RobertaForMultipleChoice = RobertaForMultipleChoice.from_pretrained('roberta-base', output_hidden_states = True)
# convert the HuggingFace Transformer model into a Bayesian Pyro model
module.to_pyro_module_(model_RobertaForMultipleChoice)
# Now we can attempt to be fully Bayesian:
for m in model_RobertaForMultipleChoice.modules():
for name, value in list(m.named_parameters(recurse=False)):
setattr(m, name, module.PyroSample(prior=dist.Normal(0, 1)
.expand(value.shape)
.to_event(value.dim())))
# define parameters for training
guide_diag_normal = guides.AutoDiagonalNormal(model_RobertaForMultipleChoice)
optimizer = Adam({"lr": 0.000000055})
scheduler = pyro.optim.StepLR({'optimizer': optimizer, 'optim_args': {'lr': 0.000000055}})
# define SVI
svi_diag_normal = SVI(model_RobertaForMultipleChoice, guide_diag_normal, optimizer, loss=Trace_ELBO())
# calculate loss from SVI
# ERRORS ARE GENERATED HERE
svi_loss = svi_diag_normal.step(input_ids = input_ids, attention_mask = attention_mask, labels = mc_labels)
Below is the error that gets displayed:
ERRORS:
Traceback (most recent call last)
svi_loss = svi_diag_normal.step(input_ids = input_ids, attention_mask = attention_mask, labels = mc_labels)
Traceback (most recent call last):
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/transformers/modeling_utils.py", line 150, in dtype
return next(self.parameters()).dtype
StopIteration
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/elbo.py", line 170, in _get_traces
yield self._get_trace(model, guide, args, kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/trace_elbo.py", line 53, in _get_trace
"flat", self.max_plate_nesting, model, guide, args, kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/enum.py", line 44, in get_importance_trace
guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 185, in get_trace
self(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 165, in __call__
ret = self.fn(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/nn/module.py", line 290, in __call__
return super().__call__(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py", line 679, in forward
self._setup_prototype(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py", line 819, in _setup_prototype
super()._setup_prototype(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py", line 577, in _setup_prototype
super()._setup_prototype(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py", line 156, in _setup_prototype
self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 11, in _context_wrap
return fn(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 185, in get_trace
self(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 165, in __call__
ret = self.fn(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 11, in _context_wrap
return fn(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 11, in _context_wrap
return fn(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/nn/module.py", line 290, in __call__
return super().__call__(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/transformers/modeling_roberta.py", line 441, in forward
output_hidden_states=output_hidden_states,
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/nn/module.py", line 290, in __call__
return super().__call__(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/transformers/modeling_bert.py", line 732, in forward
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/transformers/modeling_utils.py", line 228, in get_extended_attention_mask
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/transformers/modeling_utils.py", line 159, in dtype
first_tuple = next(gen)
StopIteration
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<ipython-input-13-9f26219c3f71>", line 1, in <module>
svi_loss = svi_diag_normal.step(input_ids = input_ids, attention_mask = attention_mask, labels = mc_labels)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/svi.py", line 128, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/trace_elbo.py", line 126, in loss_and_grads
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
RuntimeError: generator raised StopIteration