In case running on Google Colab

Mount your Google Drive to be able to load and save data to it.

from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive

Install libraries used in this post.

!pip install transformers torch onnxruntime optuna

Prepare the dataset

Load the data

In order to fine-tune the BERT models for the cord19 application we need to generate a set of query-document features as well as labels that indicate which documents are relevant for the specific queries. For this exercise we will use the query string to represent the query and the title string to represent the documents.

The file labelled_data.json contains information about the query string and the file training_all_judgement_data.csv contain information about labels and title string. Those files were created and covered elsewhere but you can download them here and here.

import json
from pandas import read_csv

labelled_data = json.load(open("/content/drive/My Drive/cord19/labelled_data_all.json", "r"))
training_data = read_csv("/content/drive/My Drive/cord19/training_all_jugdments_data.csv")

training_data has almost everything we need, except the query string.

training_data.head()
document_id query_id label title-full
0 005b2j4b 1 2 Monophyletic Relationship between Severe Acute...
1 00fmeepz 1 1 Comprehensive overview of COVID-19 based on cu...
2 010vptx3 1 2 The SARS, MERS and novel coronavirus (COVID-19...
3 0194oljo 1 1 Evidence for zoonotic origins of Middle East r...
4 021q9884 1 1 Deadly virus effortlessly hops species

The query string can be obtained from the labelled_data.

print(labelled_data[0]["query_id"], labelled_data[0]["query"])
1 coronavirus origin

Compatible BERT encodings

Since we are training a model that will be deployed in a search application, we need to ensure that the training encodings are compatible with encodings used at serving time. At serving time, document encodings will be applied offline when feeding the documents to the search engine while the query encoding will be applied at run-time upon arrival of the query. In addition, it might be relevant to use different maximum length for queries and documents.

def create_bert_encodings(queries, docs, tokenizer, query_input_size, doc_input_size):
    queries_encodings = tokenizer(
        queries, truncation=True, max_length=query_input_size-2, add_special_tokens=False
    )
    docs_encodings = tokenizer(
        docs, truncation=True, max_length=doc_input_size-1, add_special_tokens=False
    )
    
    TOKEN_NONE=0
    TOKEN_CLS=101
    TOKEN_SEP=102

    input_ids = []
    token_type_ids = []
    attention_mask = []
    for query_input_ids, doc_input_ids in zip(queries_encodings["input_ids"], docs_encodings["input_ids"]):
        # create input id
        input_id = [TOKEN_CLS] + query_input_ids + [TOKEN_SEP] + doc_input_ids + [TOKEN_SEP]
        number_tokens = len(input_id)
        padding_length = max(128 - number_tokens, 0)
        input_id = input_id + [TOKEN_NONE] * padding_length
        input_ids.append(input_id)
        # create token id
        token_type_id = [0] * len([TOKEN_CLS] + query_input_ids + [TOKEN_SEP]) + [1] * len(doc_input_ids + [TOKEN_SEP]) + [TOKEN_NONE] * padding_length
        token_type_ids.append(token_type_id)
        # create attention_mask
        attention_mask.append([1] * number_tokens + [TOKEN_NONE] * padding_length)

    encodings = {
        "input_ids": input_ids,
        "token_type_ids": token_type_ids,
        "attention_mask": attention_mask
    }
    return encodings

Create Datasets

Create a list for queries (represented by the query string), docs (represented by the doc titles) and labels from the labelled_data and training_data that we loaded earlier.

train_queries = []
train_docs = []
train_labels = []
for data_point in labelled_data:
    query_id = data_point["query_id"]
    titles = training_data[training_data["query_id"] == query_id]["title-full"].tolist()
    train_docs.extend(titles)
    train_labels.extend([1 if x > 0 else 0 for x in training_data[training_data["query_id"] == query_id]["label"].tolist()])
    query = data_point["query"]
    train_queries.extend([query] * len(titles))

We are going to use a simple data split into train and validation sets for illustration purposes. The cord19 use case probably needs cross-validation to be used since it has only 50 queries containing relevance judgement.

from sklearn.model_selection import train_test_split
train_queries, val_queries, train_docs, val_docs, train_labels, val_labels = train_test_split(
    train_queries, train_docs, train_labels, test_size=.2
)

Create train and validation encodings. In order to do that we need to chose which BERT model to use, and the maximum size used for the resulting query and document vector.

model_name = "google/bert_uncased_L-4_H-512_A-8"
query_input_size=24
doc_input_size=64
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained(model_name)

train_encodings = create_bert_encodings(
    queries=train_queries, 
    docs=train_docs, 
    tokenizer=tokenizer, 
    query_input_size=query_input_size, 
    doc_input_size=doc_input_size
)

val_encodings = create_bert_encodings(
    queries=val_queries, 
    docs=val_docs, 
    tokenizer=tokenizer, 
    query_input_size=query_input_size, 
    doc_input_size=doc_input_size
)

Now that we have the encodings and the labels we can create a Dataset object as described in the transformers webpage about custom datasets.

import torch

class Cord19Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = Cord19Dataset(train_encodings, train_labels)
val_dataset = Cord19Dataset(val_encodings, val_labels)

Fine-tune the BERT model

We can then fine-tune the model (only task specific weights).

Define accuracy metric.

from transformers import EvalPrediction
import numpy as np

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    preds = np.argmax(preds, axis=1)
    return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}

Hyperparameter tunning with Optuna.

from transformers import BertForSequenceClassification, Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='/content/results',    # output directory
    #evaluation_strategy="epoch",     # Evaluation is done at the end of each epoch.
    evaluation_strategy="steps",      # Evaluation is done (and logged) every eval_steps.
    eval_steps=1000,                  # Number of update steps between two evaluations 
    per_device_eval_batch_size=64,    # batch size for evaluation
    save_total_limit=1,               # limit the total amount of checkpoints. Deletes the older checkpoints.
)

def model_init():
    model = BertForSequenceClassification.from_pretrained(model_name)
    for param in model.base_model.parameters():
        param.requires_grad = False
    return model

trainer = Trainer(
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset,            # evaluation dataset
    compute_metrics=compute_metrics,     # metrics to be computed
    model_init=model_init                # Instantiate model before training starts
)

def my_hp_space(trial):
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-4, 1e-2, log=True),
        "num_train_epochs": trial.suggest_int("num_train_epochs", 1, 20),
        "seed": trial.suggest_int("seed", 1, 40),
        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16, 32, 64]),
    }
def my_objective(metrics):
    return metrics["eval_loss"]

best_run = trainer.hyperparameter_search(direction="minimize", hp_space=my_hp_space, compute_objective=my_objective, n_trials=100)

with open("/content/drive/My Drive/cord19/best_run.json", "w+") as f:
  f.write(json.dumps(best_run.hyperparameters))

Inspect best parameters

best_run.hyperparameters

Retrain using the best parameters and the entire dataset (need to create complete_dataset)

training_args = TrainingArguments(
    output_dir='/content/results',   # output directory
    evaluation_strategy="epoch",     # Evaluation is done at the end of each epoch.
    per_device_eval_batch_size=64,   # batch size for evaluation
    save_total_limit=2,              # limit the total amount of checkpoints. Deletes the older checkpoints.
    **best_run.hyperparameters    
)

trainer = Trainer(
    args=training_args,                  # training arguments, defined above
    train_dataset=complte_dataset,       # training dataset
    compute_metrics=compute_metrics,     # metrics to be computed
    model_init=model_init                # Instantiate model before training starts
)

trainer.train()