# 概述

在自然语言处理领域,分词器(Tokenizer) 承担着至关重要的 “桥梁” 作用:它将人类的自然语言(字符串)转化为大型语言模型能够识别和处理的整数序列。该整数序列就是大模型的输入,而 token 指的就是这些整数数字。

在众多分词策略中,字节对编码(Byte Pair Encoding, BPE) 以其高效、可控的词汇表构建机制,成为了现代预训练模型(如 BERT,GPT,T5 等)的基石。BPE 成功解决了传统分词方法在处理海量词汇和未登录词时的痛点。

本篇博客基于 2025 斯坦福 CS336 课程的第一章节作业要求,整理 BPE 分词器的原理与实现流程,并从零开始实现一个基础的 BPE 分词器。


# Unicode 编码标准

介绍 BPE 算法之前,需要先了解计算机如何保存各种字符串的。

Unicode 是一个行业标准,为世界上所有字符(英文字母、中文、日文等)提供统一的表示。截止至 2024 年公布的 Unicode 16.0 版本已经定义了 154998 个字符的表示。它为每一个字符分配了一个独一无二的编号,即码点(Code Point)。Unicode 规范了字符集,而具体的存储和传输则依赖于不同的编码方案,其中最常用的是 UTF-8。UTF-8 属于变长编码,它使用 1 到 4 个字节来表示一个 Unicode 字符。对于 ASCII 字符(如英文字母和数字),它只占用 1 个字节。

Python 对 Unicode 和各种编码提供了强大且直观的支持。在实现 BPE 时,主要用到以下两个核心方法:

  • 编码:字符串对象的 .encode() 方法将 Unicode 字符串转成对应的字节串( bytes 对象)。例如要获取一个字符串的 UTF-8 字节表示,可以使用 string.encode(encoding='utf-8')

    print("✅".encode(encoding='utf-8'))  # b'\xe2\x9c\x85'
  • 解码:使用字节串对象的 .decode() 方法可以将字节串转换回 Unicode 字符串。例如按照 UTF-8 标准进行解码可以使用 bytes_object.decode(encoding='utf-8')

    print(b'\xe2\x9c\x85'.decode(encoding='utf-8'))  # '✅'

# BPE Tokenizer 训练算法

由于所有 Unicode 字符串都可以表示为字节序列,因此一个只有 256 单元(1 字节可以表示 0-255)的词表就足以表示所有字符串。但是这个词表是非常低效。比如前面一小节中,一个 字符就被编码成 3 个字节。即便按照平均标准,一个字符也需要约 2.5 个字节来表示,这对于现代大型语言模型的输入序列长度而言是不可承受的。Transformer 架构的核心特征使其时间复杂度与输入序列长度 LL 的平方成正比,即 O(L2)O(L^2)。并且,输入序列越长,训练和推理所需的计算资源(时间、GPU 显存)就越多,模型维护长期记忆和全局信息的能力越会受到限制。

BPE 算法通过学习语料库中的高频子词模式,将多个字节合并成一个具有语义信息的子词单元,从而大幅压缩序列长度。例如,如果在词表中学习到 b'\xe2\x9c\x85' ,那么只用一个整数就可以表示

BPE 分词器中,每一个 token 都是一段字节序列,它可能是一个字节,一个单词,也可能是一个单词的部分字节等,因此也被称为子词(subword)。

下面,具体讲解 BPE 分词器的训练算法流程,可以分成三个步骤:词表初始化、预分词、合并子词。

# 词表初始化

首先,词表必须包含 28=2562^8=256 个元素,与所有单字节一一对应,这 256 个字节构成了 BPE 的原子单元,保证了理论上任何 Unicode 字符串(通过 UTF-8 编码)都可以被表示和分解。

此外,我们还需要考虑特殊 token。所谓特殊 token 是模型可识别的,拥有特定语义功能的字符串(最后也用字节串表示),不会参与正常的 BPE 合并过程。比如 <cls><sep> 在 BERT 中就是表示分类和分隔的特殊 token。这些特殊 token 在后续的编码和训练过程中会被当作一个完整的单元进行特殊处理,因此在初始化时也需要加入到词表中。

