# 概述

本篇博客基于 2025 斯坦福 CS336 课程的第一章节作业要求,系统梳理了其中关于 Transformer 模型构建与训练的核心内容。该部分涵盖了词嵌入层、RoPE 位置编码、注意力机制等基础模块,并在此基础上逐步搭建一个标准的 Transformer 模型,同时实现相应的优化器等,完成语言模型的端到端训练流程。

原始作业内容包含了大量底层模块的实现,例如线性层、交叉熵损失函数等,旨在帮助初学者通过亲手实践加深对原理的理解。由于本文定位为总结与梳理,因此不会面面俱到,而是选取我个人认为较为关键、具有代表性的部分进行介绍与归纳。


# Transformer 基础模块

本章节将介绍在构建 Transformer 模型过程中所需的基础模块组件及其具体代码实现。所有相关代码均统一放置在 module 包下,即创建一个名为 module 的目录,并在该目录中实现和组织各类基础组件的代码文件。

# 数学记号

在深度学习论文中,通常会使用行向量的形式来统一表示模型的计算过程。例如在线性层中,常见的表达方式为:

y=xW+by = xW^{\top} + b

其中 xR1×dinx \in \mathbb{R}^{1 \times d_{in}} 是输入矩阵,WRdout×dinW \in \mathbb{R}^{d_{out} \times d_{in}} 是权重矩阵,bRdoutb \in \mathbb{R}^{d_{out}} 是偏置向量,yR1×douty \in \mathbb{R}^{1 \times d_{out}} 是输出矩阵。

在这种表示方式下,输入矩阵 xx 的每一行会与 WW^{\top} 对应的列进行矩阵乘法,本质上等价于与权重矩阵 WW 的某一行进行点积运算。这种写法在深度学习中非常常见。该表示方法的另外一个隐晦的好处是与代码实现保持一致。因为 PyTorch 主要采用行优先的存储方式,所以像上述线性层中的参数,我们通常存储 WW,在计算时使用其转置视图,而并不直接存储 WW^{\top}

需要注意的是,这种记号方式与线性代数教材中更常见的列向量表示方式有所不同。在线性代数中,我们更习惯将单个样本表示为列向量,此时线性变换通常写作:

y=Wx+by=Wx+b

为了方便阅读,本文在理论推导部分将采用列向量的表示方式。但在具体代码实现中,仍然沿用深度学习按行组织样本的表示方式进行存储与计算。

# 线性层和嵌入层

线性层和嵌入层属于 Transformer 模型中最基础的组件。线性层(Linear Layer)实现了一个简单的线性变换,通常用于特征变换和输出层。而嵌入层(Embedding Layer)则是将离散的 token ID 映射到连续的向量空间中。这两个模块实现都非常直接,PyTorch 中已经提供了现成的 torch.nn.Lineartorch.nn.Embedding 类,我们可以直接使用它们来构建模型,所以这里也不再细述。

本小节主要介绍参数初始化的相关细节。

模型参数初始化对于模型训练十分重要。如果参数初始化设置不当,可能会导致模型无法收敛或者收敛到局部最优解甚至出现梯度爆炸、梯度消失等问题。

