Explain Text Classification Model with LIME
前兩篇我們介紹了怎麼訓練一個 text classification model,今天要來講除了看 text dataset 的 accuracy 之外,用 LIME (Local Interpretable Model-Agnostic Explanations)來解釋訓練出來的 model ,究竟是基於哪些字來做出預測的?
LIME 的原理
在每一個 local data point 附近,逼近 model 的預測結果去找出一個更簡單、可以被人理解的決策準則,而且對於任何的模型都能夠適用 (Model-Agnostic)。
例如,上圖中的模型是一個 binary classification,真實的模型是非線性的,分別是藍色和粉紅色的面積,但是 Lime 可以用簡單的 linear function 區分在局部的紅色叉叉和藍色圓點。這個用 linear function 逼近真實模型的手法,就是 Lime 用來解釋模型的原理。
至於怎麼逼近得到這條直線呢?Lime 用的是下圖中的演算法:
- Classifier f 是 Lime 想要解釋的複雜模型
- Number of samples N 是 Lime 改變 input 的變化,例如改字的順序、增減字、遮掉圖片的一部份等等(N 很大程度的決定了演算法跑的速度,原作者的 default 是用 5000 ,但用來解釋 Bert 跑實在是太慢了,加上要解釋的句子偏短,所以我實際用的時候縮減到 25)
- Instance x 是要被解釋的 input
- x’ 則是 input 可以被人解釋的版本,例如用最簡單的 binary vector 來代表 bag of words 中的某個字有沒有出現(即使實際的模型的 vector 更為複雜,像是 Word2vec 的 embeddings)
- Similarity kernel 是計算 distance matrix 的函式(論文中用來解釋 SVM text classification 的例子是用 RBF kernel)
- 用 K-Lasso 的統計方法(可以想城市最小平方法的一種)算出 weights
可以解釋單個 input 之後,我們接下來可以綜合如何解釋這些個別的 input,來取得模型的 global understanding。這個步驟有點類似 feature selection,看個別的 input 在用來解釋的線性方程中有哪些共同的 features,出現次數越多的 feature 代表重要性越高,下圖是一個簡單的示意圖:
至於怎麼找到這些最重要的 features,論文中用的是 greedy algorithm:
Bert Predictor
了解了 Lime 的原理後,接著我們要用它實際來解釋我們 fine tune 後的 Bert model。首先我們要 load 需要解釋的 Bert model,並定義 Lime 需要的 method(input: a list of strings; output: a 2d array of prediction probabilities):
### refer to https://github.com/marcotcr/lime/issues/409
import torch.nn.functional as F
import finetune
import os
import torch
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer)class Prediction:
def __init__(self):
self.model = AutoModelForSequenceClassification.from_pretrained(outputs_dirs)
self.tokenizer = AutoTokenizer.from_pretrained(outputs_dirs)
self.processor = finetune.ATProcessor(filename="/home/ubuntu/data/data_no_urls.csv")
self.max_seq_length = 128
self.device = "cpu"
self.model.to("cpu")
def predict_label(self, text_a, text_b):
self.model.to(self.device)
input_ids, input_mask, segment_ids = self.convert_text_to_features(text_a, text_b)
with torch.no_grad():
outputs = self.model(input_ids, segment_ids, input_mask)
logits = outputs[0]
logits = F.softmax(logits, dim=1)
logits_label = torch.argmax(logits, dim=1)
label = logits_label.detach().cpu().numpy()
logits_confidence = logits[0][logits_label]
label_confidence_ = logits_confidence.detach().cpu().numpy()
return label, label_confidence_
def _truncate_seq_pair(self, tokens_a, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a)
if total_length <= max_length:
break
if len(tokens_a) > max_length:
tokens_a.pop()
def convert_text_to_features(self, text_a, text_b=None):
features = []
cls_token = self.tokenizer.cls_token
sep_token = self.tokenizer.sep_token
cls_token_at_end = False
sequence_a_segment_id = 0
sequence_b_segment_id = 1
cls_token_segment_id = 1
pad_token_segment_id = 0
mask_padding_with_zero = True
pad_token = 0
tokens_a = self.tokenizer.tokenize(text_a)
tokens_b = None
self._truncate_seq_pair(tokens_a, self.max_seq_length - 2)
tokens = tokens_a + [sep_token]
segment_ids = [sequence_a_segment_id] * len(tokens)
if tokens_b:
tokens += tokens_b + [sep_token]
segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)
tokens = [cls_token] + tokens
segment_ids = [cls_token_segment_id] + segment_ids
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
#
# # Zero-pad up to the sequence length.
padding_length = self.max_seq_length - len(input_ids)
input_ids = input_ids + ([pad_token] * padding_length)
input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)
assert len(input_ids) == self.max_seq_length
assert len(input_mask) == self.max_seq_length
assert len(segment_ids) == self.max_seq_length
input_ids = torch.tensor([input_ids], dtype=torch.long).to(self.device)
input_mask = torch.tensor([input_mask], dtype=torch.long).to(self.device)
segment_ids = torch.tensor([segment_ids], dtype=torch.long).to(self.device)
return input_ids, input_mask, segment_ids
def predictor(self, text):
examples = []
# multiple permutation based on lime neighbourhood sample
for example in text:
examples.append(self.convert_text_to_features(example))
results = []
for example in examples:
with torch.no_grad():
outputs = self.model(example[0], example[1], example[2])
logits = outputs[0]
logits = F.softmax(logits, dim=1)
results.append(logits.cpu().detach().numpy()[0])
results_array = np.array(results)
return results_array
用 LIME 解釋 text classification
接著,就是引用 Lime 來解釋 input text 了。這邊我參考的是 Lime multi-class 的 tutorial,用來解釋 fine tune Bert 的 model。
from collections import Counter
from lime.lime_text import LimeTextExplainer
import pandas as pdclass_names = ['Food', 'Travel', 'Retail', 'Entertainment','Sports', 'Health']
prediction = Prediction()
explainer = LimeTextExplainer(class_names=class_names)
train_df = pd.read_csv("/home/ubuntu/data/data_no_urls.csv", sep=',')for label, text in zip(train_df[‘0’], train_df[‘1’]):
exp = explainer.explain_instance(text, prediction.predictor, num_samples=25, labels=[0,1,2,3,4,5])
exp.available_labels()
words = exp.as_list(label=label)
print(words)
這邊我用 num_samples = 25 (paper 作者 default 是 5000,但跑起來太慢了,加上我要解釋的句子偏短,用 25 應該就夠了),labels 則是對應到 class_names 的 index。