Preparing Data for BERT Training

preparing-data-for-bert-training

Source: MachineLearningMastery.com

BERT is an encoder-only transformer model pretrained on the masked language model (MLM) and next sentence prediction (NSP) tasks before being fine-tuned for various NLP tasks. Pretraining requires special data preparation. In this article, you will learn how to:

  • Create masked language model (MLM) training data
  • Create next sentence prediction (NSP) training data
  • Set up labels for BERT training
  • Use Hugging Face datasets to store the training data

Let’s get started.

Preparing Data for BERT Training
Photo by Daniel Gimbel. Some rights reserved.

Overview

This article is divided into four parts; they are:

  • Preparing Documents
  • Creating Sentence Pairs from Document
  • Masking Tokens
  • Saving the Training Data for Reuse

Preparing Documents

Unlike decoder-only models, BERT’s pretraining is more complex. As mentioned in the previous post, pretraining optimizes the combined loss from both MLM and NSP tasks. Therefore, training data must be labeled for both tasks.

Let’s follow Google’s BERT implementation using the Wikitext-2 or Wikitext-103 dataset. Each line in the dataset is either empty, a title line starting with “=“, or regular text. Only regular text lines are used for training.

BERT training requires two “sentences” per sample. For simplicity, define:

  • A “sentence” is a line of text in the dataset
  • A document is a sequence of consecutive “sentences”, separated by empty lines or title lines

Assuming you have trained a tokenizer as in the previous post, let’s create a function to collect text into a list of documents:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

import tokenizers

from datasets import load_dataset

def create_docs(path, name, tokenizer):

    “”“Load wikitext dataset and extract text as documents”“”

    dataset = load_dataset(path, name, split=“train”)

    docs = []

    for line in dataset[“text”]:

        line = line.strip()

        if not line or line.startswith(“=”):

            docs.append([])   # new document encountered

        else:

            tokens = tokenizer.encode(line).ids

            docs[1].append(tokens)

    docs = [doc for doc in docs if doc]  # remove empty documents

    return docs

# load the tokenizer

tokenizer = tokenizers.Tokenizer.from_file(“wikitext-103_wordpiece.json”)

docs = create_docs(“wikitext”, “wikitext-103-raw-v1”, tokenizer)

This code processes text lines sequentially. When encountering a document delimiter, it creates a new list for subsequent lines. Each line is stored as a list of integers from the tokenizer.

Tracking documents is essential for the NSP task: Two sentences form a “next sequence” pair only if both come from the same document.

Create Sentence Pairs from Document

The next step extracts sentence pairs from documents. Pairs can be consecutive sentences from the same document or random sentences from different documents. Let’s use the following algorithm to create the pairs:

  1. Scan each sentence from each document as the first sentence
  2. For the second sentence, pick either the next sentence from the same document or a random sentence from another document

But there is a constraint: total length of the sentence pair must not exceed BERT’s maximum sequence length. You need to truncate the sentences if necessary.

Here’s how you can implement this algorithm in Python:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

...

import random

# copy the document

chunks = []

for chunk in all_docs[doc_idx]:

    chunks.append(chunk)

# exhaust chunks and create samples

while chunks:

    # scan until target token length

    running_length = 0

    end = 1

    while end < len(chunks) and running_length < target_length:

        running_length += len(chunks[end1])

        end += 1

    # randomly separate the chunk into two segments

    sep = random.randint(1, end1) if end > 1 else 1

    sentence_a = [tok for chunk in chunks[:sep] for tok in chunk]

    sentence_b = [tok for chunk in chunks[sep:end] for tok in chunk]

    # sentence B: may be from another document

    if not sentence_b or random.random() < 0.5:

        # find another document (must not be the same as doc_idx)

        b_idx = random.randint(0, len(all_docs)2)

        if b_idx >= doc_idx:

            b_idx += 1

        # sentence B starts from a random position in the new document

        sentence_b = []

        running_length = len(sentence_a)

        i = random.randint(0, len(all_docs[b_idx])1)

        while i < len(all_docs[b_idx]) and running_length < target_length:

            sentence_b.extend(all_docs[b_idx][i])

            running_length += len(all_docs[b_idx][i])

            i += 1

        is_random_next = True

        chunks = chunks[sep:]

    else:

        is_random_next = False

        chunks = chunks[end:]

    # the pair is found

    pair = (sentence_a, sentence_b)