# 预分词

在第三步合并子词时,我们会对语料库中频繁出现的相邻子词对进行合并。例如, AI 这个单词中 AI 各自对应一个字节,由上述初始词表可以表示成两个 token。若 AI 这个单词在语料库中出现次数比较多,那么 AI 在第三步就会被合并成新的 token。

那么预分词(Pre-tokenization)在干什么呢?我们一般期望,每个 token 都应该具有语义信息,比如上述的 AI 亦或者是和英语语法有关的 ester 等等,而对于 y n 这种就不是一个好的 token 表示。但是,假设说,语料库中经常出现 "my name" 文本,那么在第三步合并子词时就有可能合并出 y n 这种 token。因此预分词的目的就是定义子词合并的合理边界,避免跨单词合并,从而保留语义结构。

CS336 第一章作业中,我们使用如下 GPT-2 的预分词正则表达式,它可以将一段话拆分成具有完整含义的子字符串序列:

import regex
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
regex.findall(PAT, "你好世界, Hello World!")  # [' 你好世界 ', ',', ' Hello', ' World', '!']

这里划分出来的子字符串列表就是就是下一步 BPE 合并操作的最小作用域。

预分词不仅可以保证合并出来的子词具有语义结构,还可以提高 BPE 训练的效率。在这一阶段,我们可以统计出每个预分词块(例如 Hello )在整个语料库中出现的总次数。在后续的合并计数中,我们只需将内部子词对的频率乘以该块的出现次数,即可避免对整个语料进行重复扫描,大大加速合并计数。

# 子词合并

子词合并就是 BPE 训练算法最核心的部分。

对于每个预分词后的子字符串,我们可以用当前的词表将其表示为一系列子词,进而得到整个语料库的子词表示。然后统计所有相邻的子词对的出现次数。注意,我们不统计预分词子字符串之间的子词对。然后将出现次数最多,字典序最大的子词对取出进行合并,得到新的子词添加到词表中,进而用新的子词替换原先出现的子词对。

BPE 中,子词就是 token,在词表中用字节串表示。

例如,在下面语料中,第一次合并时选择的子词对是 (b's', b't') ,进而我们将 b'st' 添加到词表中,然后将 [b'n', b'e', b'w', b'e', b's', b't'] 重新表示为 [b'n', b'e', b'w', b'e', b'st']

low low low low low
lower lower widest widest widest
newest newest newest newest newest newest

上述过程会不断迭代进行,直至词表达到预定义的大小。


# BPE Tokenizer 编码解码

在 BPE 算法训练结束后,我们得到了两份数据:

  1. 词汇表( vocab ): 包含了所有的基本字节和学习到的复合子词。
  2. 合并规则序列( merges ): 记录了 BPE 训练过程中,所有子词对被合并的顺序。

基于这两个数据,我们可以构建一个完整的 BPETokenizer 类,其核心功能是实现将自然语言转化为模型输入的 encode() 方法,以及反向的 decode() 方法。前者将给定的 Unicode 字符串转换为模型可识别的整数 Token ID 序列,后者则将给定的整数 Token ID 序列转换成 Unicode 字符串。

给定词表和待编码的 Unicode 字符串,其实有许多可行的编码方案。比如最简单的一种就是每个字节都编码成一个 token。但为了确保模型输入的一致性和效率,我们必须遵循一个最优的、确定的编码方案。按照 CS336 课程作业要求, BPETokenizer 的编码流程如下:

  • 首先,使用与训练时完全相同的预分词方法,将输入字符串拆分成语义子字符串块,并将其转化为 UTF-8 字节序列,并初始化为一个字节一个 Token 的序列。
  • 严格按照训练时学习到的 merges 列表的顺序,从头到尾遍历每条合并规则。对于当前序列中的每一个预分词块,如果其中包含当前规则要合并的子词对,则进行合并,最终生成一个新的 Token ID 序列。
  • merges 列表中的所有规则都应用完毕后,最终得到的 Token 序列就是我们期望的、压缩效率最高的编码结果。