第一章作业中指定的参数初始化标准为:

  • 线性层:N(μ=0,σ2=2din+dout)\mathcal{N}(\mu=0, \sigma^2=\frac{2}{d_{in}+d_{out}}) 并且截断在 [3σ,3σ][-3\sigma, 3\sigma] 范围内。
  • 嵌入层:N(μ=0,σ2=1\mathcal{N}(\mu=0, \sigma^2=1 并且截断在 [3,3][-3, 3] 范围内。
  • RMSNorm:初始为 1 的常数值。

具体的实现方式是使用 torch.nn.init.trunc_normal_ 函数来进行截断正态分布的初始化。

下面给出线性层的具体实现,就是简单的矩阵乘法,参数使用 torch.nn.Parameter 来定义,并遵循上述初始化标准:

# linear.py
import torch
import torch.nn as nn
class Linear(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weights = torch.nn.Parameter(
            torch.empty((out_features, in_features), device=device, dtype=dtype)
        )
        self._init_weights()
    def _init_weights(self):
        nn.init.trunc_normal_(
            self.weights, mean=0.0, std=(2 / (self.in_features + self.out_features)) ** 0.5, a=-3, b=3
        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.matmul(x, self.weights.T)

嵌入层就不多赘述了,采用索引方式从嵌入矩阵中获取对应的向量即可。

# RMSNorm

RMSNorm (Root Mean Square Layer Normalization) 是一种归一化方法,常用于 Transformer 模型中。与传统的 LayerNorm 不同,RMSNorm 只使用输入的均方根(Root Mean Square)来进行归一化,而不使用均值。这种方法在某些情况下可以提高模型的训练稳定性和性能。

RMSNorm 的计算公式如下:

RMSNorm(x)=x1di=1dxi2+ϵγ\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d} \sum_{i=1}^{d} x_i^2 + \epsilon}} \cdot \gamma

xx 是输入向量,dd 是输入的维度,ϵ\epsilon 是一个小常数,用于防止除零错误,γ\gamma 是一个可学习的缩放参数,按照上节的参数初始化标准,γ\gamma 的初始值为 1。

在语言模型当中,中间的激活值形状通常是 (batch_size,seq_len,hidden_dim)(\text{batch\_size}, \text{seq\_len}, \text{hidden\_dim}),因此在计算 RMSNorm 时,我们会沿着最后一个特征维度(hidden_dim)进行归一化,且每个样本的归一化是独立进行的。

norm.py 文件中实现 RMSNorm 层,代码如下:

# norm.py
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
    def __init__(
        self,
        d_model: int,
        eps: float = 1e-5,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        super().__init__()
        self.d_model = d_model
        self.eps = eps
        # Initialize scale parameter to ones
        self.scale = torch.nn.Parameter(torch.ones((d_model,), device=device, dtype=dtype))
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the RMSNorm layer.
        Args:
            x (torch.Tensor): Input tensor of shape (..., d_model).
        Returns:
            torch.Tensor: Output tensor of the same shape as input.
        """
        in_dtype = x.dtype
        x = x.to(self.scale.dtype)
        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)  # (..., 1)
        result = (x / rms) * self.scale
        return result.to(in_dtype)

# Position-Wise Feed-Forward Network

传统的基于 ReLU 激活函数的前馈神经网络(Feed-Forward Network, FFN)在某些任务中可能存在表达能力不足的问题。为了提升模型的性能,近年来一些研究提出了 GLU(Gated Linear Units)机制,通过引入门控机制来增强模型的非线性表达能力(移除偏置项):

GLU(x)=σ(W1x)(W2x)\text{GLU}(x) = \sigma(W_1 x) \odot (W_2 x)

其中 σ\sigma 表示激活函数,\odot 表示逐元素乘法,W1W_1W2W_2 是两个线性变换的权重矩阵。GLU 通过将输入分成两部分,一部分经过激活函数处理后作为门控信号,另一部分直接进行线性变换,然后将两者进行逐元素乘法,实现了更复杂的非线性变换。

作业中实现的是 SwiGLU(Switchable Gated Linear Units)前馈网络,其中 σ\sigma 采用 SiLU(Sigmoid Linear Unit)激活函数,计算公式如下:

SiLU(x)=x1+ex\text{SiLU}(x) = \frac{x}{1 + e^{-x}}

最终的 SwiGLU 前馈网络的计算公式为:

SwiGLU(x)=W2(SiLU(W1x))(W3x)\text{SwiGLU}(x) = W_2(\text{SiLU}(W_1 x)) \odot (W_3 x)

其中 xRdmodelx \in \mathbb{R}^{d_{model}} 是输入向量,W1Rdff×dmodelW_1 \in \mathbb{R}^{d_{ff} \times d_{model}}W2Rdmodel×dffW_2 \in \mathbb{R}^{d_{model} \times d_{ff}}W3Rdmodel×dmodelW_3 \in \mathbb{R}^{d_{model} \times d_{model}} 分别是三个线性变换的权重矩阵。

swiglu.py 文件中实现 SwiGLU 前馈网络,代码如下:

# swiglu.py
import torch
import torch.nn as nn
def silu(x: torch.Tensor) -> torch.Tensor:
    return x * torch.sigmoid(x)
class SwiGLU(nn.Module):
    def __init__(
        self,
        d_model: int,
        d_ff: int | None = None,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        super().__init__()
        self.d_model = d_model
        self.d_ff = d_ff if d_ff is not None else 8 * d_model // 3
        self.weights1 = torch.nn.Parameter(torch.empty(self.d_ff, d_model, device=device, dtype=dtype))
        self.weights2 = torch.nn.Parameter(torch.empty(self.d_model, self.d_ff, device=device, dtype=dtype))
        self.weights3 = torch.nn.Parameter(torch.empty(self.d_ff, d_model, device=device, dtype=dtype))
        self._init_weights()
    def _init_weights(self):
        std = (2 / (self.d_model + self.d_ff)) ** 0.5
        nn.init.trunc_normal_(self.weights1, mean=0.0, std=std, a=-3 * std, b=3 * std)
        nn.init.trunc_normal_(self.weights2, mean=0.0, std=std, a=-3 * std, b=3 * std)
        nn.init.trunc_normal_(self.weights3, mean=0.0, std=std, a=-3 * std, b=3 * std)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x1 = torch.matmul(x, self.weights1.T)  # W1*x
        x3 = torch.matmul(x, self.weights3.T)  # W3*x
        x2 = silu(x1) * x3  # SiLu(x1) * x3
        out = torch.matmul(x2, self.weights2.T)  # W2*x2
        return out

# RoPE 旋转位置编码

RoPE(Rotary Positional Encoding)是目前主流的大模型中广泛使用的一种位置编码方法。与传统的绝对位置编码不同,RoPE 通过对向量进行旋转变换,将位置信息隐式注入到注意力计算中,从而实现对相对位置关系的建模。

RoPE 的理论推导见参考视频

二维空间举例来说,对于一个二维向量 aa,逆时针旋转一个角度 θ\theta 等价于与如下的旋转矩阵相乘:

R(θ)=[cosθsinθsinθcosθ]R(\theta) = \begin{bmatrix}\cos \theta & -\sin \theta \\ \sin \theta & \cos \theta\end{bmatrix}

RoPE 基于二维旋转的思想,拓展到高维空间。对于序列位置 ii 的 token,不妨记作 xiRdmodelx_i \in \mathbb{R}^{d_{model}},我们将其按维度两两分组,共有 dmodel/2d_{model}/2 组,每组对应一个二维子空间。对于第 k{1,2,...,dmodel/2}k \in \{1, 2, ..., d_{model}/2\} 组,其旋转矩阵为:

Rki=[cosθi,ksinθi,ksinθi,kcosθi,k]R^i_k = \begin{bmatrix}\cos \theta_{i,k} & -\sin \theta_{i,k} \\ \sin \theta_{i,k} & \cos \theta_{i,k}\end{bmatrix}

进而完整的旋转矩阵如下:

Ri=[R1i000R2i000Rdmodel/2i]R^i = \begin{bmatrix}R^i_1 & 0 & \cdots & 0 \\ 0 & R^i_2 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & R^i_{d_{model}/2}\end{bmatrix}

即对于输入向量 xix_i,RoPE 执行如下计算:

RoPE(xi)=Rixi\text{RoPE}(x_i) = R^i x_i

这样,在计算注意力之前,RoPE 会将位置信息编码到 query 和 key 向量中,使得模型能够更好地捕捉序列中元素之间的相对位置关系。这里有如下结论,具体推导见参考资料:经过旋转变换后的 query 和 key 向量的注意分数只与其相对位置有关

<Riq,Rjk>=<q,Rjik><R^i q, R^j k> = <q, R^{j-i} k>

实现中,θi,k\theta_{i,k} 的计算方式如下,Θ\Theta 是可设置的超参数:

θi,k=iΘ2(k1)/dmodel\theta_{i,k} = \frac{i}{\Theta^{2(k-1)/d_{model}}}

对应代码实现如下,文件为 rope.py

# rope.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class RoPE(nn.Module):
    def __init__(
        self,
        theta: float,
        d_k: int,
        max_seq_len: int,
        device: torch.device | None = None,
    ):
        super().__init__()
        self.theta = theta
        self.d_k = d_k
        self.max_seq_len = max_seq_len
        self.device = device
        self._register()
    def _register(self):
        pos = (
            torch.arange(self.max_seq_len, device=self.device).unsqueeze(1).repeat(1, self.d_k // 2)
        )  # (max_seq_len, d_k // 2)
        inv_freq = 1.0 / (self.theta ** (2 * torch.arange(self.d_k // 2, device=self.device) / self.d_k)) 
        self.register_buffer("sin_cache", torch.sin(pos * inv_freq))  # (max_seq_len, d_k // 2)
        self.register_buffer("cos_cache", torch.cos(pos * inv_freq))  # (max_seq_len, d_k // 2)
    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:        
        """
        Forward pass of the RoPE module.
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_length, d_k).
            token_positions (torch.Tensor): Tensor of shape (batch_size, seq_length).
        Returns:
            torch.Tensor: Output tensor of shape (batch_size, seq_length, d_k).
        """
        sin_pos = self.sin_cache[token_positions]  # (batch_size, seq_length, d_k // 2)
        cos_pos = self.cos_cache[token_positions]  # (batch_size, seq_length, d_k // 2)
        x_odd = x[..., 1::2]  # (batch_size, seq_length, d_k // 2)
        x_even = x[..., 0::2]  # (batch_size, seq_length, d_k // 2)
        x_even_rotated = x_even * cos_pos - x_odd * sin_pos
        x_odd_rotated = x_even * sin_pos + x_odd * cos_pos
	    
        # (batch_size, seq_length, d_k // 2, 2)
        x_rotated = torch.stack([x_even_rotated, x_odd_rotated], dim=-1)  
        x_rotated = x_rotated.view(x.shape)  # (batch_size, seq_length, d_k)
        return x_rotated

核心思想是,预先计算了所有位置的 sin 和 cos 值,并将它们注册为 buffer,在前向传播时可以直接索引使用,避免了重复计算,提高效率。

# 缩放点积注意力

缩放点积注意力是 Transformer 原论文中提出的一种注意力机制,也是语言模型中最常用的注意力机制之一。其计算公式如下:

Attention(Q,K,V)=softmax(QKdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^{\top}}{\sqrt{d_k}}\right) V

查询向量 QRnq×dkQ \in \mathbb{R}^{n_q \times d_k} 与键向量 KRnk×dkK \in \mathbb{R}^{n_k \times d_k} 的转置进行点积计算,得到一个注意力分数矩阵,然后除以 dk\sqrt{d_k} 进行缩放(dkd_k 是向量的维度),最后通过 softmax 函数得到权重分布,并与值向量 VRnv×dvV \in \mathbb{R}^{n_v \times d_v} 相乘得到最终的注意力输出。注意,这里的 KKVV 是一一对应的,即 nk=nvn_k = n_v

首先实现 softmax 操作( softmax.py )。虽然 PyTorch 已经提供了 torch.nn.Softmax 类,但我觉得有必要手动实现一个版本。其中为了数值稳定性,在计算 softmax 时会先减去输入维度上的最大值,避免指数函数计算时出现溢出:

Softmax(xi)=eximax(x)jexjmax(x)\text{Softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_{j} e^{x_j - \max(x)}}

# softmax.py
import torch
def softmax(
    input: torch.Tensor,
    dim: int = -1,
) -> torch.Tensor:
    """
    Applies the softmax function to the input tensor along the specified dimension.
    Args:
        input (torch.Tensor): The input tensor.
        dim (int): The dimension along which to apply the softmax function. Default is -1.
    Returns:
        torch.Tensor: The tensor after applying the softmax function.
    """
    exp_input = torch.exp(input - torch.max(input, dim=dim, keepdim=True).values)
    sum_exp = torch.sum(exp_input, dim=dim, keepdim=True)
    return exp_input / sum_exp

基于上述介绍,最终可以实现 scale_dot_product_attention 函数,该函数置于 attention.py 模块文件中。

该函数接受查询、键和值向量,还接受一个可选的 mask 参数,用于在计算注意力分数时屏蔽掉某些位置(例如未来位置或填充位置),类型为布尔张量, True 表示该位置可以参与注意力计算, False 表示该位置需要被屏蔽掉。如果提供了 mask,则在计算注意力分数后,将 mask 中对应位置的分数设置为一个非常大的负数,以确保这些位置在 softmax 后的权重接近于零。

# attention.py
import torch
import torch.nn as nn
def scale_dot_product_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    mask: torch.Tensor = None,
) -> torch.Tensor:
    """
    Computes the scaled dot-product attention.
    Args:
        query (torch.Tensor): The query tensor of shape (..., seq_len_q, d).
        key (torch.Tensor): The key tensor of shape (..., seq_len_k, d).
        value (torch.Tensor): The value tensor of shape (..., seq_len_k, d).
        mask (torch.Tensor, optional): An optional mask tensor of shape (..., seq_len_q, seq_len_k).
    Returns:
        torch.Tensor: The result of the attention mechanism.
    """
    assert key.size(-2) == value.size(-2), "Key and Value must have the same sequence length"
    scale = torch.sqrt(torch.tensor(query.size(-1), dtype=torch.float32))
    attention_scores = torch.matmul(query, key.transpose(-2, -1)) / scale  # (..., seq_len_q, seq_len_k)
    if mask is not None:
        attention_scores = attention_scores.masked_fill(mask == False, float("-inf"))
    attention_weights = torch.softmax(attention_scores, dim=-1)  # (..., seq_len_q, seq_len_k)
    output = torch.matmul(attention_weights, value)
    return output

# 多头自注意力

多头自注意力是 Transformer 模型的核心模块,将输入 token 通过线性层分别映射为多个头的查询、键和值向量,然后在每个头上独立计算缩放点积注意力,最后将各个头的输出拼接起来,再通过一个线性层进行变换得到最终的多头自注意力的输出。

给定 xR...×dmodelx \in \mathbb{R}^{... \times d_{model}} 作为输入,通过线性变换得到:

Q=Wqx,K=Wkx,V=WvxQ = W_q x, \quad K = W_k x, \quad V = W_v x

WqRhdk×dmodelW_q \in \mathbb{R}^{h \cdot d_k \times d_{model}}WkRhdk×dmodelW_k \in \mathbb{R}^{h \cdot d_k \times d_{model}}WvRhdv×dmodelW_v \in \mathbb{R}^{h \cdot d_v \times d_{model}} 分别是查询、键和值的线性变换权重矩阵,由此通过一次矩阵乘法将输入映射到多个头的查询、键和值空间中,避免单独为每个头独立计算查询、键和值向量,提高计算效率。虽然注意力头的数量 hh 和每个头的维度 dkd_kdvd_v 是超参数,但经验做法通常会满足 hdk=hdv=dmodelh \cdot d_k = h \cdot d_v = d_{model}

多头注意力计算公式如下:

MultiHead(Q,K,V)=WoConcat(head1,head2,...,headh)\text{MultiHead}(Q, K, V) = W_o \cdot\text{Concat}(\text{head}_1, \text{head}_2, ..., \text{head}_h)

对于每个头 headihead_i,执行 2.6 节所示的缩放点积注意力:

headi=Attention(Qi,Ki,Vi)head_i = \text{Attention}(Q_i, K_i, V_i)

QiQ_iKiK_iViV_i 分别是第 ii 个头对应的查询、键和值向量,WoRdmodel×hdvW_o \in \mathbb{R}^{d_{model} \times h \cdot d_v} 是多头注意力输出的线性变换权重矩阵。

在计算注意力时,Transformer 模型使用一个因果掩码(Causal Mask)来确保模型只能关注当前 token 之前的 token,防止信息泄露。该掩码是一个上三角矩阵,其中主对角线及其下方的元素为 True ,表示这些位置可以参与注意力计算,而上方的元素为 False ,表示这些位置需要被屏蔽掉。

最后,还要考虑在计算注意力之前,先对查询和键向量应用 RoPE 位置编码,以注入位置信息,使得模型能够更好地捕捉序列中元素之间的相对位置关系。注意,这里是针对每个头进行的 RoPE 位置编码,而不是针对整个输入向量进行的。

下面给出多头自注意力模块的具体实现,包括 MultiHeadSelfAttentionMultiHeadRoPESelfAttention 两个类,前者是标准的多头自注意力实现,后者则在前者的基础上加入了 RoPE 位置编码,均写在 attention.py 模块中:

# attention.py
from cs336_basics.module.rope import RoPE
class MultiHeadSelfAttention(nn.Module):
    """Implements Multi-Head Self-Attention mechanism."""
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.Q_weights = nn.Linear(d_model, d_model, bias=False)
        self.K_weights = nn.Linear(d_model, d_model, bias=False)
        self.V_weights = nn.Linear(d_model, d_model, bias=False)
        self.out_weights = nn.Linear(d_model, d_model, bias=False)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # (..., seq_len, head, d_model // head)
        Q = self.Q_weights(x).view(*x.shape[:-1], self.num_heads, self.d_model // self.num_heads)
        K = self.K_weights(x).view(*x.shape[:-1], self.num_heads, self.d_model // self.num_heads)
        V = self.V_weights(x).view(*x.shape[:-1], self.num_heads, self.d_model // self.num_heads)
        # (..., head, seq_len, d_model // head)
        Q = Q.transpose(-3, -2).contiguous()
        K = K.transpose(-3, -2).contiguous()
        V = V.transpose(-3, -2).contiguous()
        # Create casual mask
        seq_len = x.size(-2)
        casual_mask = torch.tril(torch.ones((seq_len, seq_len), device=x.device)).bool()
        # calculate attention and combine heads
        # (..., head, seq_len, d_model // head)
        attention_output = scale_dot_product_attention(Q, K, V, mask=casual_mask)
        # (..., seq_len, head, d_model // head)
        attention_output = attention_output.transpose(-3, -2).contiguous().view(*x.shape[:-1], self.d_model)
        output = self.out_weights(attention_output)
        return output
class MultiHeadRoPESelfAttention(MultiHeadSelfAttention):
    """Implements Multi-Head Self-Attention mechanism with RoPE."""
    def __init__(self, d_model: int, num_heads: int, max_seq_len: int, theta: float):
        super().__init__(d_model, num_heads)
        self.d_k = d_model // num_heads
        self.rope = RoPE(theta=theta, d_k=self.d_k, max_seq_len=max_seq_len)
    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
        """
        x: torch.Tensor of shape (..., seq_len, d_model)
        token_positions: torch.Tensor of shape (..., seq_len)
        """
        # prepare Q, K, V
        # (..., seq_len, head, d_k)
        Q = self.Q_weights(x).view(*x.shape[:-1], self.num_heads, self.d_k)
        K = self.K_weights(x).view(*x.shape[:-1], self.num_heads, self.d_k)
        V = self.V_weights(x).view(*x.shape[:-1], self.num_heads, self.d_k)
        # (..., head, seq_len, d_model // head)
        Q = Q.transpose(-3, -2).contiguous()
        K = K.transpose(-3, -2).contiguous()
        V = V.transpose(-3, -2).contiguous()
        # reshape token_positions
        # (..., head, seq_len)
        token_positions = token_positions.unsqueeze(-2)
        token_positions = token_positions.expand(*token_positions.shape[:-2], self.num_heads, -1)
        # apply RoPE
        # (..., head, seq_len, d_k)
        Q = self.rope(Q, token_positions)
        K = self.rope(K, token_positions)
        # create casual mask
        seq_len = x.size(-2)
        casual_mask = torch.tril(torch.ones((seq_len, seq_len), device=x.device)).bool()
        # calculate attention and combine heads
        # (..., head, seq_len, d_model // head)
        attention_output = scale_dot_product_attention(Q, K, V, mask=casual_mask)
        # (..., seq_len, head, d_model // head)
        attention_output = attention_output.transpose(-3, -2).contiguous().view(*x.shape[:-1], self.d_model)
        output = self.out_weights(attention_output)
        return output

# Transformer 模型

第二节介绍了构建 Transformer 模型所需的基础模块组件,接下来我们将基于这些组件搭建一个 Transformer 模型。课程作业的 Transformer 模型架构如下图所示:

Transformer 模型架构

由图示可知,Transformer 模型首先通过一个嵌入层将输入的 token ID 映射到连续的向量空间中,然后进入多个堆叠的 Transformer 基本块(Transformer Block),最后经过归一化和线性层转为输出的词汇表概率分布。

其中 Transformer 基本块是最为核心的组件。每个基本块包含一个多头自注意力机制模块和一个前馈神经网络模块。课程实现的基本块架构如下:

Pre-Norm Transformer Block

这是一种 Pre-Norm Transformer Block,在多头自注意力模块和前馈神经网络模块之前都进行了 RMSNorm 归一化处理,并且归一化操作不影响残差连接的主体路径。相比于 Post-Norm Transformer Block,Pre-Norm 结构在训练深层 Transformer 模型时更为稳定,能够有效缓解梯度消失问题。

我们基于前面实现的多头自注意力模块、前馈神经网络、RMSNorm 等,在 transformer.py 模块中实现 TransformerBlock 类,代码如下:

# transformer.py
import torch
import torch.nn as nn
from cs336_basics.module.attention import MultiHeadRoPESelfAttention
from cs336_basics.module.swiglu import SwiGLU
from cs336_basics.module.norm import RMSNorm
from cs336_basics.module.embedding import Embedding
class TransformerBlock(nn.Module):
    """
    A single Transformer block consisting of Multi-Head Self-Attention and Feed-Forward Network.
    """
    def __init__(self, d_model: int, num_heads: int, d_ff: int, max_seq_len: int, theta: float):
        super().__init__()
        self.attention = MultiHeadRoPESelfAttention(d_model, num_heads, max_seq_len, theta)
        self.norm1 = RMSNorm(d_model)
        self.ffn = SwiGLU(d_model, d_ff)
        self.norm2 = RMSNorm(d_model)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the Transformer block.
        Args:
            x (torch.Tensor): Input tensor of shape (..., seq_len, d_model).
        Returns:
            torch.Tensor: Output tensor of the same shape as input.
        """
        # Multi-Head Self-Attention with residual connection
        token_positions = torch.arange(x.size(-2), device=x.device)  # (seq_len,)
        x = x + self.attention(self.norm1(x), token_positions)
        # Feed-Forward Network with residual connection
        x = x + self.ffn(self.norm2(x))
        return x

将多个 TransformerBlock 堆叠起来,并在输入端添加嵌入层,在输出端添加归一化层和线性层,就构成了完整的 Transformer 模型(注意,我们在模型中不显示添加 softmax)。代码实现如下:

# transformer.py
class Transformer(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        num_layers: int,
        d_model: int,
        num_heads: int,
        d_ff: int,
        max_seq_len: int,
        theta: float,
    ):
        super().__init__()
        self.embedding = Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList(
            [TransformerBlock(d_model, num_heads, d_ff, max_seq_len, theta) for _ in range(num_layers)]
        )
        self.norm = RMSNorm(d_model)
        self.linear = nn.Linear(d_model, vocab_size, bias=False)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embedding(x)  # (..., seq_len, d_model)
        for layer in self.layers:
            x = layer(x)  # (..., seq_len, d_model)
        x = self.linear(self.norm(x))  # (..., seq_len, vocab_size)
        return x

# 模型训练

# 训练目标与自监督学习

语言模型的核心任务是对自然语言序列进行概率建模,即学习联合概率分布:

P(x1,x2,...,xT)P(x_1, x_2, ..., x_T)

在自回归(Autoregressive)语言模型中,该联合概率通常通过链式法则分解为:

P(x1,x2,...,xT)=t=1TP(xtx<t)P(x_1, x_2, ..., x_T) = \prod_{t=1}^{T}P(x_t \mid x_{<t})

即在给定历史上下文 x<tx_{<t} 的条件下,预测当前位置 token xtx_t条件概率

在实际训练与推理过程中,输入文本首先通过 tokenizer 编码为词表中的 token ID 序列。模型在每一个时间步基于已有上下文预测下一个 token 的概率分布,并在生成阶段按该分布逐步生成文本。

现代大型语言模型的预训练过程,本质上是一种基于大规模语料的自监督学习(Self-Supervised Learning)。所谓 “自监督”,是指训练目标直接从原始数据中自动构造,而无需人工标注。例如在自回归语言模型中,输入为前 t1t-1 个 token,训练目标为第 tt 个 token。因此,模型通过预测下一个 token 进行训练,而监督信号本身来自文本数据。这种训练范式使得模型能够利用海量无标签文本进行学习,是大规模语言模型成功的关键基础。

由于语言模型输出的是条件概率,本质上是一个多分类问题(类别数等于词表大小)。因此训练时通常采用交叉熵损失函数,用于衡量模型预测分布与真实 token 之间的差异。

利用 Transformer 并行计算的特点,我们可以同时计算整个序列各个位置 token 的预测分布,并通过掩码机制确保模型只能访问当前 token 之前的上下文,从而高效地进行训练。

本节代码实现位于 train 包中,即创建一个 train 文件夹,并在其中实现训练相关的代码。

# 交叉熵损失函数

针对语言模型,给定一训练语料集 DD 包含多个文本序列,每个序列由 TT 个 token 组成,模型训练所用的交叉熵损失函数的数学定义如下:

L=1DT(x1,x2,...,xT)Dt=1TlogP(xtx<t)\mathcal{L} = -\frac{1}{|D|T} \sum_{(x_1, x_2, ..., x_T) \in D} \sum_{t=1}^{T} \log P(x_t \mid x_{<t})

我们在 train/utils.py 模块中实现一个函数 cross_entropy_loss 来计算交叉熵损失。该函数接受模型的输出 logits 和真实 token ID 作为输入,该 logits 默认是模型的原始输出,未经过 softmax 处理。因此在代码实现中,需要使用下述公式来计算 P(xtx<t)P(x_t \mid x_{<t})

P(xtx<t)=softmax(logitsxt)=elogitsxtmax(logitsx)jelogitsxjmax(logitsx)P(x_t \mid x_{<t}) = softmax(\text{logits}_{x_t}) = \frac{e^{\text{logits}_{x_t}-max(\text{logits}_{x})}}{\sum_{j} e^{\text{logits}_{x_j}-max(\text{logits}_x)}}

但由于 softmax 本身存在的数值稳定性问题,我们在计算交叉熵损失时,通常会利用对数的性质,将其改写成:

logP(xtx<t)=logitsxtmax(logitsx)log(jelogitsxjmax(logitsx))log P(x_t \mid x_{<t}) = \text{logits}_{x_t} - max(\text{logits}_x) - \log\left(\sum_{j} e^{\text{logits}_{x_j}-max(\text{logits}_x)}\right)

最终代码实现如下:

# utils.py
def cross_entropy_loss(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    """
    Computes the cross-entropy loss between logits and targets.
    Args:
        logits (torch.Tensor): Predicted logits of shape (..., num_classes).
        targets (torch.Tensor): Ground truth labels of shape (...,).
    Returns:
        torch.Tensor: Scalar tensor representing the cross-entropy loss.
    """
    # Flatten the tensors to apply cross-entropy loss
    logits_flat = logits.view(-1, logits.size(-1))  # (N, num_classes)
    targets_flat = targets.view(-1)  # (N,)
    size = logits_flat.size(0)
    # compute cross-entropy loss
    logits_target = logits_flat[torch.arange(size), targets_flat]  # (N,)
    logits_max = logits_flat.max(dim=-1).values  # (N,)
    loss = -logits_target + logits_max + torch.log(
        torch.exp(logits_flat - logits_max.unsqueeze(-1)).sum(dim=-1)
    )
    return loss.mean()

这里补充一下困惑度 (Perplexity) 的定义,困惑度是语言模型中常用的评估指标,它反映了模型对测试数据的预测能力。Perplexity 值越低,表示模型对数据的预测越好。Perplexity 与交叉熵损失之间的关系如下:

Perplexity=eL\text{Perplexity} = e^{\mathcal{L}}

因此可以很轻松地复用上述交叉熵损失函数来计算困惑度。

# SGD 优化器

深度学习基于损失函数,通过梯度下降算法来优化模型参数。其中最简单和最基本的梯度下降算法是随机梯度下降(Stochastic Gradient Descent, SGD),其更新规则如下:

θθηθL\theta \leftarrow \theta - \eta \nabla_\theta \mathcal{L}

其中 θ\theta 是模型参数,η\eta 是学习率,θL\nabla_\theta \mathcal{L} 是当前批次数据计算出的损失函数关于参数的梯度。

下面是课程提供的一个简单 SGD 优化器实现,其中学习率根据当前训练迭代次数进行衰减:

from collections.abc import Callable, Iterable
from typing import Optional
import torch
import math
class SGD(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3):
        if lr < 0:
            raise ValueError(f"Invalid learning rate: {lr}")
        defaults = {"lr": lr}
        super().__init__(params, defaults)
    def step(self, closure: Optional[Callable] = None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            lr = group["lr"]  # Get the learning rate.
            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p]  # Get state associated with p.
                t = state.get("t", 0)  # Get iteration number from the state, or initial value.
                grad = p.grad.data  # Get the gradient of loss with respect to p.
                p.data -= lr / math.sqrt(t + 1) * grad  # Update weight tensor in-place.
                state["t"] = t + 1  # Increment iteration number.
        return loss

基于 PyTorch 的 torch.optim.Optimizer 类可以定义自己的优化器,其中最关键的是 __init__ 方法和 step 方法:

  • __init__ 方法:自定义优化器的构造函数通常负责两件事。第一,接收并组织待优化参数 params ,可以是参数迭代器(如 model.parameters() ),表示公用一套超参数;也可以是参数组列表(list of dict),每组 dict 包含 "params" 以及该组独立的超参数(例如使用不同的学习率)。第二,将默认超参数写入 defualt 字典中,与 params 一起传递给基类的 __init__ 方法。这一步非常关键,会把参数整理成 self.param_groups (列表,每个元素是一个参数组 dict),并将 defaults 作为每个参数组的默认配置
  • step 方法:每次调用 step 方法时,优化器会根据当前参数的梯度来更新参数值。通常会遍历 self.param_groups 中的每个参数组,遍历参数组的每个参数,对其进行更新。更新规则根据具体优化算法而定,例如 SGD 就是按照上述公式进行更新。此外,大多数优化器都需要为每个参数额外保存信息,这些都存放在 self.state[p] 这个字典中,第一次遇到参数时通常需要初始化这些状态。

# AdamW 优化器

在现代深度学习训练中,SGD 虽然概念简单、实现直观,但往往需要精细的学习率策略与动量调参才能达到较好的收敛效果。因此在大多数 Transformer / 大模型训练实践中,更常用的是 AdamW:它在 Adam 的基础上引入了解耦权重衰减(decoupled weight decay),在训练稳定性与泛化表现上通常更可靠,也是本课程要求手动实现的优化器。

Adam 的核心是为每个参数维护两类动量统计量:

  • 一阶动量(梯度的指数滑动平均)

    mt=β1mt1+(1β1)gtm_t = \beta_1 m_{t-1} + (1-\beta_1) g_t

  • 二阶动量(梯度平方的指数滑动平均)

    vt=β2vt1+(1β2)gt2v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2

由于 mt,vtm_t, v_t 在训练初期会偏向 0,Adam 会使用偏置修正来得到更准确的估计,并据此对参数进行自适应步长更新。

AdamW 的关键区别在于 weight decay 的处理方式:在 Adam 中,若直接把 L2 正则项并入梯度,会与自适应缩放耦合,导致 “权重衰减” 的实际效果不再等价于传统意义的 weight decay。AdamW 采用 “解耦” 的方式:先按 Adam 的规则更新梯度方向,再额外做一次权重衰减更新,使 weight decay 更符合预期。

AdamW 对应的完整算法流程如下:

AdamW 优化器算法流程k

代码实现如下( optimizer.py ):

# optimizer.py
import torch
import math
from collections.abc import Callable, Iterable
from typing import Optional
class AdamW(torch.optim.Optimizer):
    def __init__(
        self,
        params,
        lr: float = 1e-3,
        betas: tuple = (0.9, 0.999),
        eps: float = 1e-8,
        weight_decay: float = 0.01,
    ):
        """
        Initializes the AdamW optimizer.
        Args:
            params (iterable): Iterable of parameters to optimize.
            lr (float): Learning rate.
            betas (tuple): Coefficients used for computing running averages of gradient and its square.
            eps (float): Term added to the denominator to improve numerical stability.
            weight_decay (float): Weight decay (L2 penalty).
        """
        assert lr > 0.0, "Learning rate must be positive."
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(AdamW, self).__init__(params, defaults)
    def step(self, closure: Optional[Callable] = None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                state = self.state[p]
                # State initialization
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p.data)
                    state["exp_avg_sq"] = torch.zeros_like(p.data)
                # get state variables
                state["step"] += 1
                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
                beta1, beta2 = group["betas"]
                lr = group["lr"]
                eps = group["eps"]
                weight_decay = group["weight_decay"]
                t = state["step"]
                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).add_(grad.pow(2), alpha=1 - beta2)
                # Learning rate correction
                lr_t = lr * math.sqrt(1 - beta2**t) / (1 - beta1**t)
                # Update parameters
                p.data -= lr_t * exp_avg / (torch.sqrt(exp_avg_sq) + eps)
                p.data -= lr * weight_decay * p.data

AdamW 的实现可以按 “状态初始化 → 动量更新 → 偏置修正 → 参数更新 → 权重衰减” 的顺序理解:

  1. 为每个参数维护 state:包括 exp_avg 一阶动量 mtm_texp_avg_sq 二阶动量 vtv_tstep 步数 tt

  2. 动量更新

    • exp_avg 使用 β1\beta_1 做梯度的指数滑动平均。
    • exp_avg_sq 使用 β2\beta_2 做梯度平方的指数滑动平均。
  3. 偏置修正(bias correction):训练早期动量偏小,因此通过

    m^t=mt1β1t,v^t=vt1β2t\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}

    等价地在代码中体现为对学习率进行修正。

  4. 按 Adam 规则更新参数:使用 m^t/(v^t+ϵ)\hat{m}_t / (\sqrt{\hat{v}_t}+\epsilon) 作为更新方向与尺度

  5. 解耦的权重衰减:更新完参数后额外执行:

    wwlrλww \leftarrow w - lr \cdot \lambda \cdot w

    这一步不通过梯度,而是直接对参数做衰减,因此称为 “解耦”。