This code creates pairs from a given document at index doc_idx. The initial for-loop copies sentences as chunks to avoid mutating the original document. The while-loop scans chunks until reaching the target token length, then randomly splits them into two segments.

With 50% probability, the second sentence is replaced with a random sentence from another document. This large if-block creates the NSP task label (recorded in is_random_next) and sampled a sentence from another document.

At each iteration’s end, chunks is updated to retain unused portions. The document is exhausted when this list empties. Both sentence_a and sentence_b are lists of integer token IDs.

This approach follows Google’s original BERT implementation, though it doesn’t exhaust all possible combinations. The pairs created above may be longer than the target sequence length and you need to truncate them. The truncation is implemented as follows:

def truncate_seq_pair(sentence_a, sentence_b, max_num_tokens):

    while len(sentence_a) + len(sentence_b) > max_num_tokens:

        # pick the longer sentence to remove tokens from

        candidate = sentence_a if len(sentence_a) > len(sentence_b) else sentence_b

        # remove one token from either end in equal probabilities

        if random.random() < 0.5:

            candidate.pop(0)

        else:

            candidate.pop()

Truncation is applied iteratively to the longer sentence until the total length falls below the target. Tokens are removed from either end with equal probability. The result may be a chunk from the middle of the original sentence, hence the naming convention “chunk” in the code above.

Masking Tokens

Masked tokens are the most critical part of BERT training data. The original paper specifies that 15% of tokens are masked. Actually, this means the model is trained to predict the tokens on only 15% of its output. Within this 15%, the token could be:

  • 80% of the time, the token is replaced with the [MASK] token.
  • 10% of the time, the token is replaced with a random token from the vocabulary.
  • 10% of the time, the token is unchanged.

The model must correctly predict the original token in all cases. After creating pairs, let’s implement masking as follows:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

def create_sample(sentence_a, sentence_b, is_random_next, tokenizer,

                  max_seq_length=512, mask_prob=0.15, max_predictions_per_seq=20):

    # Collect id of special tokens

    cls_id = tokenizer.token_to_id(“[CLS]”)

    sep_id = tokenizer.token_to_id(“[SEP]”)

    mask_id = tokenizer.token_to_id(“[MASK]”)

    pad_id = tokenizer.padding[“pad_id”]

    # adjust length to fit the max sequence length

    truncate_seq_pair(sentence_a, sentence_b, max_seq_length3)

    num_pad = max_seq_length len(sentence_a) len(sentence_b) 3

    # create unmodified tokens sequence

    tokens = [cls_id] + sentence_a + [sep_id] + sentence_b + [sep_id] + ([pad_id] * num_pad)

    seg_id = [0] * (len(sentence_a) + 2) + [1] * (len(sentence_b) + 1) + [1] * num_pad

    assert len(tokens) == len(seg_id) == max_seq_length

    # create the prediction targets

    cand_indices = [i for i, tok in enumerate(tokens) if tok not in [cls_id, sep_id, pad_id]]

    random.shuffle(cand_indices)

    num_predictions = int(round((len(sentence_a) + len(sentence_b)) * mask_prob))

    num_predictions = min(max_predictions_per_seq, max(1, num_predictions))

    mlm_positions = sorted(cand_indices[:num_predictions])

    mlm_labels = []

    for i in mlm_positions:

        mlm_labels.append(tokens[i])

        # prob 0.8 replace with [MASK], prob 0.1 replace with random word, prob 0.1 keep original

        if random.random() < 0.8:

            tokens[i] = mask_id

        elif random.random() < 0.5:

            tokens[i] = random.randint(4, tokenizer.get_vocab_size()1)

    # randomly mask some tokens

    ret = {

        “tokens”: tokens,

        “segment_ids”: seg_id,

        “is_random_next”: is_random_next,

        “masked_positions”: mlm_positions,

        “masked_labels”: mlm_labels,

    }

    return ret

This function creates a token sequence: [CLS] [SEP] [SEP], where and are the sentence pair with masked tokens. The IDs of the special tokens come from the tokenizer.

First, the sentence pair is truncated to fit the maximum sequence length, reserving space for three special tokens. The sequence is then padded to the expected length. Segment labels are created to distinguish sentences: 0 for the first sentence, 1 for the second, and -1 for padding.