上述编码流程完全仿照训练过程。这是因为 BPE 算法的本质是贪婪的,它学习到的 merges 列表反映了将序列压缩到词汇表大小限制内的最优、最频繁的合并路径。尽管对于单个待编码的 Unicode 字符串来说,通过这种方式编码得到的 Token 数目不一定是理论上最少的。但是,由于 merges 列表是基于整个大规模训练语料统计的最高频模式学习而来,因此在整个输入空间上,这种编码方式具有最优的期望压缩效率

解码是一个相对直接的反向过程:遍历输入的整数序列,根据保存的 vocab 词汇表,查找每个 Token ID 对应的字节序列。将所有查找到的字节序列按顺序拼接成一个完整的字节串对象。最后使用字节串对象的 .decode(encoding='utf-8') 方法转成最终的 Unicode 字符串。


# 代码实现

本节将按照 2025 斯坦福 CS336 课程的作业要求,从零实现一个基础的 BPE 分词器,并补充相关的技术细节。

# 训练语料

下面是代码实现中用到的语料数据。

mkdir -p data
cd data
wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-train.txt
wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-valid.txt
wget https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_train.txt.gz
gunzip owt_train.txt.gz
wget https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_valid.txt.gz
gunzip owt_valid.txt.gz

这些语料数据已经按照要求经过了预处理,可以直接用来训练 BPE Tokenizer。 valid 数据相比 train 数据少得多,用于快速验证算法的正确性。

# 目录结构

所有代码总共分成四个文件,如下所示:

- pretokenize.py
- train.py
- tokenizer.py
- utils.py

pretokenize.py 保存与预分词相关的函数。

train.py 保存训练 BPE Tokenizer 的函数。

tokenizer.py 定义了 BPE Tokenizer 类,实现编码解码功能。

utils.py 保存了整个 BPE 算法实现过程中用到的数据结构。

# BPE 训练函数

# 工具模块

首先给出 utils.py 的文件内容,然后详细说明各部分的作用。

# utils.py
from typing import NamedTuple, List, Tuple
from dataclasses import dataclass, field
from collections import Counter, defaultdict
BYTES: List[bytes] = [bytes([i]) for i in range(256)]
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
def split_bytes(data: bytes) -> Tuple[bytes]:
    """Split a bytestring into a list of single-byte bytestrings."""
    return tuple([BYTES[b] for b in data])

BYTES 保存了 256 个单字节对象,避免在分词时频繁创建新的字节对象,减少内存消耗。 PAT 是预分词使用的正则表达式字符串。

split_bytes 函数将字节串拆分成成单字节对象组成的元组。