# 自适应学习率调整

学习率调度器(Learning Rate Scheduler)是训练过程中动态调整学习率的工具,能够帮助模型更快地收敛并达到更好的性能。

本课程使用的是经典的余弦退火(Cosine Annealing)学习率调度器,其核心思想是将学习率按照余弦函数的形式从初始值逐渐降低到一个最小值。给定最小学习率 etamineta_{min}, 最大学习率 ηmax\eta_{max},以及当前训练步数 tt,和预热步数 twarmupt_{warmup},以及总训练步数 TT,余弦退火学习率调度器的计算公式如下:

Warm up:学习率从 0 线性增加到 ηmax\eta_{max},帮助模型训练初期更稳定,也更不容易陷入局部最优。

ηt=ttwarmupηmax,for t<twarmup\eta_t = \frac{t}{t_{warmup}} \cdot \eta_{max}, \quad \text{for } t < t_{warmup}

Cosine Annealing:在预热阶段结束后,学习率按照余弦函数逐渐降低到 ηmin\eta_{min},有助于模型在训练后期更细致地调整参数。

ηt=ηmin+12(ηmaxηmin)(1+cos(πttwarmupTtwarmup)),for twarmuptT\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min}) \left(1 + \cos\left(\pi \cdot \frac{t - t_{warmup}}{T - t_{warmup}}\right)\right), \quad \text{for } t_{warmup} \leq t \leq T