All non-special tokens are masking candidates. Their indices are shuffled, and the first num_predictions positions are selected. This number depends on mask_prob (default 15%) and is capped at max_predictions_per_seq (default 20):

num_predictions = int(round((len(sentence_a) + len(sentence_b)) * mask_prob))

num_predictions = min(max_predictions_per_seq, max(1, num_predictions))

The variable mlm_positions is a list of the indices of the masked positions, in ascending order. The variable mlm_labels is a list of the original tokens at the masked positions. When a random token from the vocabulary is needed, you pick one from the tokenizer:

tokens[i] = random.randint(4, tokenizer.get_vocab_size()1)

The first four tokens in the vocabulary are special tokens. They will not be selected for masking. The dictionary ret to return is indeed the “sample” that you will use to train the BERT model.

Saving the Training Data for Reuse

So far, you learned how to process the raw dataset into sentence pairs with masked tokens for the MLM and NSP training tasks. With the code above, you can create a list of dictionaries as the training data. However, this may not be the best way you want to serve the training loop.

The standard way to serve the training data in PyTorch code is to use the Dataset class. That is, to define a class like the following:

class CustomDataset(torch.utils.data.Dataset):

    def __init__(self, data):

        self.data = data

    def __len__(self):

        return len(self.data)

    def __getitem__(self, idx):

        return self.data[idx]

The key is to provide the __len__ and __getitem__ methods to return the total number of samples and the sample at the given index in the dataset respectively. However, this may not be the best for BERT training because you would likely need to load the entire dataset into memory at once when you initialize the dataset class. This is not efficient when the dataset is large.

An alternative is to use the Dataset class from the Hugging Face datasets library. It hides many data management details so that you can focus on more important things. Let’s assume you created a generator function that yields samples:

def create_dataset(docs, tokenizer):

    ...

    for doc in docs:

        for sample in create_samples(doc):

            yield sample

You can create a dataset object with:

...

from datasets import Dataset

dataset = Dataset.from_generator(create_dataset, gen_kwargs={“docs”: docs, “tokenizer”: tokenizer})

dataset.to_parquet(“wikitext-103_train_data.parquet”)

These two lines of code will pull all samples from the generator function and then save them to a file of parquet format. Depends on the size of your dataset, this may take a while. The gen_kwargs is a dictionary of keyword arguments to pass to the generator function. It should match how you defined the function.

Once you have the dataset in parquet format, you can load it back, and try to print a few samples:

...

dataset = Dataset.from_parquet(“wikitext-103_train_data.parquet”, streaming=True)

for i, sample in enumerate(dataset):

    if i > 5:

        break

    print(sample)

This is where the parquet format shines. The Hugging Face datasets library also supports JSON and CSV formats. But parquet is a compressed columnar format that is more efficient for data storage and retrieval. Setting streaming=True is optional. This allows you to load only the part of the dataset that you are using rather than loading the entire dataset into memory at once.

Putting everything together, this is the complete code:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

“”“Process the WikiText dataset for training the BERT model. Using Hugging Face

datasets library.

“”

import time

import random

from typing import Iterator

import tokenizers

from datasets import load_dataset, Dataset

# path and name of each dataset

DATASETS = {

    “wikitext-2”: (“wikitext”, “wikitext-2-raw-v1”),

    “wikitext-103”: (“wikitext”, “wikitext-103-raw-v1”),

}

PATH, NAME = DATASETS[“wikitext-103”]

TOKENIZER_PATH = “wikitext-103_wordpiece.json”

def create_docs(path: str, name: str, tokenizer: tokenizers.Tokenizer) -> list[list[list[int]]]:

    “”“Load wikitext dataset and extract text as documents”“”

    dataset = load_dataset(path, name, split=“train”)

    docs: list[list[list[int]]] = []

    for line in dataset[“text”]:

        line = line.strip()

        if not line or line.startswith(“=”):

            docs.append([])   # new document encountered

        else:

            tokens = tokenizer.encode(line).ids

            docs[1].append(tokens)

    docs = [doc for doc in docs if doc]  # remove empty documents

    return docs

def create_dataset(

    docs: list[list[list[int]]],

    tokenizer: tokenizers.Tokenizer,

    max_seq_length: int = 512,

    doc_repeat: int = 10,

    mask_prob: float = 0.15,

    short_seq_prob: float = 0.1,

    max_predictions_per_seq: int = 20,

) -> Iterator[dict]:

    “”“Generate samples from all documents”“”

    doc_indices = list(range(len(docs))) * doc_repeat

    for doc_idx in doc_indices:

        yield from generate_samples(doc_idx, docs, tokenizer, max_seq_length, mask_prob, short_seq_prob, max_predictions_per_seq)

