通俗易懂解读BPE分词算法实现
更好的阅读体验请访问 通俗易懂解读BPE分词算法实现 获得:
BPE (Byte Pair Encoding)
BPE(Byte Pair Encoding,字节对编码)是一种基于频率统计的子词分词算法 ,广泛用于现代自然语言处理任务中,特别是在像 BERT、GPT 和 LLaMA 这样的大模型中。它的核心思想是通过不断合并最常见的字符对来构建一个高效的词汇表。
BPE 的核心思想:
-
从字符级别开始,逐步合并高频的字符对。
-
最终生成一个既能表示常见单词,又能拆解未知词的子词词汇表 。
-
可以有效控制词汇表大小,同时避免“未登录词”问题(OOV, Out-of-Vocabulary)。
预训练过程
BPE 算法预训练工作流程:
训练语料为: Hello World , Hey Wow
1. 读取训练语料,同时完成断句分词任务
# filepaths: 训练语料所在的文件列表
def create_vocab(filepaths: List[str]) -> Dict[str, int]:# 获取所有单词和每个单词的出现次数词典vocab = defaultdict(int)for path in tqdm(filepaths, desc='Creating vocabulary'):text = open(path, 'r', encoding='utf-8-sig').read()# 利用NLTK库提供的sent_tokenize方法完成断句功能,即将原文本按照空格,句号等标点符号结合语义进行断句。sentences = sent_tokenize(text)# 遍历句子列表for sentence in sentences:# 利用NLTK库提供的wordpunct_tokenize方法完成分词功能tokens = wordpunct_tokenize(sentence)# 记录每个词的出现次数 for token in tokens:vocab[token] += 1# vocab: 记录每个词的出现次数的词典return vocab
2. 过滤掉vocab中的低频词
def truncate_vocab(vocab: Dict[str, int], mincount: int) -> None:tokens = list(vocab.keys())for token in tokens:if vocab[token] < mincount:del(vocab[token])
示例中设置为了1,不会过滤掉任何词。
3. 数据预处理
- 将训练语料中的每个单词按字符拆分,并在结尾加上特殊标记
</w>
表示单词结束。
def prepare_bpe_vocab(vocab: Dict[str, int]) -> Dict[str, int]:bpe_vocab = {}# 遍历vocab中所有词for token in vocab:# 每个词的每个字符后都加上空格,同时末尾加上 </w> 表示单词结束ntoken = ' '.join(list(token)) + ' </w>'bpe_vocab[ntoken] = vocab[token]return bpe_vocab
4. 经历N次迭代,合并前N个最频繁的字符对
# 一共合并merges个高频字符对后,才结束词汇表的构建for i in trange(merges, desc='Merging'):# 1. 获取每个相邻字符对的出现次数pairs = get_stats(vocab)# 2. 获取当前最高频的字符对best = max(pairs, key=pairs.get)# 3. 合并当前最高频的字符对vocab = merge_vocab(best, vocab)######记录历史合并的最高频子词对及其频率(传统BPE算法没有这一步)######merged_pair_freqs = defaultdict(int)# 一共合并merges个高频字符对后,才结束词汇表的构建 for _ in trange(merges, desc='Merging'):# 1. 获取每个相邻字符对的出现次数pairs = get_stats(vocab)# 2. 获取当前最高频的字符对best_pair = max(pairs.items(), key=lambda x: x[1])######记录该子词对的全局频率(传统BPE算法没有这一步)######best_subword = ''.join(best_pair[0])best_freq = best_pair[1] merged_pair_freqs[best_subword] += best_freq# 3. 合并当前最高频的字符对vocab = merge_vocab(best_pair[0], vocab)
4.1 获取每个相邻字符对的出现次数
def get_stats(vocab: Dict[str, int]) -> Dict[Tuple[str, str], int]:pairs = defaultdict(int)for word, freq in vocab.items():# 对经过预处理的vocab中的每个词按空格进行切分symbols = word.split()# 统计每个相邻字符对的出现次数for i in range(len(symbols)-1):pairs[symbols[i],symbols[i+1]] += freqreturn pairs
4.2 获取当前最高频的字符对
4.3 合并当前最高频的字符对
def merge_vocab(pair: Tuple[str, str], v_in: Dict[str, int]) -> Dict[str, int]:# 1. 将传入的最高频字符对中的两个字符用空格拼接起来,如: "H e"bigram = re.escape(' '.join(pair))v_out = {}# 2. 正则匹配含有“H e”的所有单词,并且“H”和“e”必须为两个独立的词,而不能为"HH e"或者"H ee"形式p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')# 3. 遍历vocab中所有词for word in v_in:# 3.1 用正则匹配并替换匹配上的 "H e" 为 “He”w_out = p.sub(''.join(pair), word)v_out[w_out] = v_in[word]# 4. 返回合并最高频字符对后的vocabreturn v_out
5.根据N轮迭代合并后的Vocab来构建最终的频次表(每个子词的出现次数)
def count_byte_freqs(vocab: Dict[str, int]) -> Dict[str, int]:freqs = defaultdict(int)for word in vocab:# 1. 按空格切分bytes_ = word.split(' ')# 2. 每个子词出现次数加1for byte in bytes_:freqs[byte] += 1# 3. 添加一些特殊词 for token in ['<line/>', '</line>', '<pad>', '<unk>']:freqs[token] += 1return freqs
6.根据频次表构建最终的词汇表
def create_vocab_maps(freqs: Dict[str, int]) -> (Dict[str, int], Dict[int, str]):# 1. 按照 词频从高到低 的顺序排序ordered_freqs = sorted(freqs.items(), key=lambda x: x[1], reverse=True)vocab_to_idx, idx_to_vocab = {}, {}for i in range(len(ordered_freqs)):# 2. 构建词汇表word, freq = ordered_freqs[i]vocab_to_idx[word] = iidx_to_vocab[i] = wordreturn vocab_to_idx, idx_to_vocab
7. freqs = 最终子词频率 + 历史最高频合并对的频率(传统BPE算法没有这一步)
freqs.update(merged_pair_freqs)
8. 通常最后会将预训练生成的频次表和词汇表写入文件保存
def save(self, path: str) -> None:# 1. 频次表记录合并规则,也就是有哪些子词以及这些子词的出现次数,作为分词时的合并规则和优先选择权with open(f'{path}/freqs.json', 'w', encoding='utf-8') as outfile:json.dump(self.freqs, outfile, indent=4, ensure_ascii=False)# 2. 常规的词汇表with open(f'{path}/vocab_to_idx.json', 'w', encoding='utf-8') as outfile:json.dump(self.vocab_to_idx, outfile, indent=4, ensure_ascii=False)with open(f'{path}/idx_to_vocab.json', 'w', encoding='utf-8') as outfile:json.dump(self.idx_to_vocab, outfile, indent=4, ensure_ascii=False)
BPE 算法预训练过程完整代码如下
def train_bpe(filepaths: List[str], mincount: int, merges: int) -> 'BytePairTokenizer':vocab = create_vocab(filepaths)truncate_vocab(vocab, mincount)vocab = prepare_bpe_vocab(vocab)merged_pair_freqs = defaultdict(int) # (传统BPE算法没有这一步)for _ in trange(merges, desc='Merging'):pairs = get_stats(vocab)best_pair = max(pairs.items(), key=lambda x: x[1])best_subword = ''.join(best_pair[0]) # (传统BPE算法没有这一步)best_freq = best_pair[1] # (传统BPE算法没有这一步)merged_pair_freqs[best_subword] += best_freq # (传统BPE算法没有这一步)vocab = merge_vocab(best_pair[0], vocab)freqs = count_byte_freqs(vocab)vocab_to_idx, idx_to_vocab = create_vocab_maps(freqs)freqs.update(merged_pair_freqs) # (传统BPE算法没有这一步)return BytePairTokenizer(freqs, vocab_to_idx, idx_to_vocab)
分词过程
1.对输入的文本进行断句加分词
# 使用NLTK库提供的sent_tokenize方法进行分词lines = sent_tokenize(open(filepath, encoding='utf-8-sig').read())tokens = []# 遍历所有句子for line in lines:if len(line) > 1:tokens += get_line_ids(line, tokenizer)
def get_line_ids(line: str, tokenizer: BytePairTokenizer) -> List[int]:# 对每个句子进行分词tokens = wordpunct_tokenize(line)# 将每个词从str转换为list列表形式,同时列表末尾追加</w>tokens = [list(t) + ['</w>'] for t in tokens]...
以输入 “Hello World” 为例
2. 对当前句子中每个词进行子词合并加词ID映射,最后得到当前句子对应的Token列表
def get_line_ids(line: str, tokenizer: BytePairTokenizer) -> List[int]:...lineids = []for token in tokens:# 2.1 对每个词进行子词合并,直到无法合并为止token = tokenizer.merge_bytes(token)# 2.2 将当前词列表中每个子词映射为字典中对于的词IDids = tokenizer.get_byte_ids(token)lineids += idssol_id = tokenizer.get_byte_id('<line/>')eol_id = tokenizer.get_byte_id('</line>')lineids = [sol_id] + lineids + [eol_id]return lineids
2.1 对每个词进行子词合并,直到无法合并为止
# 对当前词的子词进行合并,直到无法合并为止def merge_bytes(self, bytes_: List[str]) -> List[str]:bytes_, merged = self.merge_max_pair(bytes_)while merged:bytes_, merged = self.merge_max_pair(bytes_)return bytes_ def merge_max_pair(self, bytes_: List[str]) -> (List[str], bool):# 1. 取出出现次数最多的字符对max_pair = self.get_max_pair_idxs(bytes_)merged = True if max_pair is not None else Falseif merged:# 2. 合并该字符对bytes_ = bytes_[:max_pair[0]] + \[''.join(bytes_[max_pair[0]:max_pair[1]+1])] + \bytes_[max_pair[1]+1:]return bytes_, mergeddef get_max_pair_idxs(self, bytes_) -> Tuple[int, int]:pairs = {}# 1. 遍历所有相邻字符对的组合for i in range(1, len(bytes_)):pair = ''.join(bytes_[i-1:i+1])# 2. 判断每个字符对是否存在于频次表中,如果存在记录出现次数if pair in self.freqs:pairs[(i-1, i)] = self.freqs[pair]# 3. 取出出现次数最多的字符对return None if len(pairs) == 0 else max(pairs, key=pairs.get)
2.2 将当前词列表中每个子词映射为字典中对于的词ID
def get_byte_ids(self, bytes_):ids = []for byte in bytes_:if byte in self.vocab_to_idx:ids.append(self.vocab_to_idx[byte])else:ids.append(self.vocab_to_idx[self.unk])return ids
附录
BPE 分词器完整代码实现:
from typing import Tuple, Dict, List
from collections import defaultdict
import json, refrom nltk import wordpunct_tokenize, sent_tokenize
from tqdm import trange, tqdmclass BytePairTokenizer:def __init__(self, freqs: Dict[str, int], vocab_to_idx: Dict[str, int],idx_to_vocab: Dict[int, str]):""" Initialize byte pair tokenizerArgs:freqs: frequency dictionary of vocabularyvocab_to_index: map of vocabulary words to indicesindex_to_vocab: map of vocabulary indices to words"""self.vocab_to_idx = vocab_to_idxself.idx_to_vocab = idx_to_vocabself.freqs = freqsself.sol = '<line/>'self.eol = '</line>'self.pad = '<pad>'self.unk = '<unk>'self.eow = '</w>'def get_sol(self) -> str:return self.soldef get_eol(self) -> str:return self.eoldef get_pad(self) -> str:return self.paddef get_unk(self) -> str:return self.unkdef get_eow(self) -> str:return self.eowdef get_byte(self, byte_id: int) -> str:return self.idx_to_vocab[byte_id]def get_byte_id(self, byte: str) -> int:unk_id = self.vocab_to_idx[self.unk]bid = self.vocab_to_idx[byte] if byte in self.vocab_to_idx else unk_idreturn biddef get_byte_ids(self, bytes_):""" Get byte ids for each byte in provided list"""ids = []for byte in bytes_:if byte in self.vocab_to_idx:ids.append(self.vocab_to_idx[byte])else:ids.append(self.vocab_to_idx[self.unk])return idsdef get_bytes(self, byte_ids: List[int]) -> List[str]:""" Given a list of byte ids return corresponding bytesArgs:byte_ids: list of byte idsReturns:(List[str]): list of bytes"""tokens = []for byte_id in byte_ids:tokens.append(self.idx_to_vocab[byte_id])return tokensdef merge_bytes(self, bytes_: List[str]) -> List[str]:""" Return list of bytes with max pair mergedArgs:bytes_: list to merge max pair inReturns:(List[str]): list of bytes with all max pair occurrences merged"""bytes_, merged = self.merge_max_pair(bytes_)while merged:bytes_, merged = self.merge_max_pair(bytes_)return bytes_ def merge_max_pair(self, bytes_: List[str]) -> (List[str], bool):""" Takes in a list of bytes and merges the max pair if possibleArgs:bytes_: list of bytes to merge max pair inReturns:(bytes_): list of bytes with max pair merged(bool): flag indicating whether merge occurred"""max_pair = self.get_max_pair_idxs(bytes_)merged = True if max_pair is not None else Falseif merged:bytes_ = bytes_[:max_pair[0]] + \[''.join(bytes_[max_pair[0]:max_pair[1]+1])] + \bytes_[max_pair[1]+1:]return bytes_, mergeddef get_max_pair_idxs(self, bytes_) -> Tuple[int, int]:""" Get index of maximum byte pair in list of bytesArgs:bytes_: list of bytes to find maximum pair fromReturns:(Tuple[int, int]): maximum frequency byte pair"""pairs = {}for i in range(1, len(bytes_)):pair = ''.join(bytes_[i-1:i+1])if pair in self.freqs:pairs[(i-1, i)] = self.freqs[pair]return None if len(pairs) == 0 else max(pairs, key=pairs.get) def save(self, path: str) -> None:with open(f'{path}/freqs.json', 'w', encoding='utf-8') as outfile:json.dump(self.freqs, outfile, indent=4, ensure_ascii=False)with open(f'{path}/vocab_to_idx.json', 'w', encoding='utf-8') as outfile:json.dump(self.vocab_to_idx, outfile, indent=4, ensure_ascii=False)with open(f'{path}/idx_to_vocab.json', 'w', encoding='utf-8') as outfile:json.dump(self.idx_to_vocab, outfile, indent=4, ensure_ascii=False)@staticmethoddef load(path: str) -> 'BytePairTokenizer':with open(f'{path}/freqs.json', 'r', encoding='utf-8') as infile:freqs = json.load(infile)with open(f'{path}/vocab_to_idx.json', 'r', encoding='utf-8') as infile:vocab_to_idx = json.load(infile)with open(f'{path}/idx_to_vocab.json', 'r', encoding='utf-8') as infile:idx_to_vocab = json.load(infile)return BytePairTokenizer(freqs, vocab_to_idx, idx_to_vocab)@staticmethoddef train_bpe(filepaths: List[str], mincount: int, merges: int) -> 'BytePairTokenizer':vocab = create_vocab(filepaths)truncate_vocab(vocab, mincount)vocab = prepare_bpe_vocab(vocab)merged_pair_freqs = defaultdict(int)for _ in trange(merges, desc='Merging'):pairs = get_stats(vocab)if not pairs:breakbest_pair = max(pairs.items(), key=lambda x: x[1])best_subword = ''.join(best_pair[0])best_freq = best_pair[1]merged_pair_freqs[best_subword] += best_freqvocab = merge_vocab(best_pair[0], vocab)freqs = count_byte_freqs(vocab)vocab_to_idx, idx_to_vocab = create_vocab_maps(freqs)freqs.update(merged_pair_freqs)return BytePairTokenizer(freqs, vocab_to_idx, idx_to_vocab)def create_vocab(filepaths: List[str]) -> Dict[str, int]:""" Create dictionary of vocabulary frequencies in given list of filesArgs:filepaths: list of filepaths to collect vocabulary fromReturns:(Dict[str, int]): dictionary mapping vocabulary terms to their frequency """vocab = defaultdict(int)for path in tqdm(filepaths, desc='Creating vocabulary'):text = open(path, 'r', encoding='utf-8-sig').read()sentences = sent_tokenize(text)for sentence in sentences:tokens = wordpunct_tokenize(sentence)for token in tokens:vocab[token] += 1return vocabdef truncate_vocab(vocab: Dict[str, int], mincount: int) -> None:""" Truncate vocabulary dictionary based on a minimum countArgs:vocab: frequency mapping dictionary to truncatemincount: minimum count for members of dictionary (words with lowerfrequencies will be removed)"""tokens = list(vocab.keys())for token in tokens:if vocab[token] < mincount:del(vocab[token])def prepare_bpe_vocab(vocab: Dict[str, int]) -> Dict[str, int]:""" Prepare vocabulary frequency dictionary for byte-pair generation.End-of-word byte '</w>' added to words, every character separated by spaceArgs:vocab: vocabulary frequency dictionary to prepareReturns:(Dict[str, int]): byte-pair ready vocabulary frequency dictionary"""bpe_vocab = {}for token in vocab:ntoken = ' '.join(list(token)) + ' </w>'bpe_vocab[ntoken] = vocab[token]return bpe_vocabdef get_stats(vocab: Dict[str, int]) -> Dict[Tuple[str, str], int]:""" Count all bytepairs in a dictionary containing vocabulary frequenciesArgs:vocab: dictionary mapping words to their frequencyReturns:(Dict[Tuple[str, str], int]): dictionary containing byte pairfrequencies"""pairs = defaultdict(int)for word, freq in vocab.items():symbols = word.split()for i in range(len(symbols)-1):pairs[symbols[i],symbols[i+1]] += freqreturn pairsdef merge_vocab(pair: Tuple[str, str], v_in: Dict[str, int]) -> Dict[str, int]:""" Merge all instances of given byte pair in vocabulary frequencydictionaryArgs:pair: byte pair to mergev_in: vocabulary to merge byte pair intReturns:(Dict[str, int]): resulting vocabulary with all instances of given bytepair merged"""bigram = re.escape(' '.join(pair))v_out = {}p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')for word in v_in:w_out = p.sub(''.join(pair), word)v_out[w_out] = v_in[word]return v_outdef count_byte_freqs(vocab: Dict[str, int]) -> Dict[str, int]:freqs = defaultdict(int)for word in vocab:bytes_ = word.split(' ')for byte in bytes_:freqs[byte] += 1for token in ['<line/>', '</line>', '<pad>', '<unk>']:freqs[token] += 1return freqsdef create_vocab_maps(freqs: Dict[str, int]) -> (Dict[str, int], Dict[int, str]):""" Create map of vocabulary terms to indices and vice versa. Word indicesare in order of their frequency in the provided vocabulary Args:freqs: dictionary mapping vocabulary terms to their frequenciesReturns:(Dict[str, int]): dictionary mapping vocab to indices(Dict[int, str]): dictionary mapping indices to vocab"""ordered_freqs = sorted(freqs.items(), key=lambda x: x[1], reverse=True)vocab_to_idx, idx_to_vocab = {}, {}for i in range(len(ordered_freqs)):word, freq = ordered_freqs[i]vocab_to_idx[word] = iidx_to_vocab[i] = wordreturn vocab_to_idx, idx_to_vocab