Post Annealing:当训练步数超过总训练步数时,学习率保持在最小值 ηmin\eta_{min}

ηt=ηmin,for t>T\eta_t = \eta_{min}, \quad \text{for } t > T

utils.py 下实现 get_lr_cosine_schedule 函数来计算当前训练步数对应的学习率:

# utils.py
def get_lr_cosine_schedule(
    lr_min: float,
    lr_max: float,
    step: int,
    warmup_steps: int,
    cosine_steps: int,
) -> float:
    """Compute learning rate at given step with linear warmup and cosine decay.
    Args:
        lr_min (float): Minimum learning rate.
        lr_max (float): Maximum learning rate.
        step (int): Current training step.
        warmup_steps (int): Number of warmup steps T_w.
        cosine_steps (int): Number of steps for cosine decay T_c.
    Returns:
        float: Computed learning rate for the current step.
    """
    assert step >= 0, "Step must be non-negative"
    # warmup phase
    if step < warmup_steps:
        lr = lr_max * step / warmup_steps
    # cosine decay phase
    elif step <= cosine_steps:
        progress = (step - warmup_steps) / (cosine_steps - warmup_steps)
        lr = lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress))
    # post-annealing phase
    else:
        lr = lr_min
    return lr

# Gradient Clipping

在训练过程中,有时会遇到产生较大梯度的训练样本,这可能会破坏训练的稳定性。为了缓解这一问题,实践中常用的一种技术是梯度裁剪(gradient clipping)。其核心思想是,在每次执行优化器更新步骤之前,对参数梯度的范数施加一个上限约束。