def generate_samples(

    doc_idx: int,

    all_docs: list[list[list[int]]],

    tokenizer: tokenizers.Tokenizer,

    max_seq_length: int = 512,

    mask_prob: float = 0.15,

    short_seq_prob: float = 0.1,

    max_predictions_per_seq: int = 20,

) -> Iterator[dict]:

    “”“Generate samples from a given document”“”

    # number of tokens to extract from this doc, excluding [CLS], [SEP], [SEP]

    target_length = max_seq_length 3

    if random.random() < short_seq_prob:

        # shorter sequence is used 10% of the time

        target_length = random.randint(2, target_length)

    # copy the document

    chunks = []

    for chunk in all_docs[doc_idx]:

        chunks.append(chunk)

    # exhaust chunks and create samples

    while chunks:

        # scan until target token length

        running_length = 0

        end = 1

        while end < len(chunks) and running_length < target_length:

            running_length += len(chunks[end1])

            end += 1

        # randomly separate the chunk into two segments

        sep = random.randint(1, end1) if end > 1 else 1

        sentence_a = [tok for chunk in chunks[:sep] for tok in chunk]

        sentence_b = [tok for chunk in chunks[sep:end] for tok in chunk]

        # sentence B: may be from another document

        if not sentence_b or random.random() < 0.5:

            # find another document (must not be the same as doc_idx)

            b_idx = random.randint(0, len(all_docs)2)

            if b_idx >= doc_idx:

                b_idx += 1

            # sentence B starts from a random position in the new document

            sentence_b = []

            running_length = len(sentence_a)

            i = random.randint(0, len(all_docs[b_idx])1)

            while i < len(all_docs[b_idx]) and running_length < target_length:

                sentence_b.extend(all_docs[b_idx][i])

                running_length += len(all_docs[b_idx][i])

                i += 1

            is_random_next = True

            chunks = chunks[sep:]

        else:

            is_random_next = False

            chunks = chunks[end:]

        # create a sample from the pair

        yield create_sample(sentence_a, sentence_b, is_random_next, tokenizer, max_seq_length, mask_prob, max_predictions_per_seq)

def create_sample(

    sentence_a: list[list[int]],

    sentence_b: list[list[int]],

    is_random_next: bool,

    tokenizer: tokenizers.Tokenizer,

    max_seq_length: int = 512,

    mask_prob: float = 0.15,

    max_predictions_per_seq: int = 20,

) -> dict:

    “”“Create a sample from a pair of sentences”“”

    # Collect id of special tokens

    cls_id = tokenizer.token_to_id(“[CLS]”)

    sep_id = tokenizer.token_to_id(“[SEP]”)

    mask_id = tokenizer.token_to_id(“[MASK]”)

    pad_id = tokenizer.padding[“pad_id”]

    # adjust length to fit the max sequence length

    truncate_seq_pair(sentence_a, sentence_b, max_seq_length3)

    num_pad = max_seq_length len(sentence_a) len(sentence_b) 3

    # create unmodified tokens sequence

    tokens = [cls_id] + sentence_a + [sep_id] + sentence_b + [sep_id] + ([pad_id] * num_pad)

    seg_id = [0] * (len(sentence_a) + 2) + [1] * (len(sentence_b) + 1) + [1] * num_pad

    assert len(tokens) == len(seg_id) == max_seq_length

    # create the prediction targets

    cand_indices = [i for i, tok in enumerate(tokens) if tok not in [cls_id, sep_id, pad_id]]

    random.shuffle(cand_indices)

    num_predictions = int(round((len(sentence_a) + len(sentence_b)) * mask_prob))

    num_predictions = min(max_predictions_per_seq, max(1, num_predictions))

    mlm_positions = sorted(cand_indices[:num_predictions])

    mlm_labels = []

    for i in mlm_positions:

        mlm_labels.append(tokens[i])

        # prob 0.8 replace with [MASK], prob 0.1 replace with random word, prob 0.1 keep original

        if random.random() < 0.8:

            tokens[i] = mask_id

        elif random.random() < 0.5:

            tokens[i] = random.randint(4, tokenizer.get_vocab_size()1)

    # randomly mask some tokens

    ret = {

        “tokens”: tokens,

        “segment_ids”: seg_id,

        “is_random_next”: is_random_next,

        “masked_positions”: mlm_positions,

        “masked_labels”: mlm_labels,

    }

    return ret

