Fine-tuning a BERT model with transformers

Setup a custom Dataset, fine-tune BERT with Transformers Trainer and export the model via ONNX.

This post describes a simple way to get started with fine-tuning transformer models. It will cover the basics and introduce you to the amazing Trainer class from the transformers library. I will leave important topics such as hyperparameter tuning, cross-validation and more detailed model validation to followup posts.

Decorative image

Photo by Samule Sun on Unsplash

We use a dataset built from COVID-19 Open Research Dataset Challenge. This work is one small piece of a larger project that is to build the cord19 search app.

You can run the code from Google Colab but do not forget to enable GPU support.

Install required libraries

pip install pandas transformers

Load the dataset

In order to fine-tune the BERT models for the cord19 application we need to generate a set of query-document features as well as labels that indicate which documents are relevant for the specific queries. For this exercise we will use the query string to represent the query and the title string to represent the documents.

training_data = read_csv("")

Table 1

There are 50 unique queries.


For each query we have a list of documents, divided between relevant (label=1) and irrelevant (label=0).

training_data[["title", "label"]].groupby("label").count()

Table 2

Data split

We are going to use a simple data split into train and validation sets for illustration purposes. Even though we have more than 50 thousand data points when we consider unique query and document pairs, I believe this specific case would benefit from cross-validation since it has only 50 queries containing relevance judgement.