给定所有参数的梯度 gg,我们先计算其 L2L_2 范数 g2\|g\|_2。如果该范数小于最大值 MM,则保持 gg 不变;否则,用 Mg2+ϵ\frac{M}{\|g\|_2 + \epsilon}(其中 ϵ\epsilon 是一个很小的数,用于保证数值稳定性)的系数对 gg 进行缩放。需要注意,缩放后梯度的范数将略小于 MM

代码实现如下( utils.py ):

# utils.py
def gradient_clipping(parameters: Iterable[torch.nn.Parameter], max_l2_norm: float) -> None:
    """
    Clips the gradients of the given parameters to have a maximum norm.
    Args:
        parameters (Iterable[torch.nn.Parameter]): Iterable of model parameters.
        max_l2_norm (float): Maximum allowed L2 norm for the gradients.
    """
    param_norm = torch.sqrt(
        sum(torch.sum(param.grad.data**2) for param in parameters if param.grad is not None)
    )
    if param_norm > max_l2_norm:
        clip_coef = max_l2_norm / (param_norm + 1e-6)
        for param in parameters:
            if param.grad is not None:
                param.grad.data.mul_(clip_coef)

提醒,这里的 L2L_2 范数是针对所有参数的梯度计算的,而不是单个参数的梯度。这种全局的梯度裁剪方式在训练大型模型时更为常见,因为它能够更有效地控制整体梯度的规模,避免某些参数的梯度过大导致训练不稳定。

