Source: MachineLearningMastery.com
BERT is an encoder-only transformer model pretrained on the masked language model (MLM) and next sentence prediction (NSP) tasks before being fine-tuned for various NLP tasks. Pretraining requires special data preparation. In this article, you will learn how to:
- Create masked language model (MLM) training data
- Create next sentence prediction (NSP) training data
- Set up labels for BERT training
- Use Hugging Face datasets to store the training data
Let’s get started.

Preparing Data for BERT Training
Photo by Daniel Gimbel. Some rights reserved.
Overview
This article is divided into four parts; they are:
- Preparing Documents
- Creating Sentence Pairs from Document
- Masking Tokens
- Saving the Training Data for Reuse
Preparing Documents
Unlike decoder-only models, BERT’s pretraining is more complex. As mentioned in the previous post, pretraining optimizes the combined loss from both MLM and NSP tasks. Therefore, training data must be labeled for both tasks.
Let’s follow Google’s BERT implementation using the Wikitext-2 or Wikitext-103 dataset. Each line in the dataset is either empty, a title line starting with “=“, or regular text. Only regular text lines are used for training.
BERT training requires two “sentences” per sample. For simplicity, define:
- A “sentence” is a line of text in the dataset
- A document is a sequence of consecutive “sentences”, separated by empty lines or title lines
Assuming you have trained a tokenizer as in the previous post, let’s create a function to collect text into a list of documents:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import tokenizers from datasets import load_dataset def create_docs(path, name, tokenizer): “”“Load wikitext dataset and extract text as documents”“” dataset = load_dataset(path, name, split=“train”) docs = [] for line in dataset[“text”]: line = line.strip() if not line or line.startswith(“=”): docs.append([]) # new document encountered else: tokens = tokenizer.encode(line).ids docs[–1].append(tokens) docs = [doc for doc in docs if doc] # remove empty documents return docs # load the tokenizer tokenizer = tokenizers.Tokenizer.from_file(“wikitext-103_wordpiece.json”) docs = create_docs(“wikitext”, “wikitext-103-raw-v1”, tokenizer) |
This code processes text lines sequentially. When encountering a document delimiter, it creates a new list for subsequent lines. Each line is stored as a list of integers from the tokenizer.
Tracking documents is essential for the NSP task: Two sentences form a “next sequence” pair only if both come from the same document.
Create Sentence Pairs from Document
The next step extracts sentence pairs from documents. Pairs can be consecutive sentences from the same document or random sentences from different documents. Let’s use the following algorithm to create the pairs:
- Scan each sentence from each document as the first sentence
- For the second sentence, pick either the next sentence from the same document or a random sentence from another document
But there is a constraint: total length of the sentence pair must not exceed BERT’s maximum sequence length. You need to truncate the sentences if necessary.
Here’s how you can implement this algorithm in Python:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
... import random # copy the document chunks = [] for chunk in all_docs[doc_idx]: chunks.append(chunk) # exhaust chunks and create samples while chunks: # scan until target token length running_length = 0 end = 1 while end < len(chunks) and running_length < target_length: running_length += len(chunks[end–1]) end += 1 # randomly separate the chunk into two segments sep = random.randint(1, end–1) if end > 1 else 1 sentence_a = [tok for chunk in chunks[:sep] for tok in chunk] sentence_b = [tok for chunk in chunks[sep:end] for tok in chunk] # sentence B: may be from another document if not sentence_b or random.random() < 0.5: # find another document (must not be the same as doc_idx) b_idx = random.randint(0, len(all_docs)–2) if b_idx >= doc_idx: b_idx += 1 # sentence B starts from a random position in the new document sentence_b = [] running_length = len(sentence_a) i = random.randint(0, len(all_docs[b_idx])–1) while i < len(all_docs[b_idx]) and running_length < target_length: sentence_b.extend(all_docs[b_idx][i]) running_length += len(all_docs[b_idx][i]) i += 1 is_random_next = True chunks = chunks[sep:] else: is_random_next = False chunks = chunks[end:] # the pair is found pair = (sentence_a, sentence_b) |
This code creates pairs from a given document at index doc_idx. The initial for-loop copies sentences as chunks to avoid mutating the original document. The while-loop scans chunks until reaching the target token length, then randomly splits them into two segments.
With 50% probability, the second sentence is replaced with a random sentence from another document. This large if-block creates the NSP task label (recorded in is_random_next) and sampled a sentence from another document.
At each iteration’s end, chunks is updated to retain unused portions. The document is exhausted when this list empties. Both sentence_a and sentence_b are lists of integer token IDs.
This approach follows Google’s original BERT implementation, though it doesn’t exhaust all possible combinations. The pairs created above may be longer than the target sequence length and you need to truncate them. The truncation is implemented as follows:
|
def truncate_seq_pair(sentence_a, sentence_b, max_num_tokens): while len(sentence_a) + len(sentence_b) > max_num_tokens: # pick the longer sentence to remove tokens from candidate = sentence_a if len(sentence_a) > len(sentence_b) else sentence_b # remove one token from either end in equal probabilities if random.random() < 0.5: candidate.pop(0) else: candidate.pop() |
Truncation is applied iteratively to the longer sentence until the total length falls below the target. Tokens are removed from either end with equal probability. The result may be a chunk from the middle of the original sentence, hence the naming convention “chunk” in the code above.
Masking Tokens
Masked tokens are the most critical part of BERT training data. The original paper specifies that 15% of tokens are masked. Actually, this means the model is trained to predict the tokens on only 15% of its output. Within this 15%, the token could be:
- 80% of the time, the token is replaced with the
[MASK]token. - 10% of the time, the token is replaced with a random token from the vocabulary.
- 10% of the time, the token is unchanged.
The model must correctly predict the original token in all cases. After creating pairs, let’s implement masking as follows:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
def create_sample(sentence_a, sentence_b, is_random_next, tokenizer, max_seq_length=512, mask_prob=0.15, max_predictions_per_seq=20): # Collect id of special tokens cls_id = tokenizer.token_to_id(“[CLS]”) sep_id = tokenizer.token_to_id(“[SEP]”) mask_id = tokenizer.token_to_id(“[MASK]”) pad_id = tokenizer.padding[“pad_id”] # adjust length to fit the max sequence length truncate_seq_pair(sentence_a, sentence_b, max_seq_length–3) num_pad = max_seq_length – len(sentence_a) – len(sentence_b) – 3 # create unmodified tokens sequence tokens = [cls_id] + sentence_a + [sep_id] + sentence_b + [sep_id] + ([pad_id] * num_pad) seg_id = [0] * (len(sentence_a) + 2) + [1] * (len(sentence_b) + 1) + [–1] * num_pad assert len(tokens) == len(seg_id) == max_seq_length # create the prediction targets cand_indices = [i for i, tok in enumerate(tokens) if tok not in [cls_id, sep_id, pad_id]] random.shuffle(cand_indices) num_predictions = int(round((len(sentence_a) + len(sentence_b)) * mask_prob)) num_predictions = min(max_predictions_per_seq, max(1, num_predictions)) mlm_positions = sorted(cand_indices[:num_predictions]) mlm_labels = [] for i in mlm_positions: mlm_labels.append(tokens[i]) # prob 0.8 replace with [MASK], prob 0.1 replace with random word, prob 0.1 keep original if random.random() < 0.8: tokens[i] = mask_id elif random.random() < 0.5: tokens[i] = random.randint(4, tokenizer.get_vocab_size()–1) # randomly mask some tokens ret = { “tokens”: tokens, “segment_ids”: seg_id, “is_random_next”: is_random_next, “masked_positions”: mlm_positions, “masked_labels”: mlm_labels, } return ret |
This function creates a token sequence: [CLS] , where and are the sentence pair with masked tokens. The IDs of the special tokens come from the tokenizer.
First, the sentence pair is truncated to fit the maximum sequence length, reserving space for three special tokens. The sequence is then padded to the expected length. Segment labels are created to distinguish sentences: 0 for the first sentence, 1 for the second, and -1 for padding.
All non-special tokens are masking candidates. Their indices are shuffled, and the first num_predictions positions are selected. This number depends on mask_prob (default 15%) and is capped at max_predictions_per_seq (default 20):
|
num_predictions = int(round((len(sentence_a) + len(sentence_b)) * mask_prob)) num_predictions = min(max_predictions_per_seq, max(1, num_predictions)) |
The variable mlm_positions is a list of the indices of the masked positions, in ascending order. The variable mlm_labels is a list of the original tokens at the masked positions. When a random token from the vocabulary is needed, you pick one from the tokenizer:
|
tokens[i] = random.randint(4, tokenizer.get_vocab_size()–1) |
The first four tokens in the vocabulary are special tokens. They will not be selected for masking. The dictionary ret to return is indeed the “sample” that you will use to train the BERT model.
Saving the Training Data for Reuse
So far, you learned how to process the raw dataset into sentence pairs with masked tokens for the MLM and NSP training tasks. With the code above, you can create a list of dictionaries as the training data. However, this may not be the best way you want to serve the training loop.
The standard way to serve the training data in PyTorch code is to use the Dataset class. That is, to define a class like the following:
|
class CustomDataset(torch.utils.data.Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] |
The key is to provide the __len__ and __getitem__ methods to return the total number of samples and the sample at the given index in the dataset respectively. However, this may not be the best for BERT training because you would likely need to load the entire dataset into memory at once when you initialize the dataset class. This is not efficient when the dataset is large.
An alternative is to use the Dataset class from the Hugging Face datasets library. It hides many data management details so that you can focus on more important things. Let’s assume you created a generator function that yields samples:
|
def create_dataset(docs, tokenizer): ... for doc in docs: for sample in create_samples(doc): yield sample |
You can create a dataset object with:
|
... from datasets import Dataset dataset = Dataset.from_generator(create_dataset, gen_kwargs={“docs”: docs, “tokenizer”: tokenizer}) dataset.to_parquet(“wikitext-103_train_data.parquet”) |
These two lines of code will pull all samples from the generator function and then save them to a file of parquet format. Depends on the size of your dataset, this may take a while. The gen_kwargs is a dictionary of keyword arguments to pass to the generator function. It should match how you defined the function.
Once you have the dataset in parquet format, you can load it back, and try to print a few samples:
|
... dataset = Dataset.from_parquet(“wikitext-103_train_data.parquet”, streaming=True) for i, sample in enumerate(dataset): if i > 5: break print(sample) |
This is where the parquet format shines. The Hugging Face datasets library also supports JSON and CSV formats. But parquet is a compressed columnar format that is more efficient for data storage and retrieval. Setting streaming=True is optional. This allows you to load only the part of the dataset that you are using rather than loading the entire dataset into memory at once.
Putting everything together, this is the complete code:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
“”“Process the WikiText dataset for training the BERT model. Using Hugging Face datasets library. ““” import time import random from typing import Iterator import tokenizers from datasets import load_dataset, Dataset # path and name of each dataset DATASETS = { “wikitext-2”: (“wikitext”, “wikitext-2-raw-v1”), “wikitext-103”: (“wikitext”, “wikitext-103-raw-v1”), } PATH, NAME = DATASETS[“wikitext-103”] TOKENIZER_PATH = “wikitext-103_wordpiece.json” def create_docs(path: str, name: str, tokenizer: tokenizers.Tokenizer) -> list[list[list[int]]]: “”“Load wikitext dataset and extract text as documents”“” dataset = load_dataset(path, name, split=“train”) docs: list[list[list[int]]] = [] for line in dataset[“text”]: line = line.strip() if not line or line.startswith(“=”): docs.append([]) # new document encountered else: tokens = tokenizer.encode(line).ids docs[–1].append(tokens) docs = [doc for doc in docs if doc] # remove empty documents return docs def create_dataset( docs: list[list[list[int]]], tokenizer: tokenizers.Tokenizer, max_seq_length: int = 512, doc_repeat: int = 10, mask_prob: float = 0.15, short_seq_prob: float = 0.1, max_predictions_per_seq: int = 20, ) -> Iterator[dict]: “”“Generate samples from all documents”“” doc_indices = list(range(len(docs))) * doc_repeat for doc_idx in doc_indices: yield from generate_samples(doc_idx, docs, tokenizer, max_seq_length, mask_prob, short_seq_prob, max_predictions_per_seq) def generate_samples( doc_idx: int, all_docs: list[list[list[int]]], tokenizer: tokenizers.Tokenizer, max_seq_length: int = 512, mask_prob: float = 0.15, short_seq_prob: float = 0.1, max_predictions_per_seq: int = 20, ) -> Iterator[dict]: “”“Generate samples from a given document”“” # number of tokens to extract from this doc, excluding [CLS], [SEP], [SEP] target_length = max_seq_length – 3 if random.random() < short_seq_prob: # shorter sequence is used 10% of the time target_length = random.randint(2, target_length) # copy the document chunks = [] for chunk in all_docs[doc_idx]: chunks.append(chunk) # exhaust chunks and create samples while chunks: # scan until target token length running_length = 0 end = 1 while end < len(chunks) and running_length < target_length: running_length += len(chunks[end–1]) end += 1 # randomly separate the chunk into two segments sep = random.randint(1, end–1) if end > 1 else 1 sentence_a = [tok for chunk in chunks[:sep] for tok in chunk] sentence_b = [tok for chunk in chunks[sep:end] for tok in chunk] # sentence B: may be from another document if not sentence_b or random.random() < 0.5: # find another document (must not be the same as doc_idx) b_idx = random.randint(0, len(all_docs)–2) if b_idx >= doc_idx: b_idx += 1 # sentence B starts from a random position in the new document sentence_b = [] running_length = len(sentence_a) i = random.randint(0, len(all_docs[b_idx])–1) while i < len(all_docs[b_idx]) and running_length < target_length: sentence_b.extend(all_docs[b_idx][i]) running_length += len(all_docs[b_idx][i]) i += 1 is_random_next = True chunks = chunks[sep:] else: is_random_next = False chunks = chunks[end:] # create a sample from the pair yield create_sample(sentence_a, sentence_b, is_random_next, tokenizer, max_seq_length, mask_prob, max_predictions_per_seq) def create_sample( sentence_a: list[list[int]], sentence_b: list[list[int]], is_random_next: bool, tokenizer: tokenizers.Tokenizer, max_seq_length: int = 512, mask_prob: float = 0.15, max_predictions_per_seq: int = 20, ) -> dict: “”“Create a sample from a pair of sentences”“” # Collect id of special tokens cls_id = tokenizer.token_to_id(“[CLS]”) sep_id = tokenizer.token_to_id(“[SEP]”) mask_id = tokenizer.token_to_id(“[MASK]”) pad_id = tokenizer.padding[“pad_id”] # adjust length to fit the max sequence length truncate_seq_pair(sentence_a, sentence_b, max_seq_length–3) num_pad = max_seq_length – len(sentence_a) – len(sentence_b) – 3 # create unmodified tokens sequence tokens = [cls_id] + sentence_a + [sep_id] + sentence_b + [sep_id] + ([pad_id] * num_pad) seg_id = [0] * (len(sentence_a) + 2) + [1] * (len(sentence_b) + 1) + [–1] * num_pad assert len(tokens) == len(seg_id) == max_seq_length # create the prediction targets cand_indices = [i for i, tok in enumerate(tokens) if tok not in [cls_id, sep_id, pad_id]] random.shuffle(cand_indices) num_predictions = int(round((len(sentence_a) + len(sentence_b)) * mask_prob)) num_predictions = min(max_predictions_per_seq, max(1, num_predictions)) mlm_positions = sorted(cand_indices[:num_predictions]) mlm_labels = [] for i in mlm_positions: mlm_labels.append(tokens[i]) # prob 0.8 replace with [MASK], prob 0.1 replace with random word, prob 0.1 keep original if random.random() < 0.8: tokens[i] = mask_id elif random.random() < 0.5: tokens[i] = random.randint(4, tokenizer.get_vocab_size()–1) # randomly mask some tokens ret = { “tokens”: tokens, “segment_ids”: seg_id, “is_random_next”: is_random_next, “masked_positions”: mlm_positions, “masked_labels”: mlm_labels, } return ret def truncate_seq_pair(sentence_a: list[int], sentence_b: list[int], max_num_tokens: int) -> None: “”“Truncate a pair of sequences until below a maximum sequence length.”“” while len(sentence_a) + len(sentence_b) > max_num_tokens: # pick the longer sentence to remove tokens from candidate = sentence_a if len(sentence_a) > len(sentence_b) else sentence_b # remove one token from either end in equal probabilities if random.random() < 0.5: candidate.pop(0) else: candidate.pop() if __name__ == “__main__”: print(time.time(), “started”) tokenizer = tokenizers.Tokenizer.from_file(TOKENIZER_PATH) print(time.time(), “loaded tokenizer”) docs = create_docs(PATH, NAME, tokenizer) print(time.time(), “created docs with %d documents” % len(docs)) dataset = Dataset.from_generator(create_dataset, gen_kwargs={“docs”: docs, “tokenizer”: tokenizer}) print(time.time(), “created dataset from generator”) # Save dataset to parquet file dataset.to_parquet(“wikitext-103_train_data.parquet”) print(time.time(), “saved dataset to parquet file”) # Load dataset from parquet file dataset = Dataset.from_parquet(“wikitext-103_train_data.parquet”, streaming=True) print(time.time(), “loaded dataset from parquet file”) # Print a few samples for i, sample in enumerate(dataset): print(i) print(sample) print() if i >= 3: break print(time.time(), “finished”) |
Running this code, you will see the output like the following:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
1763913652.5099447 started 1763913652.5830114 loaded tokenizer 1763913817.1271229 created docs with 268854 documents Generating train split: 4268307 examples [11:42:36, 101.25 examples/s] Loading dataset shards: 100%|█████████████████████████████| 73/73 [00:10<00:00, 7.05it/s] 1763956111.2021146 created dataset from generator Creating parquet from Arrow format: 100%|█████████████| 4269/4269 [06:13<00:00, 11.42ba/s] 1763956487.0040812 saved dataset to parquet file Generating train split: 4268307 examples [06:22, 11168.96 examples/s] Loading dataset shards: 100%|█████████████████████████████| 74/74 [00:09<00:00, 8.20it/s] 1763956881.6215432 loaded dataset from parquet file {‘tokens’: [1, 10887, 4875, …, 0, 0], ‘segment_ids’: [0, 0, …, 1, 1, …, -1, -1], ‘is_random_next’: True, ‘masked_positions’: [29, 58, …], ‘masked_labels’: [15, 8551, …]} 1 {‘tokens’: [1, 8792, 9150, …, 0, 0], ‘segment_ids’: [0, 0, …, 1, 1, …, -1, -1], ‘is_random_next’: True, ‘masked_positions’: [15, 19, …], ‘masked_labels’: [8522, 9100, …]} 2 {‘tokens’: [1, 8506, 8556, …, 0, 0], ‘segment_ids’: [0, 0, …, 1, 1, …, -1, -1], ‘is_random_next’: False, ‘masked_positions’: [3, 8, …], ‘masked_labels’: [19367, 29188, …]} 3 {‘tokens’: [1, 8544, 8910, …, 0, 0], ‘segment_ids’: [0, 0, …, 1, 1, …, -1, -1], ‘is_random_next’: False, ‘masked_positions’: [13, 16, …], ‘masked_labels’: [8656, 12114, …]} 1763956881.6688802 finished |
The timestamps printed intermittently are intentional to show the time spent. This code processes the Wikitext-103 dataset and it takes several hours. Once complete, the parquet file enables fast, efficient iteration over samples. For testing, you can use the smaller Wikitext-2 dataset instead. You can see how the code runs in a few minutes.
Further Readings
Below are some resources that you may find useful:
- Devlin et al (2018) BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
- Google’s BERT implementation on GitHub
Summary
In this article, you learned how to prepare the data for BERT training. You learned how to create masked language model (MLM) training data and next sentence prediction (NSP) training data. You also learned how to save the data in parquet format for reuse.