# utils.py
class myHeap:
    def __init__(self, mode="max", init_list=None):
        self.mode = mode
        self._init_data(init_list or [])
    def _init_data(self, init_list) -> list:
        self.data = list(init_list)
        self.size = len(self.data)
        if self.size == 0:
            return
        for i in range((len(self.data) - 1) // 2, -1, -1):
            self._shift_down(i)
    def _compare(self, a, b):
        if self.mode == "max":
            return a > b
        else:
            return a < b
    def _shift_down(self, index):
        """Shift down the element at index to maintain heap property"""
        while True:
            left = 2 * index + 1
            right = 2 * index + 2
            maxIndex = index
            if left < self.size and self._compare(self.data[left], self.data[maxIndex]):
                maxIndex = left
            if right < self.size and self._compare(self.data[right], self.data[maxIndex]):
                maxIndex = right
            if maxIndex != index:
                self.data[index], self.data[maxIndex] = self.data[maxIndex], self.data[index]
                index = maxIndex
            else:
                break
    def push(self, item):
        """Push an item onto the heap"""
        self.data.append(item)
        self.size += 1
        index = self.size - 1
        while index > 0:
            parent = (index - 1) // 2
            if self._compare(self.data[index], self.data[parent]):
                self.data[index], self.data[parent] = self.data[parent], self.data[index]
                index = parent
            else:
                break
    def pop(self) -> any:
        """Pop the top item off the heap and return it"""
        if self.size == 0:
            return None
        top_item = self.data[0]
        self.data[0] = self.data[-1]
        self.data.pop()
        self.size -= 1
        if self.size > 0:
            self._shift_down(0)
        return top_item
    def top(self) -> any:
        """Return the top item of the heap without removing it"""
        return self.data[0] if self.size > 0 else None

myHeap 类是手动实现的堆数据结构,支持大顶堆和小顶堆,用于高效获取频率最高,字典序最大的分词对。

# utils.py
class BytePair(NamedTuple):
    left: bytes
    right: bytes
    @property
    def merged_bytes(self):
        return self.left + self.right
    def __str__(self) -> str:
        return f"({self.left.decode('utf-8', errors='ignore')} {self.right.decode('utf-8', errors='ignore')})"
    def __repr__(self) -> str:
        return self.__str__()

BPE Tokenizer 中每个 token 都是一个字节串,每次合并时需要选择一个 token 对。为了方便管理,这里定义 BytePair 类,继承元组类型,表示 token 对,其中 merged_bytes 方法通过修饰器转换成属性,可以返回该 token 对合并后的字节串。

# utils.py
@dataclass
class WordRef:
    tokens: list[bytes]
    count: int
    @property
    def byte_pairs_dict(self) -> defaultdict[BytePair, int]:
        bp_to_count = defaultdict(int)
        for i in range(len(self.tokens) - 1):
            bp = BytePair(self.tokens[i], self.tokens[i + 1])
            bp_to_count[bp] += self.count
        return bp_to_count
    def merge(self, bp_merged: BytePair):
        """merge the given byte pair in this word"""
        new_tokens: list[bytes] = []
        i = 0
        while i < len(self.tokens):
            if i < len(self.tokens) - 1 and self.tokens[i] == bp_merged.left and self.tokens[i + 1] == bp_merged.right:
                new_tokens.append(bp_merged.merged_bytes)
                i += 2
            else:
                new_tokens.append(self.tokens[i])
                i += 1
        self.tokens = new_tokens
    def __str__(self) -> str:
        return f"{self.tokens} : {self.count}"
    def __repr__(self) -> str:
        return self.__str__()

WordRef 类表示预分词后,每个有效的子字节串。属性 tokens 表示当前该字节串被表示成的 token 序列,属性 cnt 是预分词时统计该子字节串出现的次数。

由于 BPE 算法合并过程时需要统计 token 对数量,因此这里定义 byte_pairs_dict 返回该字节串的对应的 token 对统计结果。 merge 函数接受需要合并的 token 对,类型为 BytePair 对象,按要求合并,得到新的 token 序列。

# 预分词模块

介绍完 utils.py 的内容,再来讲讲预分词的代码如何实现。

预分词的功能说起来简单,只需要使用前文提到的正则表达式进行匹配得到结果即可。但难点在于如何兼备内存效率和计算效率。一般来说,BPE Tokenizer 训练的语料非常庞大, owt_train.txt 都有 11G,更别提现代大模型所需要的训练数据量。一次性读取所有文件内容进内存显然不合理,因此我们需要给语料数据进行分块,逐块进行预分词。

另外,正则表达式匹配是一个计算复杂度非常高的过程,直接对一个长文本进行预分词需要耗费非常多的时间。分块处理的另一个好处是可以使用多进程编程提高计算效率。

CS336 课程提供了文件分块处理的函数,如下所示:

# pretokenize.py
import os
import re
import regex
import multiprocessing
import io
from typing import BinaryIO, Dict, Tuple, List, Iterable
from collections import Counter
from cs336_basics.bpe.utils import split_bytes, PAT
def find_chunk_boundaries(
    file: BinaryIO,
    desired_num_chunks: int,
    split_special_token: bytes,
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), "Must represent special token as a bytestring"
    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)
    chunk_size = file_size // desired_num_chunks
    # Initial guesses for chunk boundary locations, uniformly spaced
    # Chunks start on previous index, don't include last index
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size
    mini_chunk_size = 4096  # Read ahead by 4k bytes at a time
    for bi in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[bi]
        file.seek(initial_position)  # Start at boundary guess
        while True:
            mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk
            # If EOF, this boundary should be at the end of the file
            if mini_chunk == b"":
                chunk_boundaries[bi] = file_size
                break
            # Find the special token in the mini chunk
            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[bi] = initial_position + found_at
                break
            initial_position += mini_chunk_size
    # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
    return sorted(set(chunk_boundaries))