# 预训练数据

在 BPE 实验中,我们已经在 data/inyStoriesV2-GPT4-train.txt 上进行了 tokenizer 的训练,并生成了对应的 token ID 序列保存成 npy 文件。接下来我们需要从这个 npy 文件中加载数据,并在训练过程中从中采样出一个 batch 的输入和目标。

按照课程要求,在 train/utils.py 中实现 get_data_batch 函数来完成这个功能。该函数接受一个 numpy 数组(token ID 序列),以及 batch size 和上下文长度等参数,返回一个 batch 的输入和目标张量。输入张量的 shape 是 (batch_size, context_length) ,目标张量的 shape 也是 (batch_size, context_length) ,其中目标张量是输入张量向右平移一位得到的,即每个位置的目标是下一个位置的 token ID。

# utils.py
def get_data_batch(
    dataset: npt.NDArray,
    batch_size: int,
    context_length: int,
    device: torch.device = torch.device("cpu"),
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Take a numpy array of token IDs and return a batch of input-target pairs for training.
    Args:
        dataset (Iterable[int]): Iterable of token IDs.
        batch_size (int): Number of sequences in the batch.
        context_length (int): Length of each input sequence.
        device (torch.device): Device to place the tensors on.
    Returns:
        tuple[torch.Tensor, torch.Tensor]: A tuple containing:
            - inputs: Tensor of shape (batch_size, context_length)
            - targets: Tensor of shape (batch_size, context_length)
    """
    length = len(dataset)
    inputs = np.zeros((batch_size, context_length), dtype=np.int64)
    targets = np.zeros((batch_size, context_length), dtype=np.int64)
    input_indices = np.random.randint(0, length - context_length, size=batch_size)
    for i, idx in enumerate(input_indices):
        inputs[i] = dataset[idx : idx + context_length]
        targets[i] = dataset[idx + 1 : idx + context_length + 1]
    return torch.tensor(inputs, device=device), torch.tensor(targets, device=device)

注意,这里的训练数据已经通过 tokenizer 转换成了 token ID 的形式,并且保存在一个 npy 文件。而这个文件通常会非常大,无法一次性加载到内存中。因此在实际训练过程中,我们通常采用内存映射(memory-mapped) 的方式来访问数据,即使用 numpy.memmap 来创建一个内存映射对象,这样就可以像访问普通 numpy 数组一样访问磁盘上的数据,而不需要将整个文件加载到内存中。

其次,上述 get_data_batch 函数中,我们通过随机采样的方式从数据集中选取起始位置来构建输入和目标序列。其实并不是标准的数据加载方式,因为每一次我们都是从数据集中随机选取一个位置来构建一个 batch 的输入和目标,这样可能会导致训练过程中某些样本被过度采样,而另一些样本则很少被采样到。更常见的做法是使用一个数据加载器(DataLoader),它会在每个 epoch 结束后自动打乱数据,并按照一定的顺序来加载数据,确保每个样本在训练过程中都能被均匀地访问到。因此在实际训练中,我们通常会实现一个自定义的 Dataset 类来封装数据访问逻辑,并使用 PyTorch 的 DataLoader 来加载数据,这样可以更高效地进行训练。

为此,在 utils.py 下实现如下 TokenDataset 类:

class TokenDataset(Dataset):
    def __init__(self, file_path: str, context_length: int):
        self.data = np.load(file_path, mmap_mode="r")
        self.context_length = context_length
        self.total_size = len(self.data) - context_length - 1
    def __len__(self):
        return self.total_size
    def __getitem__(self, idx):
        x = self.data[idx : idx + self.context_length].astype(np.int64)
        y = self.data[idx + 1 : idx + self.context_length + 1].astype(np.int64)
        return torch.from_numpy(x), torch.from_numpy(y)

# Checkpoint

最后,为了能够在训练过程中保存模型的状态,以便后续恢复训练或者进行推理,我们需要实现一个 ** 模型检查点(checkpoint)** 的功能。模型检查点通常包含模型的参数、优化器的状态、当前训练步数等信息,可以通过 torch.savetorch.load 来进行保存和加载。

utils.py 中实现 save_checkpointload_checkpoint 两个函数来完成模型检查点的保存和加载:

# utils.py
def save_checkpoint(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    iteration: int,
    out: str | os.PathLike | typing.BinaryIO | typing.IO[bytes],
) -> None:
    """Save the model and optimizer state to a checkpoint file."""
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "iteration": iteration,
    }
    torch.save(checkpoint, out)
def load_checkpoint(
    src: str | os.PathLike | typing.BinaryIO | typing.IO[bytes],
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
) -> int:
    checkpoint = torch.load(src, map_location="cpu")
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    return checkpoint["iteration"]

# 训练脚本

基于上述实现的组件,我们可以编写一个训练脚本来进行模型的训练。该脚本将负责加载数据、初始化模型和优化器、执行训练循环,并在训练过程中保存检查点。

该脚本位于 cs336_basics/main.py ,代码实现如下:

# main.py
import argparse
import os
import time
import torch
import numpy as np
from torch.utils.data import DataLoader
from dataclasses import dataclass
from tqdm import tqdm
from cs336_basics.module import Transformer
from cs336_basics.train import (
    AdamW,
    TokenDataset,
    cross_entropy_loss,
    get_lr_cosine_schedule,
    gradient_clipping,
    load_checkpoint,
    save_checkpoint,
)
@dataclass
class TrainConfig:
    # Data paths.
    train_data: str
    val_data: str
    out_dir: str = "checkpoints"
    # Model hyperparameters.
    vocab_size: int = 32000
    context_length: int = 256
    d_model: int = 384
    d_ff: int = 1024
    num_layers: int = 6
    num_heads: int = 8
    rope_theta: float = 10000.0
    # Optimization settings.
    batch_size: int = 16
    learning_rate: float = 3e-4
    min_lr: float = 3e-5
    weight_decay: float = 0.01
    max_iters: int = 5000
    warmup_steps: int = 100
    grad_clip: float = 1.0
    # Runtime and logging.
    device: str = "cpu"
    seed: int = 42
    num_workers: int = 4
    log_interval: int = 50
    eval_interval: int = 200
    eval_batches: int = 20
    save_interval: int = 1000
    resume: str | None = None
def parse_args() -> TrainConfig:
    parser = argparse.ArgumentParser(description="Train a Transformer language model with TokenDataset.")
    parser.add_argument("--train-data", type=str, required=True, help="Path to train .npy token file.")
    parser.add_argument("--val-data", type=str, required=True, help="Path to val .npy token file.")
    parser.add_argument("--out-dir", type=str, default="checkpoints", help="Directory for checkpoints.")
    parser.add_argument("--batch-size", type=int, default=16, help="Batch size.")
    parser.add_argument("--max-iters", type=int, default=5000, help="Total train steps.")
    parser.add_argument("--device", type=str, default="cpu", help="Device to train on.")
    parser.add_argument("--resume", type=str, default=None, help="Checkpoint path to resume from.")
    args = parser.parse_args()
    return TrainConfig(
        train_data=args.train_data,
        val_data=args.val_data,
        out_dir=args.out_dir,
        batch_size=args.batch_size,
        max_iters=args.max_iters,
        device=args.device,
        resume=args.resume,
    )
@torch.no_grad()
def evaluate(model: torch.nn.Module, val_loader: DataLoader, config: TrainConfig, device: torch.device) -> float:
    model.eval()
    losses: list[float] = []
    for batch_idx, (x, y) in enumerate(val_loader):
        if batch_idx >= config.eval_batches:
            break
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = cross_entropy_loss(logits, y)
        losses.append(loss.item())
    model.train()
    return float(np.mean(losses)) if losses else 0.0
def build_dataloaders(config: TrainConfig, device: torch.device) -> tuple[DataLoader, DataLoader]:
    train_dataset = TokenDataset(config.train_data, config.context_length)
    val_dataset = TokenDataset(config.val_data, config.context_length)
    use_pin_memory = device.type == "cuda"
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=use_pin_memory,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=use_pin_memory,
    )
    print(f"[Data] Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
    return train_loader, val_loader
def build_model_and_optimizer(
    config: TrainConfig, device: torch.device
) -> tuple[torch.nn.Module, torch.optim.Optimizer]:
    # Create model and optimizer from config.
    model = Transformer(
        vocab_size=config.vocab_size,
        num_layers=config.num_layers,
        d_model=config.d_model,
        num_heads=config.num_heads,
        d_ff=config.d_ff,
        max_seq_len=config.context_length,
        theta=config.rope_theta,
    ).to(device)
    optimizer = AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    return model, optimizer
def resume(config: TrainConfig, model: torch.nn.Module, optimizer: torch.optim.Optimizer) -> int:
    # Restore model state if resume checkpoint is provided.
    if config.resume and os.path.exists(config.resume):
        print(f"[Resume] Loading from {config.resume} ...")
        return load_checkpoint(config.resume, model, optimizer) + 1
    return 1
def train(config: TrainConfig) -> None:
    # Setup
    os.makedirs(config.out_dir, exist_ok=True)
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    device = torch.device(config.device)
    train_loader, val_loader = build_dataloaders(config, device)
    model, optimizer = build_model_and_optimizer(config, device)
    iter_num = resume(config, model, optimizer)
    # main training loop
    model.train()
    train_iter = iter(train_loader)
    tick = time.time()
    pbar = tqdm(total=config.max_iters - iter_num + 1, desc="Training", unit="step")
    while iter_num <= config.max_iters:
        # Fetch next batch and restart iterator after one full pass.
        try:
            inputs, targets = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            inputs, targets = next(train_iter)
        inputs, targets = inputs.to(device), targets.to(device)
        # Update learning rate from warmup + cosine schedule.
        lr = get_lr_cosine_schedule(
            lr_min=config.min_lr,
            lr_max=config.learning_rate,
            step=iter_num,
            warmup_steps=config.warmup_steps,
            cosine_steps=config.max_iters,
        )
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
        # Compute loss and update parameters.
        optimizer.zero_grad()
        logits = model(inputs)
        loss = cross_entropy_loss(logits, targets)
        loss.backward()
        if config.grad_clip > 0.0:
            gradient_clipping(model.parameters(), config.grad_clip)
        optimizer.step()
        # Print train stats.
        if iter_num % config.log_interval == 0:
            elapsed_ms = (time.time() - tick) * 1000.0
            tick = time.time()
            print(f"[Train] step={iter_num} | loss={loss.item():.4f} | lr={lr:.2e} | time={elapsed_ms:.2f}ms")
            pbar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{lr:.2e}")
        # Run validation.
        if iter_num % config.eval_interval == 0:
            val_loss = evaluate(model, val_loader, config, device)
            print(f"[Eval]  step={iter_num} | val_loss={val_loss:.4f}")
        # Save periodic checkpoints.
        if iter_num % config.save_interval == 0:
            ckpt_path = os.path.join(config.out_dir, f"ckpt_{iter_num}.pt")
            save_checkpoint(model, optimizer, iter_num, ckpt_path)
            print(f"[Save]  {ckpt_path}")
        iter_num += 1
        pbar.update(1)
    pbar.close()
    # Save the final checkpoint.
    final_path = os.path.join(config.out_dir, "ckpt_final.pt")
    save_checkpoint(model, optimizer, iter_num, final_path)
    print("[Done] Training finished.")
if __name__ == "__main__":
    train(parse_args())

该脚本采用模块化设计,将配置、数据、模型构建与训练循环解耦。主要的数据结构和核心函数说明如下:

  • TrainConfig 类:作为全局配置容器,统一管理训练所需的所有超参数。
  • parse_args 函数:利用 argparse 解析终端传入的参数(如 --batch-size , --resume ),并将其映射实例化为一个 TrainConfig 对象。
  • build_dataloaders 函数:构建训练集和验证集的数据加载器。注意,这里根据设备类型(CPU/GPU)自动设置 pin_memory 以加速数据传输。
  • build_model_and_optimizer 函数:根据 TrainConfig 中的架构参数初始化 Transformer 模型,并将其移动到指定设备。
  • evaluate 函数,在验证集上评估当前模型的性能。
  • resume :处理断点续训逻辑。如果提供断点目录,则加载模型权重和优化器状态,和训练步数。
  • train :函数:模型训练的主循环,包含数据加载、前向传播、损失计算、反向传播、梯度裁剪、参数更新、学习率调整、日志记录、验证评估和检查点保存等功能。

# 模型推理

最后实现一个 generate 函数来完成模型的推理功能。该函数接受一个训练好的模型、一个输入文本字符串,以及生成文本的最大长度等参数,返回模型生成的文本字符串。生成过程采用 ** 自回归(autoregressive)** 的方式,即每次根据当前输入序列预测下一个 token,并将其添加到输入序列中,直到达到最大长度或者生成结束标志。

def generate(
    model: Transformer,
    input_str: str,
    tokenizer: Tokenizer,
    max_length: int,
    eos_token: str,
    temperature: float = 1.0,
    top_k: int = 0,
    sample: bool = False,
    device: str = "cpu",
) -> str:
    # setup
    model.eval()
    input_tokens = tokenizer.encode(input_str)
    input_ids = torch.tensor(input_tokens, dtype=torch.long).unsqueeze(0)  # (1, seq_len)
    device = torch.device(device)
    model.to(device)
    input_ids = input_ids.to(device)
    output_ids = input_ids.tolist()[0]  # start with input tokens
    # generate tokens until max_length or EOS
    with torch.no_grad():
        for _ in range(max_length - len(input_tokens)):
            logits = model(input_ids)  # (1, seq_len, vocab_size)
            next_token_logits = logits[0, -1, :] / (temperature + 1e-8)  # (vocab_size,)
            if top_k > 0:
                top_k_values, _ = torch.topk(next_token_logits, top_k)
                next_token_logits[next_token_logits < top_k_values[-1]] = -float("inf")
            if sample:
                next_token_id = torch.multinomial(
                    F.softmax(next_token_logits, dim=-1), num_samples=1).item()
            else:
                next_token_id = torch.argmax(next_token_logits).item()
            output_ids.append(next_token_id)
            input_ids = torch.cat([input_ids, torch.tensor([[next_token_id]], device=device)], dim=1)
            if tokenizer.vocab[next_token_id] == eos_token.encode("utf-8"):
                break
    # decode output tokens to string
    output_str = tokenizer.decode(output_ids)
    return output_str

除了基本的自回归生成逻辑之外,这个函数还支持一些常用的生成策略:

  • 温度采样(Temperature Sampling):通过调整 logits 的温度参数来控制生成文本的多样性。较高的温度会使概率分布更平坦,增加生成文本的多样性;较低的温度会使概率分布更尖锐,倾向于生成更常见的 token。
  • Top-k 采样:通过限制候选 token 的数量来控制生成文本的质量和多样性。仅保留概率最高的 k 个 token,其他 token 的概率被设置为负无穷,这样在采样时只能从这 k 个 token 中选择。
  • 概率采样(Sampling):在生成下一个 token 时,可以选择直接取概率最高的 token(贪心策略),或者根据概率分布进行随机采样,这样可以增加生成文本的多样性。

这些都是现代大语言模型常用的生成策略,实际生成环境可能包含更多的控制机制,例如重复惩罚、长度惩罚、禁用特定 token 等等,以满足不同的生成需求和约束条件。


# 总结

本文基于斯坦福 CS336 第一章课程作业,系统性地完成了 Transformer 模型的从零构建与训练流程。我们不仅由浅入深地实现了 RMSNorm、RoPE 旋转位置编码、SwiGLU 前馈网络及多头注意力机制等核心组件,还手动编写了 AdamW 优化器与完整的训练循环。通过这一 “造轮子” 的实践过程,将抽象的数学公式转化为具体的 PyTorch 代码,也揭示了现代大语言模型的底层运作机理。