from sklearn.model_selection import train_test_split
train_queries, val_queries, train_docs, val_docs, train_labels, val_labels = train_test_split(

Create BERT encodings

Create train and validation encodings.
In order to do that we need to choose which BERT model to use.
We will use padding and truncation
because the training routine expects all tensors within a batch to have the same dimensions.

from transformers import BertTokenizerFast

model_name = "google/bert_uncased_L-4_H-512_A-8"
tokenizer = BertTokenizerFast.from_pretrained(model_name)

train_encodings = tokenizer(train_queries, train_docs, truncation=True, padding='max_length', max_length=128)
val_encodings = tokenizer(val_queries, val_docs, truncation=True, padding='max_length', max_length=128)

Create a custom dataset

Now that we have the encodings and the labels we can create a Dataset object as described in the transformers webpage about custom datasets.

import torch

class Cord19Dataset(
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = Cord19Dataset(train_encodings, train_labels)
val_dataset = Cord19Dataset(val_encodings, val_labels)

Fine-tune the BERT model

We are going to use BertForSequenceClassification, since we are trying to classify query and document pairs into two distinct classes (non-relevant, relevant).

from transformers import BertForSequenceClassification
model = BertForSequenceClassification.from_pretrained(model_name)

We can set requires_grad to False for all the base model parameters in order to fine-tune only the task-specific parameters.

for param in model.base_model.parameters():
    param.requires_grad = False

We can then fine-tune the model with Trainer. Below is a basic routine with out-of-the-box set of parameters. Care should be taken when chosing the parameters below, but this is out of the scope of this piece.

from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",          # output directory
    evaluation_strategy="epoch",     # Evaluation is done at the end of each epoch.
    num_train_epochs=3,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    save_total_limit=1,              # limit the total amount of checkpoints. Deletes the older checkpoints.    

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset             # evaluation dataset


Export the model to onnx

Once training is complete we can export the model using the ONNX format to be deployed elsewhere. I assume below that you have access to a GPU, which you can get from Google Colab for example.

from torch.onnx import export

device = torch.device("cuda") 

model_onnx_path = "model.onnx"
dummy_input = (
input_names = ["input_ids", "token_type_ids", "attention_mask"]
output_names = ["logits"]
    model, dummy_input, model_onnx_path, input_names = input_names, 
    output_names = output_names, verbose=False, opset_version=11

Fine-tuning a BERT model for search applications

Thiago Martins

Thiago Martins

Vespa Data Scientist

How to ensure training and serving encoding compatibility

There are cases where the inputs to your Transformer model are pairs of sentences, but you want to process each sentence of the pair at different times due to your application’s nature.

Decorative image

Photo by Alice Dietrich on Unsplash

The search use case

Search applications are one example. They involve a large collection of documents that can be pre-processed and stored before a search action is required. On the other hand, a query triggers a search action, and we can only process it in real-time. Search apps’ goal is to return the most relevant documents to the query as quickly as possible. By applying the tokenizer to the documents as soon as we feed them to the application, we only need to tokenize the query when a search action is required, saving time.

In addition to applying the tokenizer at different times, you also want to retain adequate control about encoding your pair of sentences. For search, you might want to have a joint input vector of length 128 where the query, which is usually smaller than the document, contributes with 32 tokens while the document can take up to 96 tokens.

Training and serving compatibility

When training a Transformer model for search, you want to ensure that the training data will follow the same pattern used by the search engine serving the final model. I have written a blog post on how to get started with BERT model fine-tuning using the transformer library. This piece will adapt the training routine with a custom encoding based on two separate tokenizers to reproduce how a Vespa application would serve the model once deployed.

Create independent BERT encodings

The only change required is simple but essential. In my previous post, we discussed the vanilla case where we simply applied the tokenizer directly to the pairs of queries and documents.

from transformers import BertTokenizerFast

model_name = "google/bert_uncased_L-4_H-512_A-8"
tokenizer = BertTokenizerFast.from_pretrained(model_name)

train_encodings = tokenizer(train_queries, train_docs, truncation=True, padding='max_length', max_length=128)
val_encodings = tokenizer(val_queries, val_docs, truncation=True, padding='max_length', max_length=128)

In the search case, we create the create_bert_encodings function that will apply two different tokenizers, one for the query and the other for the document. In addition to allowing for different query and document max_length, we also need to set add_special_tokens=False and not use padding, as those need to be included by our custom code when joining the tokens generated by the tokenizer.

def create_bert_encodings(queries, docs, tokenizer, query_input_size, doc_input_size):
    queries_encodings = tokenizer(
        queries, truncation=True, max_length=query_input_size-2, add_special_tokens=False
    docs_encodings = tokenizer(
        docs, truncation=True, max_length=doc_input_size-1, add_special_tokens=False

    input_ids = []
    token_type_ids = []
    attention_mask = []
    for query_input_ids, doc_input_ids in zip(queries_encodings["input_ids"], docs_encodings["input_ids"]):
        # create input id
        input_id = [TOKEN_CLS] + query_input_ids + [TOKEN_SEP] + doc_input_ids + [TOKEN_SEP]
        number_tokens = len(input_id)
        padding_length = max(128 - number_tokens, 0)
        input_id = input_id + [TOKEN_NONE] * padding_length
        # create token id
        token_type_id = [0] * len([TOKEN_CLS] + query_input_ids + [TOKEN_SEP]) + [1] * len(doc_input_ids + [TOKEN_SEP]) + [TOKEN_NONE] * padding_length
        # create attention_mask
        attention_mask.append([1] * number_tokens + [TOKEN_NONE] * padding_length)

    encodings = {
        "input_ids": input_ids,
        "token_type_ids": token_type_ids,
        "attention_mask": attention_mask
    return encodings

We then create the train_encodings and val_encodings required by the training routine. Everything else on the training routine works just the same.

from transformers import BertTokenizerFast

model_name = "google/bert_uncased_L-4_H-512_A-8"
tokenizer = BertTokenizerFast.from_pretrained(model_name)

train_encodings = create_bert_encodings(

val_encodings = create_bert_encodings(

Conclusion and future work

Training a model to deploy in a search application require us to ensure that the training encodings are compatible with encodings used at serving time. We generate document encodings offline when feeding the documents to the search engine while creating query encoding at run-time upon arrival of the query. It is often relevant to use different maximum lengths for queries and documents, and other possible configurations.

Decorative image

Photo by Steve Johnson on Unsplash

We showed how to customize BERT model encodings to ensure this training and serving compatibility. However, a better approach is to build tools that bridge the gap between training and serving by allowing users to request training data that respects by default the encodings used when serving the model. pyvespa will include such integration to make it easier for Vespa users to train BERT models without having to adjust the encoding generation manually as we did above.