# 概述
在自然语言处理领域,分词器(Tokenizer) 承担着至关重要的 “桥梁” 作用:它将人类的自然语言(字符串)转化为大型语言模型能够识别和处理的整数序列。该整数序列就是大模型的输入,而 token 指的就是这些整数数字。
在众多分词策略中,字节对编码(Byte Pair Encoding, BPE) 以其高效、可控的词汇表构建机制,成为了现代预训练模型(如 BERT,GPT,T5 等)的基石。BPE 成功解决了传统分词方法在处理海量词汇和未登录词时的痛点。
本篇博客基于 2025 斯坦福 CS336 课程的第一章节作业要求,整理 BPE 分词器的原理与实现流程,并从零开始实现一个基础的 BPE 分词器。
# Unicode 编码标准
介绍 BPE 算法之前,需要先了解计算机如何保存各种字符串的。
Unicode 是一个行业标准,为世界上所有字符(英文字母、中文、日文等)提供统一的表示。截止至 2024 年公布的 Unicode 16.0 版本已经定义了 154998 个字符的表示。它为每一个字符分配了一个独一无二的编号,即码点(Code Point)。Unicode 规范了字符集,而具体的存储和传输则依赖于不同的编码方案,其中最常用的是 UTF-8。UTF-8 属于变长编码,它使用 1 到 4 个字节来表示一个 Unicode 字符。对于 ASCII 字符(如英文字母和数字),它只占用 1 个字节。
Python 对 Unicode 和各种编码提供了强大且直观的支持。在实现 BPE 时,主要用到以下两个核心方法:
-
编码:字符串对象的
.encode()方法将 Unicode 字符串转成对应的字节串(bytes对象)。例如要获取一个字符串的 UTF-8 字节表示,可以使用string.encode(encoding='utf-8')。print("✅".encode(encoding='utf-8')) # b'\xe2\x9c\x85'
-
解码:使用字节串对象的
.decode()方法可以将字节串转换回 Unicode 字符串。例如按照 UTF-8 标准进行解码可以使用bytes_object.decode(encoding='utf-8')。print(b'\xe2\x9c\x85'.decode(encoding='utf-8')) # '✅'
# BPE Tokenizer 训练算法
由于所有 Unicode 字符串都可以表示为字节序列,因此一个只有 256 单元(1 字节可以表示 0-255)的词表就足以表示所有字符串。但是这个词表是非常低效。比如前面一小节中,一个 ✅ 字符就被编码成 3 个字节。即便按照平均标准,一个字符也需要约 2.5 个字节来表示,这对于现代大型语言模型的输入序列长度而言是不可承受的。Transformer 架构的核心特征使其时间复杂度与输入序列长度 的平方成正比,即 。并且,输入序列越长,训练和推理所需的计算资源(时间、GPU 显存)就越多,模型维护长期记忆和全局信息的能力越会受到限制。
BPE 算法通过学习语料库中的高频子词模式,将多个字节合并成一个具有语义信息的子词单元,从而大幅压缩序列长度。例如,如果在词表中学习到 b'\xe2\x9c\x85' ,那么只用一个整数就可以表示 ✅ 。
BPE 分词器中,每一个 token 都是一段字节序列,它可能是一个字节,一个单词,也可能是一个单词的部分字节等,因此也被称为子词(subword)。
下面,具体讲解 BPE 分词器的训练算法流程,可以分成三个步骤:词表初始化、预分词、合并子词。
# 词表初始化
首先,词表必须包含 个元素,与所有单字节一一对应,这 256 个字节构成了 BPE 的原子单元,保证了理论上任何 Unicode 字符串(通过 UTF-8 编码)都可以被表示和分解。
此外,我们还需要考虑特殊 token。所谓特殊 token 是模型可识别的,拥有特定语义功能的字符串(最后也用字节串表示),不会参与正常的 BPE 合并过程。比如 <cls> 、 <sep> 在 BERT 中就是表示分类和分隔的特殊 token。这些特殊 token 在后续的编码和训练过程中会被当作一个完整的单元进行特殊处理,因此在初始化时也需要加入到词表中。
# 预分词
在第三步合并子词时,我们会对语料库中频繁出现的相邻子词对进行合并。例如, AI 这个单词中 A 和 I 各自对应一个字节,由上述初始词表可以表示成两个 token。若 AI 这个单词在语料库中出现次数比较多,那么 A 和 I 在第三步就会被合并成新的 token。
那么预分词(Pre-tokenization)在干什么呢?我们一般期望,每个 token 都应该具有语义信息,比如上述的 AI 亦或者是和英语语法有关的 est 、 er 等等,而对于 y n 这种就不是一个好的 token 表示。但是,假设说,语料库中经常出现 "my name" 文本,那么在第三步合并子词时就有可能合并出 y n 这种 token。因此预分词的目的就是定义子词合并的合理边界,避免跨单词合并,从而保留语义结构。
CS336 第一章作业中,我们使用如下 GPT-2 的预分词正则表达式,它可以将一段话拆分成具有完整含义的子字符串序列:
import regex | |
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" | |
regex.findall(PAT, "你好世界, Hello World!") # [' 你好世界 ', ',', ' Hello', ' World', '!'] |
这里划分出来的子字符串列表就是就是下一步 BPE 合并操作的最小作用域。
预分词不仅可以保证合并出来的子词具有语义结构,还可以提高 BPE 训练的效率。在这一阶段,我们可以统计出每个预分词块(例如 Hello )在整个语料库中出现的总次数。在后续的合并计数中,我们只需将内部子词对的频率乘以该块的出现次数,即可避免对整个语料进行重复扫描,大大加速合并计数。
# 子词合并
子词合并就是 BPE 训练算法最核心的部分。
对于每个预分词后的子字符串,我们可以用当前的词表将其表示为一系列子词,进而得到整个语料库的子词表示。然后统计所有相邻的子词对的出现次数。注意,我们不统计预分词子字符串之间的子词对。然后将出现次数最多,字典序最大的子词对取出进行合并,得到新的子词添加到词表中,进而用新的子词替换原先出现的子词对。
BPE 中,子词就是 token,在词表中用字节串表示。
例如,在下面语料中,第一次合并时选择的子词对是 (b's', b't') ,进而我们将 b'st' 添加到词表中,然后将 [b'n', b'e', b'w', b'e', b's', b't'] 重新表示为 [b'n', b'e', b'w', b'e', b'st'] 。
low low low low low | |
lower lower widest widest widest | |
newest newest newest newest newest newest |
上述过程会不断迭代进行,直至词表达到预定义的大小。
# BPE Tokenizer 编码解码
在 BPE 算法训练结束后,我们得到了两份数据:
- 词汇表(
vocab): 包含了所有的基本字节和学习到的复合子词。 - 合并规则序列(
merges): 记录了 BPE 训练过程中,所有子词对被合并的顺序。
基于这两个数据,我们可以构建一个完整的 BPETokenizer 类,其核心功能是实现将自然语言转化为模型输入的 encode() 方法,以及反向的 decode() 方法。前者将给定的 Unicode 字符串转换为模型可识别的整数 Token ID 序列,后者则将给定的整数 Token ID 序列转换成 Unicode 字符串。
给定词表和待编码的 Unicode 字符串,其实有许多可行的编码方案。比如最简单的一种就是每个字节都编码成一个 token。但为了确保模型输入的一致性和效率,我们必须遵循一个最优的、确定的编码方案。按照 CS336 课程作业要求, BPETokenizer 的编码流程如下:
- 首先,使用与训练时完全相同的预分词方法,将输入字符串拆分成语义子字符串块,并将其转化为 UTF-8 字节序列,并初始化为一个字节一个 Token 的序列。
- 严格按照训练时学习到的
merges列表的顺序,从头到尾遍历每条合并规则。对于当前序列中的每一个预分词块,如果其中包含当前规则要合并的子词对,则进行合并,最终生成一个新的 Token ID 序列。 - 当
merges列表中的所有规则都应用完毕后,最终得到的 Token 序列就是我们期望的、压缩效率最高的编码结果。
上述编码流程完全仿照训练过程。这是因为 BPE 算法的本质是贪婪的,它学习到的 merges 列表反映了将序列压缩到词汇表大小限制内的最优、最频繁的合并路径。尽管对于单个待编码的 Unicode 字符串来说,通过这种方式编码得到的 Token 数目不一定是理论上最少的。但是,由于 merges 列表是基于整个大规模训练语料统计的最高频模式学习而来,因此在整个输入空间上,这种编码方式具有最优的期望压缩效率。
解码是一个相对直接的反向过程:遍历输入的整数序列,根据保存的 vocab 词汇表,查找每个 Token ID 对应的字节序列。将所有查找到的字节序列按顺序拼接成一个完整的字节串对象。最后使用字节串对象的 .decode(encoding='utf-8') 方法转成最终的 Unicode 字符串。
# 代码实现
本节将按照 2025 斯坦福 CS336 课程的作业要求,从零实现一个基础的 BPE 分词器,并补充相关的技术细节。
# 训练语料
下面是代码实现中用到的语料数据。
mkdir -p data | |
cd data | |
wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-train.txt | |
wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-valid.txt | |
wget https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_train.txt.gz | |
gunzip owt_train.txt.gz | |
wget https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_valid.txt.gz | |
gunzip owt_valid.txt.gz |
这些语料数据已经按照要求经过了预处理,可以直接用来训练 BPE Tokenizer。 valid 数据相比 train 数据少得多,用于快速验证算法的正确性。
# 目录结构
所有代码总共分成四个文件,如下所示:
- pretokenize.py | |
- train.py | |
- tokenizer.py | |
- utils.py |
pretokenize.py 保存与预分词相关的函数。
train.py 保存训练 BPE Tokenizer 的函数。
tokenizer.py 定义了 BPE Tokenizer 类,实现编码解码功能。
utils.py 保存了整个 BPE 算法实现过程中用到的数据结构。
# BPE 训练函数
# 工具模块
首先给出 utils.py 的文件内容,然后详细说明各部分的作用。
# utils.py | |
from typing import NamedTuple, List, Tuple | |
from dataclasses import dataclass, field | |
from collections import Counter, defaultdict | |
BYTES: List[bytes] = [bytes([i]) for i in range(256)] | |
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" | |
def split_bytes(data: bytes) -> Tuple[bytes]: | |
"""Split a bytestring into a list of single-byte bytestrings.""" | |
return tuple([BYTES[b] for b in data]) |
BYTES 保存了 256 个单字节对象,避免在分词时频繁创建新的字节对象,减少内存消耗。 PAT 是预分词使用的正则表达式字符串。
split_bytes 函数将字节串拆分成成单字节对象组成的元组。
# utils.py | |
class myHeap: | |
def __init__(self, mode="max", init_list=None): | |
self.mode = mode | |
self._init_data(init_list or []) | |
def _init_data(self, init_list) -> list: | |
self.data = list(init_list) | |
self.size = len(self.data) | |
if self.size == 0: | |
return | |
for i in range((len(self.data) - 1) // 2, -1, -1): | |
self._shift_down(i) | |
def _compare(self, a, b): | |
if self.mode == "max": | |
return a > b | |
else: | |
return a < b | |
def _shift_down(self, index): | |
"""Shift down the element at index to maintain heap property""" | |
while True: | |
left = 2 * index + 1 | |
right = 2 * index + 2 | |
maxIndex = index | |
if left < self.size and self._compare(self.data[left], self.data[maxIndex]): | |
maxIndex = left | |
if right < self.size and self._compare(self.data[right], self.data[maxIndex]): | |
maxIndex = right | |
if maxIndex != index: | |
self.data[index], self.data[maxIndex] = self.data[maxIndex], self.data[index] | |
index = maxIndex | |
else: | |
break | |
def push(self, item): | |
"""Push an item onto the heap""" | |
self.data.append(item) | |
self.size += 1 | |
index = self.size - 1 | |
while index > 0: | |
parent = (index - 1) // 2 | |
if self._compare(self.data[index], self.data[parent]): | |
self.data[index], self.data[parent] = self.data[parent], self.data[index] | |
index = parent | |
else: | |
break | |
def pop(self) -> any: | |
"""Pop the top item off the heap and return it""" | |
if self.size == 0: | |
return None | |
top_item = self.data[0] | |
self.data[0] = self.data[-1] | |
self.data.pop() | |
self.size -= 1 | |
if self.size > 0: | |
self._shift_down(0) | |
return top_item | |
def top(self) -> any: | |
"""Return the top item of the heap without removing it""" | |
return self.data[0] if self.size > 0 else None |
myHeap 类是手动实现的堆数据结构,支持大顶堆和小顶堆,用于高效获取频率最高,字典序最大的分词对。
# utils.py | |
class BytePair(NamedTuple): | |
left: bytes | |
right: bytes | |
@property | |
def merged_bytes(self): | |
return self.left + self.right | |
def __str__(self) -> str: | |
return f"({self.left.decode('utf-8', errors='ignore')} {self.right.decode('utf-8', errors='ignore')})" | |
def __repr__(self) -> str: | |
return self.__str__() |
BPE Tokenizer 中每个 token 都是一个字节串,每次合并时需要选择一个 token 对。为了方便管理,这里定义 BytePair 类,继承元组类型,表示 token 对,其中 merged_bytes 方法通过修饰器转换成属性,可以返回该 token 对合并后的字节串。
# utils.py | |
@dataclass | |
class WordRef: | |
tokens: list[bytes] | |
count: int | |
@property | |
def byte_pairs_dict(self) -> defaultdict[BytePair, int]: | |
bp_to_count = defaultdict(int) | |
for i in range(len(self.tokens) - 1): | |
bp = BytePair(self.tokens[i], self.tokens[i + 1]) | |
bp_to_count[bp] += self.count | |
return bp_to_count | |
def merge(self, bp_merged: BytePair): | |
"""merge the given byte pair in this word""" | |
new_tokens: list[bytes] = [] | |
i = 0 | |
while i < len(self.tokens): | |
if i < len(self.tokens) - 1 and self.tokens[i] == bp_merged.left and self.tokens[i + 1] == bp_merged.right: | |
new_tokens.append(bp_merged.merged_bytes) | |
i += 2 | |
else: | |
new_tokens.append(self.tokens[i]) | |
i += 1 | |
self.tokens = new_tokens | |
def __str__(self) -> str: | |
return f"{self.tokens} : {self.count}" | |
def __repr__(self) -> str: | |
return self.__str__() |
WordRef 类表示预分词后,每个有效的子字节串。属性 tokens 表示当前该字节串被表示成的 token 序列,属性 cnt 是预分词时统计该子字节串出现的次数。
由于 BPE 算法合并过程时需要统计 token 对数量,因此这里定义 byte_pairs_dict 返回该字节串的对应的 token 对统计结果。 merge 函数接受需要合并的 token 对,类型为 BytePair 对象,按要求合并,得到新的 token 序列。
# 预分词模块
介绍完 utils.py 的内容,再来讲讲预分词的代码如何实现。
预分词的功能说起来简单,只需要使用前文提到的正则表达式进行匹配得到结果即可。但难点在于如何兼备内存效率和计算效率。一般来说,BPE Tokenizer 训练的语料非常庞大, owt_train.txt 都有 11G,更别提现代大模型所需要的训练数据量。一次性读取所有文件内容进内存显然不合理,因此我们需要给语料数据进行分块,逐块进行预分词。
另外,正则表达式匹配是一个计算复杂度非常高的过程,直接对一个长文本进行预分词需要耗费非常多的时间。分块处理的另一个好处是可以使用多进程编程提高计算效率。
CS336 课程提供了文件分块处理的函数,如下所示:
# pretokenize.py | |
import os | |
import re | |
import regex | |
import multiprocessing | |
import io | |
from typing import BinaryIO, Dict, Tuple, List, Iterable | |
from collections import Counter | |
from cs336_basics.bpe.utils import split_bytes, PAT | |
def find_chunk_boundaries( | |
file: BinaryIO, | |
desired_num_chunks: int, | |
split_special_token: bytes, | |
) -> list[int]: | |
""" | |
Chunk the file into parts that can be counted independently. | |
May return fewer chunks if the boundaries end up overlapping. | |
""" | |
assert isinstance(split_special_token, bytes), "Must represent special token as a bytestring" | |
# Get total file size in bytes | |
file.seek(0, os.SEEK_END) | |
file_size = file.tell() | |
file.seek(0) | |
chunk_size = file_size // desired_num_chunks | |
# Initial guesses for chunk boundary locations, uniformly spaced | |
# Chunks start on previous index, don't include last index | |
chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)] | |
chunk_boundaries[-1] = file_size | |
mini_chunk_size = 4096 # Read ahead by 4k bytes at a time | |
for bi in range(1, len(chunk_boundaries) - 1): | |
initial_position = chunk_boundaries[bi] | |
file.seek(initial_position) # Start at boundary guess | |
while True: | |
mini_chunk = file.read(mini_chunk_size) # Read a mini chunk | |
# If EOF, this boundary should be at the end of the file | |
if mini_chunk == b"": | |
chunk_boundaries[bi] = file_size | |
break | |
# Find the special token in the mini chunk | |
found_at = mini_chunk.find(split_special_token) | |
if found_at != -1: | |
chunk_boundaries[bi] = initial_position + found_at | |
break | |
initial_position += mini_chunk_size | |
# Make sure all boundaries are unique, but might be fewer than desired_num_chunks | |
return sorted(set(chunk_boundaries)) |
给语料数据进行分块时要注意避免一个正常的单词被切分成两块,因此需要定义一个分块的边界。这个分块的边界通常由开发人员自己定义。本作业使用到的语料数据均使用一个特殊 token <|endoftext|> 来分割不同的文档,这也是本次作业中唯一使用到的特殊 token。
find_chunk_boundaries 函数正是基于这个特殊 token 来获取每个语料块的边界。简单来说,这个函数根据文件 IO 指针计算出整个文件内容所占的字节数,然后根据预期分块数 desired_num_chunks 均分得到初始分块边界。针对每个边界,逐步向后读取固定字节大小的内容,检测是否出现了特殊的分割 token,更新最终的边界。由于每个初始边界都会向后调整,因此最后有效的块数可能会少于预期值,因此返回时使用 sorted(set(chunk_boundaries)) 进行后处理。
def pretokenize_file(file: BinaryIO, special_tokens: List[str] = None) -> Counter[Tuple[bytes, ...], int]: | |
""" | |
Pre-tokenize a file into words and count their occurrences, using multiprocessing. | |
The primary function to execute the parallel pre-tokenization process. | |
""" | |
# 1. prepare special tokens | |
split_special_token = b"<|endoftext|>" | |
# 2. find chunk boundaries | |
num_processes = 16 | |
chunk_max_size = 1024 * 1024 * 500 # 500 MB | |
desired_chunks = max(num_processes, file.seek(0, os.SEEK_END) // chunk_max_size + 1) | |
boundaries = find_chunk_boundaries(file, desired_num_chunks=desired_chunks, split_special_token=split_special_token) | |
# 3. pretokenize chunks in parallel | |
def chunk_generator(): | |
for start, end in zip(boundaries[:-1], boundaries[1:]): | |
file.seek(start) | |
chunk_data = file.read(end - start) | |
yield (chunk_data, special_tokens) | |
final_words_counter: Counter[Tuple[bytes, ...], int] = Counter() | |
try: | |
with multiprocessing.Pool(processes=num_processes) as pool: | |
for result in pool.imap_unordered(pretokenize_chunk, chunk_generator()): | |
final_words_counter.update(result) | |
except Exception as e: | |
print(f"Multiprocessing failed, falling back to single process: {e}") | |
for chunk_data, special_tokens in chunk_generator(): | |
counter = pretokenize_chunk((chunk_data, special_tokens)) | |
final_words_counter.update(counter) | |
return final_words_counter |
pretokenize_file 函数为整个文件内容进行预分词,返回预分词后每个子字节串的统计结果。注意这里的子字节串被表示为单字节元组,用于后续创建 WordRef 对象。
该函数使用了多进程编程技术,定义每个分块的最大尺寸和可用进程数,调用 find_chunk_boundaries 获取每个分块的边界。然后通过 chunk_generator 迭代器动态读取文件指定分块内容进入内存。接着使用进程池和 pool.imap_unordered 分发预分词任务,每个子进程自行处理对应的分块语料内容。
子进程预分词的逻辑为 pretokenize_chunk 函数,具体实现如下:
def pretokenize_chunk(args) -> Counter[Tuple[bytes, ...], int]: | |
""" | |
Pre-tokenize a chunk, replace special tokens, and count word occurrences. | |
This function runs in parallel worker processes. | |
""" | |
chunk, special_tokens = args | |
chunk_str = chunk.decode("utf-8", errors="ignore") | |
words_counter = Counter() | |
last_end = 0 | |
# pretokenize between special tokens | |
if special_tokens and len(special_tokens) > 0: | |
for match in regex.finditer("|".join(re.escape(st) for st in special_tokens), chunk_str): | |
pre_token_text = chunk_str[last_end : match.start()] | |
for pre_token in regex.finditer(PAT, pre_token_text): | |
word_bytes = pre_token.group(0).encode("utf-8") | |
words_counter[split_bytes(word_bytes)] += 1 | |
del word_bytes # free memory | |
last_end = match.end() | |
# pretokenize after last special token | |
pre_token_text = chunk_str[last_end:] | |
for pre_token in regex.finditer(PAT, pre_token_text): | |
word_bytes = pre_token.group(0).encode("utf-8") | |
words_counter[split_bytes(word_bytes)] += 1 | |
del word_bytes | |
return words_counter |
预分词时,需要将特殊 token 从语料中剔除,因为它们不会参与任何子词合并的过程,它们本身就是一个 token。匹配时首先找到特殊 token 的位置,然后对特殊 token 直接的文本进行预分词。
注意,在正则匹配过程中,应该使用 finditer 函数避免一次性完成所有匹配任务,导致内存开销过大。
# 训练模块
基于语料库的统计数据,从字节级别开始,迭代合并最高频、最大字典序的相邻字节对,逐步构建一个指定大小的词汇表。
训练代码如下:
# train.py | |
import pdb | |
import pickle | |
from tqdm import tqdm | |
from collections import Counter, defaultdict | |
from cs336_basics.bpe.pretokenize import pretokenize_file | |
from cs336_basics.bpe.utils import myHeap, BytePair, WordRef | |
def train_bpe( | |
input_path: str, | |
vocab_size: int, | |
special_tokens: list[str] = None, | |
): | |
"""Train a BPE tokenizer on the given input file.""" | |
# 1. initialize byte-level vocab and add special tokens | |
vocab = {i: bytes([i]) for i in range(256)} # int to bytes | |
for special_token in special_tokens or []: | |
vocab[len(vocab)] = special_token.encode("utf-8", errors="ignore") | |
# 2. pretokenize input file into words and count occurrences | |
with open(input_path, "rb") as f: | |
word_counts = pretokenize_file(f, special_tokens) # Counter[Tuple[bytes, ...], int] | |
# 3. create WordRef objects and initialize byte pair stats | |
word_refs: list[WordRef] = [] | |
pair_to_count: defaultdict[BytePair, int] = defaultdict(int) | |
pair_to_word: defaultdict[BytePair, set[int]] = defaultdict(set) | |
for word_tuple, count in word_counts.items(): | |
word_ref = WordRef(tokens=list(word_tuple), count=count) | |
word_refs.append(WordRef(tokens=list(word_tuple), count=count)) | |
for bp, cnt in word_ref.byte_pairs_dict.items(): | |
pair_to_count[bp] += cnt | |
pair_to_word[bp].add(len(word_refs) - 1) | |
pair_heap = myHeap() | |
for bp, cnt in pair_to_count.items(): | |
pair_heap.push((cnt, bp)) | |
pair_heap_to_delete = dict() # heap lazy deletion map | |
# 4. merge BPEs until reaching the desired vocab size | |
merges = [] | |
with tqdm(total=vocab_size - len(vocab), desc="BPE Merging") as pbar: | |
while len(vocab) < vocab_size: | |
# Lazy deletion: pop until we find a valid top pair | |
while pair_heap.size > 0: | |
top_cnt, top_bp = pair_heap.pop() | |
if top_bp in pair_heap_to_delete: | |
top_cnt -= pair_heap_to_delete[top_bp] | |
if top_cnt > 0: | |
pair_heap.push((top_cnt, top_bp)) | |
del pair_heap_to_delete[top_bp] | |
else: | |
break | |
# merge best pair | |
assert top_bp is not None, "No more pairs to merge" | |
merges.append((top_bp.left, top_bp.right)) | |
new_token = top_bp.merged_bytes | |
vocab[len(vocab)] = new_token | |
# update word counts and pair stats | |
pair_to_count_change: defaultdict[BytePair, int] = defaultdict(int) | |
pair_heap_to_add = defaultdict(int) # cumulative count changes for lazy deletion | |
for word_ref_id in pair_to_word[top_bp]: | |
word_ref = word_refs[word_ref_id] | |
old_bp_dict = word_ref.byte_pairs_dict | |
word_ref.merge(top_bp) | |
new_bp_dict = word_ref.byte_pairs_dict | |
# update old pairs | |
for bp, cnt in old_bp_dict.items(): | |
if bp not in new_bp_dict and bp != top_bp: | |
pair_to_word[bp].remove(word_ref_id) | |
change = new_bp_dict.get(bp, 0) - cnt | |
pair_to_count_change[bp] += change | |
if change < 0: | |
pair_heap_to_delete[bp] = pair_heap_to_delete.get(bp, 0) - change | |
# update new pairs | |
for bp, cnt in new_bp_dict.items(): | |
if bp not in old_bp_dict: | |
pair_to_word[bp].add(word_ref_id) | |
pair_to_count_change[bp] += cnt | |
pair_heap_to_add[bp] += cnt | |
# apply count changes | |
for bp, change in pair_to_count_change.items(): | |
pair_to_count[bp] += change | |
if pair_to_count[bp] <= 0: | |
del pair_to_count[bp] | |
del pair_to_word[bp] | |
for bp, add_cnt in pair_heap_to_add.items(): | |
pair_heap.push((add_cnt, bp)) # push updated count to heap | |
del pair_heap_to_delete[top_bp] # remove merged pair from lazy deletion map | |
assert top_bp not in pair_to_count and top_bp not in pair_to_word | |
# update progress bar | |
pbar.update(1) | |
return vocab, merges |
整个函数流程基本符合第三节的内容,核心的三个数据是:
word_refs:预分词得到的WordRef对象列表。pair_to_count:相邻字节对的实时统计结果。pair_to_word:相邻字节对出现的WordRef索引。
每次合并时,从 pair_to_word 中取出受影响的 WordRef 对象,调用 merge 函数完成合并,并及时更新 pair_to_count 。
为了加速获取频率最高的字节对,这里使用了堆优化 + 延迟删除算法。合并时统计减少 / 删除的字节对,在下一次取最大值时进行延迟删除,并统计所有新增字节对更新堆。
# BPE Tokenizer
完成训练后,我们可以保存词表和合并规则序列,然后定义 BPETokenizer 类实现编码解码功能。
由于编码过程中也需要给字符串进行预分词,所以这里定义一个新的函数 pretokenize_text ,和前面预分词的逻辑类似,只是这里需要正确返回对应的特殊 token,因为它们也需要被编码成对应的 Token ID。
# pretokenize.py | |
def pretokenize_text(text: str, special_tokens: List[str] = None) -> Iterable[bytes]: | |
""" | |
Pre-tokenize a string into words, replacing special tokens. | |
Yields pre-tokenized byte sequences. | |
""" | |
last_end = 0 # Track the end of the last match | |
if special_tokens and len(special_tokens) > 0: | |
special_tokens.sort(key=len, reverse=True) # Longer tokens first | |
for match in re.finditer("|".join(re.escape(st) for st in special_tokens or []), text): | |
special_token = match.group() | |
# yield pre-tokens before the special token | |
for pre_token in regex.finditer(PAT, text[last_end : match.start()]): | |
yield pre_token.group().encode("utf-8") | |
yield special_token.encode("utf-8") | |
last_end = match.end() | |
for pre_token in regex.finditer(PAT, text[last_end:]): | |
yield pre_token.group().encode("utf-8") |
这里将函数写成一个迭代器的形式,迭代返回一个字节串对象,这么做是为了兼容后面不同的编码需求。
作业中, BPETokenizer 的数据结构如下:
vocab: Dict[int, bytes]:Token ID 到字节串对象字典。inv_vocab: Dict[bytes, int]:vocab的反向字典。merges: List[Tuple[bytes, bytes]]:合并规则序列。special_tokens: List[str]:特殊 token 列表。encode_cache: Dict[bytes, List[int]]:分词器编码的缓存结果,用于加速编码过程。from_file:从文件中构建分词器对象。_pretoken_to_ids:编码的核心逻辑,将一个预分词字节串转成对应的 Token ID,注意特殊 token 需要转换成其对应的 ID。encode:编码一个给定字符串,返回编码后的整数序列。encode_iterable:考虑到有时候需要编码的内容(比如文件)非常大,一次性读入内存不可取,因此传入字符串迭代器,返回编码后的整数 ID 迭代器,保证内存效率。decode:将给定的整数序列解码成对应的字符串。由于整数序列由用户给定,有可能最终得到的字节串无法通过 UTF-8 进行解码,这里使用.decode("utf-8", errors="replace")将这些无效字节转换成标准的 Unicode 替代字符。
废话少说,直接上代码实现:
# tokenizer.py | |
import pickle | |
import time | |
from typing import Dict, Tuple, List, Union, Iterable | |
from cs336_basics.bpe.pretokenize import pretokenize_text | |
from cs336_basics.bpe.utils import split_bytes, WordRef, BytePair | |
class Tokenizer: | |
vocab: Dict[int, bytes] | |
inv_vocab: Dict[bytes, int] | |
merges: List[Tuple[bytes, bytes]] | |
special_tokens: List[str] | |
encode_cache: Dict[bytes, List[int]] = {} | |
def __init__(self, vocab: Dict[int, bytes], merges: List[Tuple[bytes, bytes]], special_tokens: List[str] = None): | |
self.vocab = vocab | |
self.inv_vocab = {v: k for k, v in vocab.items()} | |
self.merges = merges | |
self.special_tokens = special_tokens or [] | |
@classmethod | |
def from_file(cls, vocab_filepath: str, merges_filepath: str, special_tokens: List[str] | None = None): | |
with open(vocab_filepath, "rb") as vf: | |
vocab = pickle.load(vf) | |
with open(merges_filepath, "rb") as mf: | |
merges = pickle.load(mf) | |
return cls(vocab=vocab, merges=merges, special_tokens=special_tokens) | |
def _pretoken_to_ids(self, text: bytes) -> Iterable[int]: | |
"""Encode a byte string into a list of token IDs using the BPE merges""" | |
if text in self.encode_cache: | |
yield from self.encode_cache[text] | |
return | |
word_ref = WordRef(tokens=list(split_bytes(text)), count=1) | |
byte_pairs_dict = word_ref.byte_pairs_dict | |
for merge_left, merge_right in self.merges: | |
bp = BytePair(merge_left, merge_right) | |
if bp not in byte_pairs_dict: | |
continue | |
word_ref.merge(bp) | |
byte_pairs_dict = word_ref.byte_pairs_dict | |
if len(word_ref.tokens) == 1: | |
break | |
token_ids = [] | |
for token in word_ref.tokens: | |
yield self.inv_vocab[token] | |
token_ids.append(self.inv_vocab[token]) | |
self.encode_cache[text] = token_ids | |
def encode(self, text: str) -> List[int]: | |
"""Encode a string into a list of token IDs using the BPE merges""" | |
encoded_ids = [] | |
# reset cache with special tokens | |
self.encode_cache = {st.encode("utf-8"): [self.inv_vocab[st.encode("utf-8")]] for st in self.special_tokens} | |
for pre_toeken in pretokenize_text(text, self.special_tokens): | |
for token_id in self._pretoken_to_ids(pre_toeken): | |
encoded_ids.append(token_id) | |
return encoded_ids | |
def encode_iterable(self, iterable: Iterable[str]) -> Iterable[int]: | |
"""Encode an iterable of strings into a flat iterable of token IDs using the BPE merges""" | |
self.encode_cache = {st.encode("utf-8"): [self.inv_vocab[st.encode("utf-8")]] for st in self.special_tokens} | |
for text in iterable: | |
for pre_toeken in pretokenize_text(text, self.special_tokens): | |
for token_id in self._pretoken_to_ids(pre_toeken): | |
yield token_id | |
def decode(self, ids: List[int]) -> str: | |
bytes_seq = [] | |
for i in ids: | |
if i >= len(self.vocab): | |
Warning(f"Token ID {i} not in vocab, replacing with b''") | |
bytes_seq.append(self.vocab.get(i, b"")) | |
seq = b"".join(self.vocab.get(i) for i in ids).decode("utf-8", errors="replace") | |
return seq |