给语料数据进行分块时要注意避免一个正常的单词被切分成两块,因此需要定义一个分块的边界。这个分块的边界通常由开发人员自己定义。本作业使用到的语料数据均使用一个特殊 token <|endoftext|> 来分割不同的文档,这也是本次作业中唯一使用到的特殊 token。

find_chunk_boundaries 函数正是基于这个特殊 token 来获取每个语料块的边界。简单来说,这个函数根据文件 IO 指针计算出整个文件内容所占的字节数,然后根据预期分块数 desired_num_chunks 均分得到初始分块边界。针对每个边界,逐步向后读取固定字节大小的内容,检测是否出现了特殊的分割 token,更新最终的边界。由于每个初始边界都会向后调整,因此最后有效的块数可能会少于预期值,因此返回时使用 sorted(set(chunk_boundaries)) 进行后处理。

def pretokenize_file(file: BinaryIO, special_tokens: List[str] = None) -> Counter[Tuple[bytes, ...], int]:
    """
    Pre-tokenize a file into words and count their occurrences, using multiprocessing.
    The primary function to execute the parallel pre-tokenization process.
    """
    # 1. prepare special tokens
    split_special_token = b"<|endoftext|>"
    # 2. find chunk boundaries
    num_processes = 16
    chunk_max_size = 1024 * 1024 * 500  # 500 MB
    desired_chunks = max(num_processes, file.seek(0, os.SEEK_END) // chunk_max_size + 1)
    boundaries = find_chunk_boundaries(file, desired_num_chunks=desired_chunks, split_special_token=split_special_token)
    # 3. pretokenize chunks in parallel
    def chunk_generator():
        for start, end in zip(boundaries[:-1], boundaries[1:]):
            file.seek(start)
            chunk_data = file.read(end - start)
            yield (chunk_data, special_tokens)
    final_words_counter: Counter[Tuple[bytes, ...], int] = Counter()
    try:
        with multiprocessing.Pool(processes=num_processes) as pool:
            for result in pool.imap_unordered(pretokenize_chunk, chunk_generator()):
                final_words_counter.update(result)
    except Exception as e:
        print(f"Multiprocessing failed, falling back to single process: {e}")
        for chunk_data, special_tokens in chunk_generator():
            counter = pretokenize_chunk((chunk_data, special_tokens))
            final_words_counter.update(counter)
    return final_words_counter

pretokenize_file 函数为整个文件内容进行预分词,返回预分词后每个子字节串的统计结果。注意这里的子字节串被表示为单字节元组,用于后续创建 WordRef 对象。

该函数使用了多进程编程技术,定义每个分块的最大尺寸和可用进程数,调用 find_chunk_boundaries 获取每个分块的边界。然后通过 chunk_generator 迭代器动态读取文件指定分块内容进入内存。接着使用进程池和 pool.imap_unordered 分发预分词任务,每个子进程自行处理对应的分块语料内容。

子进程预分词的逻辑为 pretokenize_chunk 函数,具体实现如下:

def pretokenize_chunk(args) -> Counter[Tuple[bytes, ...], int]:
    """
    Pre-tokenize a chunk, replace special tokens, and count word occurrences.
    This function runs in parallel worker processes.
    """
    chunk, special_tokens = args
    chunk_str = chunk.decode("utf-8", errors="ignore")
    words_counter = Counter()
    last_end = 0
    # pretokenize between special tokens
    if special_tokens and len(special_tokens) > 0:
        for match in regex.finditer("|".join(re.escape(st) for st in special_tokens), chunk_str):
            pre_token_text = chunk_str[last_end : match.start()]
            for pre_token in regex.finditer(PAT, pre_token_text):
                word_bytes = pre_token.group(0).encode("utf-8")
                words_counter[split_bytes(word_bytes)] += 1
                del word_bytes  # free memory
            last_end = match.end()
    # pretokenize after last special token
    pre_token_text = chunk_str[last_end:]
    for pre_token in regex.finditer(PAT, pre_token_text):
        word_bytes = pre_token.group(0).encode("utf-8")
        words_counter[split_bytes(word_bytes)] += 1
        del word_bytes
    return words_counter

预分词时,需要将特殊 token 从语料中剔除,因为它们不会参与任何子词合并的过程,它们本身就是一个 token。匹配时首先找到特殊 token 的位置,然后对特殊 token 直接的文本进行预分词。

注意,在正则匹配过程中,应该使用 finditer 函数避免一次性完成所有匹配任务,导致内存开销过大。

# 训练模块

基于语料库的统计数据,从字节级别开始,迭代合并最高频、最大字典序的相邻字节对,逐步构建一个指定大小的词汇表。

训练代码如下:

# train.py
import pdb
import pickle
from tqdm import tqdm
from collections import Counter, defaultdict
from cs336_basics.bpe.pretokenize import pretokenize_file
from cs336_basics.bpe.utils import myHeap, BytePair, WordRef
def train_bpe(
    input_path: str,
    vocab_size: int,
    special_tokens: list[str] = None,
):
    """Train a BPE tokenizer on the given input file."""
    # 1. initialize byte-level vocab and add special tokens
    vocab = {i: bytes([i]) for i in range(256)}  # int to bytes
    for special_token in special_tokens or []:
        vocab[len(vocab)] = special_token.encode("utf-8", errors="ignore")
    # 2. pretokenize input file into words and count occurrences
    with open(input_path, "rb") as f:
        word_counts = pretokenize_file(f, special_tokens)  # Counter[Tuple[bytes, ...], int]
    # 3. create WordRef objects and initialize byte pair stats
    word_refs: list[WordRef] = []
    pair_to_count: defaultdict[BytePair, int] = defaultdict(int)
    pair_to_word: defaultdict[BytePair, set[int]] = defaultdict(set)
    for word_tuple, count in word_counts.items():
        word_ref = WordRef(tokens=list(word_tuple), count=count)
        word_refs.append(WordRef(tokens=list(word_tuple), count=count))
        for bp, cnt in word_ref.byte_pairs_dict.items():
            pair_to_count[bp] += cnt
            pair_to_word[bp].add(len(word_refs) - 1)
    pair_heap = myHeap()
    for bp, cnt in pair_to_count.items():
        pair_heap.push((cnt, bp))
    pair_heap_to_delete = dict()  # heap lazy deletion map
    # 4. merge BPEs until reaching the desired vocab size
    merges = []
    with tqdm(total=vocab_size - len(vocab), desc="BPE Merging") as pbar:
        while len(vocab) < vocab_size:
            # Lazy deletion: pop until we find a valid top pair
            while pair_heap.size > 0:
                top_cnt, top_bp = pair_heap.pop()
                if top_bp in pair_heap_to_delete:
                    top_cnt -= pair_heap_to_delete[top_bp]
                    if top_cnt > 0:
                        pair_heap.push((top_cnt, top_bp))
                    del pair_heap_to_delete[top_bp]
                else:
                    break
            # merge best pair
            assert top_bp is not None, "No more pairs to merge"
            merges.append((top_bp.left, top_bp.right))
            new_token = top_bp.merged_bytes
            vocab[len(vocab)] = new_token
            # update word counts and pair stats
            pair_to_count_change: defaultdict[BytePair, int] = defaultdict(int)
            pair_heap_to_add = defaultdict(int)  # cumulative count changes for lazy deletion
            for word_ref_id in pair_to_word[top_bp]:
                word_ref = word_refs[word_ref_id]
                old_bp_dict = word_ref.byte_pairs_dict
                word_ref.merge(top_bp)
                new_bp_dict = word_ref.byte_pairs_dict
                # update old pairs
                for bp, cnt in old_bp_dict.items():
                    if bp not in new_bp_dict and bp != top_bp:
                        pair_to_word[bp].remove(word_ref_id)
                    change = new_bp_dict.get(bp, 0) - cnt
                    pair_to_count_change[bp] += change
                    if change < 0:
                        pair_heap_to_delete[bp] = pair_heap_to_delete.get(bp, 0) - change
                # update new pairs
                for bp, cnt in new_bp_dict.items():
                    if bp not in old_bp_dict:
                        pair_to_word[bp].add(word_ref_id)
                        pair_to_count_change[bp] += cnt
                        pair_heap_to_add[bp] += cnt
            # apply count changes
            for bp, change in pair_to_count_change.items():
                pair_to_count[bp] += change
                if pair_to_count[bp] <= 0:
                    del pair_to_count[bp]
                    del pair_to_word[bp]
            for bp, add_cnt in pair_heap_to_add.items():
                pair_heap.push((add_cnt, bp))  # push updated count to heap
            del pair_heap_to_delete[top_bp]  # remove merged pair from lazy deletion map
            assert top_bp not in pair_to_count and top_bp not in pair_to_word
            # update progress bar
            pbar.update(1)
    return vocab, merges

整个函数流程基本符合第三节的内容,核心的三个数据是:

  • word_refs :预分词得到的 WordRef 对象列表。
  • pair_to_count :相邻字节对的实时统计结果。
  • pair_to_word :相邻字节对出现的 WordRef 索引。

每次合并时,从 pair_to_word 中取出受影响的 WordRef 对象,调用 merge 函数完成合并,并及时更新 pair_to_count

为了加速获取频率最高的字节对,这里使用了堆优化 + 延迟删除算法。合并时统计减少 / 删除的字节对,在下一次取最大值时进行延迟删除,并统计所有新增字节对更新堆。

# BPE Tokenizer

完成训练后,我们可以保存词表和合并规则序列,然后定义 BPETokenizer 类实现编码解码功能。

由于编码过程中也需要给字符串进行预分词,所以这里定义一个新的函数 pretokenize_text ,和前面预分词的逻辑类似,只是这里需要正确返回对应的特殊 token,因为它们也需要被编码成对应的 Token ID。

# pretokenize.py
def pretokenize_text(text: str, special_tokens: List[str] = None) -> Iterable[bytes]:
    """
    Pre-tokenize a string into words, replacing special tokens.
    Yields pre-tokenized byte sequences.
    """
    last_end = 0  # Track the end of the last match
    if special_tokens and len(special_tokens) > 0:
        special_tokens.sort(key=len, reverse=True)  # Longer tokens first
        for match in re.finditer("|".join(re.escape(st) for st in special_tokens or []), text):
            special_token = match.group()
            # yield pre-tokens before the special token
            for pre_token in regex.finditer(PAT, text[last_end : match.start()]):
                yield pre_token.group().encode("utf-8")
            yield special_token.encode("utf-8")
            last_end = match.end()
    for pre_token in regex.finditer(PAT, text[last_end:]):
        yield pre_token.group().encode("utf-8")

这里将函数写成一个迭代器的形式,迭代返回一个字节串对象,这么做是为了兼容后面不同的编码需求。

作业中, BPETokenizer 的数据结构如下:

  • vocab: Dict[int, bytes] :Token ID 到字节串对象字典。
  • inv_vocab: Dict[bytes, int]vocab 的反向字典。
  • merges: List[Tuple[bytes, bytes]] :合并规则序列。
  • special_tokens: List[str] :特殊 token 列表。
  • encode_cache: Dict[bytes, List[int]] :分词器编码的缓存结果,用于加速编码过程。
  • from_file :从文件中构建分词器对象。
  • _pretoken_to_ids :编码的核心逻辑,将一个预分词字节串转成对应的 Token ID,注意特殊 token 需要转换成其对应的 ID。
  • encode :编码一个给定字符串,返回编码后的整数序列。
  • encode_iterable :考虑到有时候需要编码的内容(比如文件)非常大,一次性读入内存不可取,因此传入字符串迭代器,返回编码后的整数 ID 迭代器,保证内存效率。
  • decode :将给定的整数序列解码成对应的字符串。由于整数序列由用户给定,有可能最终得到的字节串无法通过 UTF-8 进行解码,这里使用 .decode("utf-8", errors="replace") 将这些无效字节转换成标准的 Unicode 替代字符。

废话少说,直接上代码实现:

# tokenizer.py
import pickle
import time
from typing import Dict, Tuple, List, Union, Iterable
from cs336_basics.bpe.pretokenize import pretokenize_text
from cs336_basics.bpe.utils import split_bytes, WordRef, BytePair
class Tokenizer:
    vocab: Dict[int, bytes]
    inv_vocab: Dict[bytes, int]
    merges: List[Tuple[bytes, bytes]]
    special_tokens: List[str]
    encode_cache: Dict[bytes, List[int]] = {}
    def __init__(self, vocab: Dict[int, bytes], merges: List[Tuple[bytes, bytes]], special_tokens: List[str] = None):
        self.vocab = vocab
        self.inv_vocab = {v: k for k, v in vocab.items()}
        self.merges = merges
        self.special_tokens = special_tokens or []
    @classmethod
    def from_file(cls, vocab_filepath: str, merges_filepath: str, special_tokens: List[str] | None = None):
        with open(vocab_filepath, "rb") as vf:
            vocab = pickle.load(vf)
        with open(merges_filepath, "rb") as mf:
            merges = pickle.load(mf)
        return cls(vocab=vocab, merges=merges, special_tokens=special_tokens)
    def _pretoken_to_ids(self, text: bytes) -> Iterable[int]:
        """Encode a byte string into a list of token IDs using the BPE merges"""
        if text in self.encode_cache:
            yield from self.encode_cache[text]
            return
        word_ref = WordRef(tokens=list(split_bytes(text)), count=1)
        byte_pairs_dict = word_ref.byte_pairs_dict
        for merge_left, merge_right in self.merges:
            bp = BytePair(merge_left, merge_right)
            if bp not in byte_pairs_dict:
                continue
            word_ref.merge(bp)
            byte_pairs_dict = word_ref.byte_pairs_dict
            if len(word_ref.tokens) == 1:
                break
        token_ids = []
        for token in word_ref.tokens:
            yield self.inv_vocab[token]
            token_ids.append(self.inv_vocab[token])
        self.encode_cache[text] = token_ids
    def encode(self, text: str) -> List[int]:
        """Encode a string into a list of token IDs using the BPE merges"""
        encoded_ids = []
        # reset cache with special tokens
        self.encode_cache = {st.encode("utf-8"): [self.inv_vocab[st.encode("utf-8")]] for st in self.special_tokens}
        for pre_toeken in pretokenize_text(text, self.special_tokens):
            for token_id in self._pretoken_to_ids(pre_toeken):
                encoded_ids.append(token_id)
        return encoded_ids
    def encode_iterable(self, iterable: Iterable[str]) -> Iterable[int]:
        """Encode an iterable of strings into a flat iterable of token IDs using the BPE merges"""
        self.encode_cache = {st.encode("utf-8"): [self.inv_vocab[st.encode("utf-8")]] for st in self.special_tokens}
        for text in iterable:
            for pre_toeken in pretokenize_text(text, self.special_tokens):
                for token_id in self._pretoken_to_ids(pre_toeken):
                    yield token_id
    def decode(self, ids: List[int]) -> str:
        bytes_seq = []
        for i in ids:
            if i >= len(self.vocab):
                Warning(f"Token ID {i} not in vocab, replacing with b''")
            bytes_seq.append(self.vocab.get(i, b""))
        seq = b"".join(self.vocab.get(i) for i in ids).decode("utf-8", errors="replace")
        return seq