Fine-tuning a BERT model for search applications
Create training and serving compatible encodings, custom dataset and a basic setup based on the Transformers library.
- Load the dataset
- Compatible BERT encodings
- Create Datasets
- Fine-tune the BERT model
- Export the model to onnx
In order to fine-tune the BERT models for the cord19 application we need to generate a set of query-document features as well as labels that indicate which documents are relevant for the specific queries. For this exercise we will use the query
string to represent the query and the title
string to represent the documents.
import json
from pandas import read_csv
labelled_data = json.load(open("labelled_data_all.json", "r"))
training_data = read_csv("training_all_jugdments_data.csv")
training_data
has almost everything we need, except the query
string.
training_data.head()
The query string can be obtained from the labelled_data
.
print(labelled_data[0]["query_id"], labelled_data[0]["query"])
Since we are training a model that will be deployed in a search application, we need to ensure that the training encodings are compatible with encodings used at serving time. At serving time, document encodings will be applied offline when feeding the documents to the search engine while the query encoding will be applied at run-time upon arrival of the query. In addition, it might be relevant to use different maximum length for queries and documents.
def create_bert_encodings(queries, docs, tokenizer, query_input_size, doc_input_size):
queries_encodings = tokenizer(
queries, truncation=True, max_length=query_input_size-2, add_special_tokens=False
)
docs_encodings = tokenizer(
docs, truncation=True, max_length=doc_input_size-1, add_special_tokens=False
)
TOKEN_NONE=0
TOKEN_CLS=101
TOKEN_SEP=102
input_ids = []
token_type_ids = []
attention_mask = []
for query_input_ids, doc_input_ids in zip(queries_encodings["input_ids"], docs_encodings["input_ids"]):
# create input id
input_id = [TOKEN_CLS] + query_input_ids + [TOKEN_SEP] + doc_input_ids + [TOKEN_SEP]
number_tokens = len(input_id)
padding_length = max(128 - number_tokens, 0)
input_id = input_id + [TOKEN_NONE] * padding_length
input_ids.append(input_id)
# create token id
token_type_id = [0] * len([TOKEN_CLS] + query_input_ids + [TOKEN_SEP]) + [1] * len(doc_input_ids + [TOKEN_SEP]) + [TOKEN_NONE] * padding_length
token_type_ids.append(token_type_id)
# create attention_mask
attention_mask.append([1] * number_tokens + [TOKEN_NONE] * padding_length)
encodings = {
"input_ids": input_ids,
"token_type_ids": token_type_ids,
"attention_mask": attention_mask
}
return encodings
Create a list for queries (represented by the query string), docs (represented by the doc titles) and labels from the labelled_data
and training_data
that we loaded earlier.
train_queries = []
train_docs = []
train_labels = []
for data_point in labelled_data:
query_id = data_point["query_id"]
titles = training_data[training_data["query_id"] == query_id]["title-full"].tolist()
train_docs.extend(titles)
train_labels.extend([1 if x > 0 else 0 for x in training_data[training_data["query_id"] == query_id]["label"].tolist()])
query = data_point["query"]
train_queries.extend([query] * len(titles))
We are going to use a simple data split into train and validation sets for illustration purposes. The cord19 use case probably needs cross-validation to be used since it has only 50 queries containing relevance judgement.
from sklearn.model_selection import train_test_split
train_queries, val_queries, train_docs, val_docs, train_labels, val_labels = train_test_split(
train_queries, train_docs, train_labels, test_size=.2
)
Create train and validation encodings. In order to do that we need to chose which BERT model to use, and the maximum size used for the resulting query and document vector.
model_name = "google/bert_uncased_L-4_H-512_A-8"
query_input_size=24
doc_input_size=64
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained(model_name)
train_encodings = create_bert_encodings(
queries=train_queries,
docs=train_docs,
tokenizer=tokenizer,
query_input_size=query_input_size,
doc_input_size=doc_input_size
)
val_encodings = create_bert_encodings(
queries=val_queries,
docs=val_docs,
tokenizer=tokenizer,
query_input_size=query_input_size,
doc_input_size=doc_input_size
)
Now that we have the encodings and the labels we can create a Dataset
object as described in the transformers webpage about custom datasets.
import torch
class Cord19Dataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
train_dataset = Cord19Dataset(train_encodings, train_labels)
val_dataset = Cord19Dataset(val_encodings, val_labels)
We can then fine-tune the model (only task specific weights). Below is a basic routine with out-of-the-box set of parameters. Care should be taken when chosing the parameters below, but this is out of the scope of this piece.
from transformers import BertForSequenceClassification, Trainer, TrainingArguments
training_args = TrainingArguments(
output_dir='./results', # output directory
num_train_epochs=3, # total number of training epochs
per_device_train_batch_size=16, # batch size per device during training
per_device_eval_batch_size=64, # batch size for evaluation
warmup_steps=500, # number of warmup steps for learning rate scheduler
weight_decay=0.01, # strength of weight decay
logging_dir='./logs', # directory for storing logs
logging_steps=10,
)
model = BertForSequenceClassification.from_pretrained(model_name)
for param in model.base_model.parameters():
param.requires_grad = False
trainer = Trainer(
model=model, # the instantiated 🤗 Transformers model to be trained
args=training_args, # training arguments, defined above
train_dataset=train_dataset, # training dataset
eval_dataset=val_dataset # evaluation dataset
)
trainer.train()
Once training is complete we can export the model using the ONNX format.
from torch.onnx import export
from pathlib import Path
model_onnx_path = Path(model_name + ".onnx")
dummy_input = (
train_dataset[0]["input_ids"].unsqueeze(0),
train_dataset[0]["token_type_ids"].unsqueeze(0),
train_dataset[0]["attention_mask"].unsqueeze(0)
)
input_names = ["input_ids", "token_type_ids", "attention_mask"]
output_names = ["logits"]
export(
model, dummy_input, model_onnx_path, input_names = input_names,
output_names = output_names, verbose=False, opset_version=11
)