A Look into the Black Box of Transformers
by Daniel K Baissa
Transformer Models are often seen as black boxes, where it is not entirely clear how the models make their predictions. This has led to a field of research to develop tools to help us peer into these models and see how they come up with their outputs. In this post we will look at how to visualize the attention between tokens that transformer models assign. This should help researchers have a better understanding about the relationships their models find and the outcomes they produce.
Let’s use the bert-base-uncased model.
from transformers import BertModel
import torch
model_ckpt = 'bert-base-uncased'
do_lower_case = True
model = BertModel.from_pretrained(model_ckpt, output_attentions=True)
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(model_ckpt, do_lower_case=do_lower_case)
text = "This is a test"
inputs = tokenizer(text, return_tensors="pt")
print(f"Input tensor shape is:{inputs['input_ids'].size()}")
Input tensor shape is:torch.Size([1, 6])
sentence_a = "Maia is a cute little baby who likes to explore her world!"
sentence_b = "Maia is a cute little baby who loves to explore her house."
inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
token_type_ids = inputs['token_type_ids']
input_ids = inputs['input_ids']
attention = model(input_ids, token_type_ids=token_type_ids)[-1]
sentence_b_start = token_type_ids[0].tolist().index(1)
input_id_list = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list)
Let’s start with head view from bertviz to visualize the attention layers one at a time.
from bertviz import head_view,model_view
head_view(attention, tokens)
If you run this you should get something that looks like this, you can select the attention layers and visualize the attention relationships between tokens. You can hover over any token to visualize the attention to and from the token. The colors correspond to the attention heads, and you can double-click a color square at the top to filter the attention head of your choice. You can select the attention heads by single clicking on them.
This is nice, but what if you wanted a higher level view? Seeing all attention relationships in all layers and heads at once?
model_view(attention, tokens, sentence_b_start)
The layers are rows and heads are columns from the model. Remember bert-base has 12 layers and 12 heads (zero-indexed).
Inspiration from: Transformers for NLP and Computer Vision 3rd Edition
I hope this was helpful!
Best, Dan Baissa