def truncate_seq_pair(sentence_a: list[int], sentence_b: list[int], max_num_tokens: int) -> None:

    “”“Truncate a pair of sequences until below a maximum sequence length.”“”

    while len(sentence_a) + len(sentence_b) > max_num_tokens:

        # pick the longer sentence to remove tokens from

        candidate = sentence_a if len(sentence_a) > len(sentence_b) else sentence_b

        # remove one token from either end in equal probabilities

        if random.random() < 0.5:

            candidate.pop(0)

        else:

            candidate.pop()

if __name__ == “__main__”:

    print(time.time(), “started”)

    tokenizer = tokenizers.Tokenizer.from_file(TOKENIZER_PATH)

    print(time.time(), “loaded tokenizer”)

    docs = create_docs(PATH, NAME, tokenizer)

    print(time.time(), “created docs with %d documents” % len(docs))

    dataset = Dataset.from_generator(create_dataset, gen_kwargs={“docs”: docs, “tokenizer”: tokenizer})

    print(time.time(), “created dataset from generator”)

    # Save dataset to parquet file

    dataset.to_parquet(“wikitext-103_train_data.parquet”)

    print(time.time(), “saved dataset to parquet file”)

    # Load dataset from parquet file

    dataset = Dataset.from_parquet(“wikitext-103_train_data.parquet”, streaming=True)

    print(time.time(), “loaded dataset from parquet file”)

    # Print a few samples

    for i, sample in enumerate(dataset):

        print(i)

        print(sample)

        print()

        if i >= 3:

            break

    print(time.time(), “finished”)

Running this code, you will see the output like the following:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

1763913652.5099447 started

1763913652.5830114 loaded tokenizer

1763913817.1271229 created docs with 268854 documents

Generating train split: 4268307 examples [11:42:36, 101.25 examples/s]

Loading dataset shards: 100%|█████████████████████████████| 73/73 [00:10<00:00, 7.05it/s]

1763956111.2021146 created dataset from generator

Creating parquet from Arrow format: 100%|█████████████| 4269/4269 [06:13<00:00, 11.42ba/s]

1763956487.0040812 saved dataset to parquet file

Generating train split: 4268307 examples [06:22, 11168.96 examples/s]

Loading dataset shards: 100%|█████████████████████████████| 74/74 [00:09<00:00, 8.20it/s]

1763956881.6215432 loaded dataset from parquet file

{‘tokens’: [1, 10887, 4875, …, 0, 0], ‘segment_ids’: [0, 0, …, 1, 1, …, -1, -1],

‘is_random_next’: True, ‘masked_positions’: [29, 58, …],

‘masked_labels’: [15, 8551, …]}

1

{‘tokens’: [1, 8792, 9150, …, 0, 0], ‘segment_ids’: [0, 0, …, 1, 1, …, -1, -1],

‘is_random_next’: True, ‘masked_positions’: [15, 19, …],

‘masked_labels’: [8522, 9100, …]}

2

{‘tokens’: [1, 8506, 8556, …, 0, 0], ‘segment_ids’: [0, 0, …, 1, 1, …, -1, -1],

‘is_random_next’: False, ‘masked_positions’: [3, 8, …],

‘masked_labels’: [19367, 29188, …]}

3

{‘tokens’: [1, 8544, 8910, …, 0, 0], ‘segment_ids’: [0, 0, …, 1, 1, …, -1, -1],

‘is_random_next’: False, ‘masked_positions’: [13, 16, …],

‘masked_labels’: [8656, 12114, …]}

1763956881.6688802 finished

The timestamps printed intermittently are intentional to show the time spent. This code processes the Wikitext-103 dataset and it takes several hours. Once complete, the parquet file enables fast, efficient iteration over samples. For testing, you can use the smaller Wikitext-2 dataset instead. You can see how the code runs in a few minutes.

Further Readings

Below are some resources that you may find useful:

Summary

In this article, you learned how to prepare the data for BERT training. You learned how to create masked language model (MLM) training data and next sentence prediction (NSP) training data. You also learned how to save the data in parquet format for reuse.

No comments yet.