Posted on ::

1. 引言

在人工智能快速爆发的今天,各种大语言模型(LLM)层出不穷,参数规模动辄达到千亿级、模型版本不断迭代更新,难免让人感觉跟不上模型发展的节奏。

如果把视角切换到底层结构,会发现从早期的 GPT-2 到如今主流的模型,大多数模型的架构依然建立在 Transformer Decoder-only 之上,并没有发生颠覆性的变化。对比 Qwen3、GLM,Llama 等开源大模型,它们在结构上高度相似,差异更多体现在具体模块的实现方式以及参数配置方面。

图:大语言模型结构对比

如果你对这些模块的功能和实现已经比较了解,可以直接看完整的代码实现

图:大语言模型基础结构和推理流程

当理解了这套基础结构及其推理流程,再去看其他大语言模型的论文或源码就会容易很多。从零构建一个大语言模型,就像搭积木一样,只要理解每个模块的功能,再按照合理方式将其拼接起来,就能搭建出完整的模型。

本文将以 Qwen3 的模型结构为基础,沿前向传播的路径,逐步介绍各个模块的作用与原理,并给出对应的 PyTorch 代码实现。在此基础上,我们将进一步把这些模块组装成一个完整的 Qwen3 模型,最后通过加载 Qwen3-0.6B 的预训练权重,对模型进行推理验证。

图:Qwen3-0.6B模型结构和推理流程

2. 分词器(Tokenizer)

用户输入的提示词并不会以字符串的形式直接输入给大语言模型,模型真正能够处理的是张量(Tensor),而不是自然语言本身,原始文本必须先转换成一串数字才能进一步处理。另外模型本身并不直接负责转换,这项工作通常由一个独立组件Tokenizer完成,主要工作流程通常分为两步:

  • 切分:将原始文本切分为若干Token。

  • 编码:根据词表(Vocabulary)查询映射关系,将每个 Token 转换为对应的整数 ID。

例如,句子“今天天气很好”经过切分后,可能得到["今天", "天气", "很", "好"],再通过词表映射后得到编码为:[10941, 1487, 25896, 148483],后续将这串数字输入给模型。

TokenID
今天10941
天气1487
25896
148483

表:词表映射

在大语言模型中,Token表示文本被切分后的最小处理单元。这个单元不一定是一个汉字、一个单词,甚至也不一定是一个可见字符。例如:句子“今天天气很好”可能被切分为 ["今天", "天气", "很", "好"],也可能被切分为 ["今", "天", "天气", "很好"],具体如何切分并不由语法规则直接决定,而是由模型所采用的分词方法决定。

2.1 如何分词

你可能会想直接以“字符”或“单词”为单位进行切分,两种方法都较为直观,但在实际应用中都存在明显局限。

  • 按字符分词:将每一个字符都视为一个独立的 Token,例如 'a''b''中''文' 等。这种方法的优点在于词表规模较小,通常只需覆盖有限数量的常见字符。但缺点为文本在切分后会形成较长的序列,不仅增加计算开销、降低推理速度,也使模型很难高效地捕捉语义。例如,单独一个字符'中' 往往并不具备完整语义,必须结合上下文才能理解。

  • 按单词分词:将每个完整单词视为一个Token,这种方式在形式上更贴近人类对语言的直观理解,但会带来词表规模过大的问题。以英语为例,同一个词往往具有多种词形变化,如 runrunningranruns 等;此外还存在大量专有名词、新词、复合词和领域术语。如果将这些形式都作为独立的 Token,词表规模就会迅速膨胀。词表过大不仅会导致嵌入层(Embedding Layer)的参数数量显著增加,还会提高模型训练、存储的成本。更重要的是,当模型遇到训练语料中未出现过的新词时,往往无法将其进一步拆分,只能映射为统一的未知标记(unknown),造成信息损失。

因此,无论是按字符分词,还是按单词分词,都难以同时兼顾词表规模和序列长度这两个关键因素。

为了在词表规模和序列长度之间找到平衡,现代大语言模型通常采用子词级别的分词方法。所谓子词,可以理解为介于字符和单词之间的单元,它既可能是一个完整的单词,也可能是词根、词缀,或者单词中的某一部分。

子词分词方法的优势在于:对于高频词,可以直接作为一个Token;而对于低频词或复杂词,可以将其拆分为多个更小但更具规律性的子词单元。例如,tokenization 可以被拆分为 tokenization。这样既控制了词表规模,又增强了对未见词和复杂词的表示能力。

那么如何拆分出这些子词?子词的划分虽然可以参考一定的语言学规则,但依赖人工规则难以适应大规模、多样化的真实语料。现代分词方法通常借助统计规律从语料中自动学习出有哪些子词,并得到一个词表,其中最具代表性的方法是 BPE(Byte Pair Encoding)。

2.2 BPE(Byte-Pair Encoding)

BPE 最初是一种用于数据压缩的算法,后来被引入自然语言处理领域,并逐渐发展成为子词分词中的常用方法之一。当前许多大语言模型(如GPT、Qwen、Llama)都采用了BPE方法构建词表和分词,它能够在词表规模和序列长度之间取得较好的平衡。

从使用流程来看,BPE 可以分为两个环节:一是在训练阶段根据语料统计结果构建词表和合并规则,二是在推理阶段,根据这些合并规则对输入的文本进行切分,并将切分结果通过词表映射为Token ID序列。

下面分别介绍 BPE 的词表构建过程以及文本切分与编码流程。

2.2.1 词表训练流程

BPE 的训练过程本质上就是词表的构建过程,它从最细粒度的基础符号表示出发,通过反复合并语料中出现频率最高的相邻符号对,将其合并为一个新的符号,重复这一过程,表中的符号会逐渐从基础符号扩展为更大的子词单元,直到达到预设词表大小或合并次数上限。

BPE 词表的训练流程可以归纳为以下几个步骤:

  1. 初始化基础词表。 对训练语料中的所有文本进行预切分,将每个词拆解为最小粒度的基础符号序列(字符或字节),并将所有出现过的基础符号收集为初始词表$V_0$。同时统计每个词在语料中的出现频次。

  2. 统计相邻符号对频率。 遍历语料中所有经过拆分后的符号序列,统计每一对相邻符号$(s_i, s_{i+1})$在整个语料中的共现频次。

  3. 合并最高频符号对。 找到频次最高的相邻符号对$(s_a, s_b)$,将其合并为一个新符号$s_{ab}$,加入词表;同时将该合并操作$(s_a, s_b) \rightarrow s_{ab}$记录到合并规则表中。

  4. 更新语料表示。 将语料中所有出现的相邻$s_a, s_b$替换为新符号$s_{ab}$,得到更新后的符号序列。

  5. 重复迭代。 回到步骤二,在更新后的语料上重新统计相邻符号对频率,继续合并,直到满足以下任一终止条件:

    1. 词表大小达到预设上限$|V| = V_{\max}$;

    2. 已完成预设的最大合并次数;

训练结束后,最终输出两项关键产物:一是包含所有基础符号与合并产生的子词单元词表;二是按合并顺序排列的合并规则表,它决定了后续对新文本进行分词时的合并优先级。

注意这里的基础符号可以是字符,也可以是字节,但在实际工程实现中,分词器往往以字节作为底层的初始表示,为了便于说明,例子仍以字符为单位进行演示。

为了更直观地理解,下面用一个简化的中文语料为例进行说明。假设训练语料中包含以下词语:

人工智能,人工标注,智能系统,智能助手

在初始化阶段,BPE 会先将这些词语拆分为最基本的字符序列:

  • 人工智能 → 人 工 智 能

  • 人工标注 → 人 工 标 注

  • 智能系统 → 智 能 系 统

  • 智能助手 → 智 能 助 手

此时,词表中只包含最基础的字符单位{人:0, 工:1, 智:2, 能:3, 标:4, 注:5, 系:6, 统:7, 助:8, 手:9},接下来,BPE 会在整个语料范围内统计所有相邻符号对的出现频率。

  • (人, 工)人工智能人工标注 中出现,共 2 次

  • (智, 能)人工智能智能系统智能助手 中出现,共 3 次

  • (系, 统)智能系统 中出现,共 1 次

  • (助, 手)智能助手 中出现,共 1 次

  • (标, 注)人工标注 中出现,共 1 次

当前频率最高的相邻符号对是 (智, 能),那么 BPE 会首先将其合并为一个新的符号“智能”并加入词表。于是,语料更新为:

  • 人工智能 → 人 工 智能

  • 人工标注 → 人 工 标 注

  • 智能系统 → 智能 系 统

  • 智能助手 → 智能 助 手

随后重新统计新的相邻符号对频率,这次发现(人, 工)的频率较高,于是继续将其合并为“人工”:

  • 人工智能 → 人工 智能

  • 人工标注 → 人工 标 注

  • 智能系统 → 智能 系 统

  • 智能助手 → 智能 助 手

继续迭代,还可以将 (系, 统)合并为“系统”,将(助, 手) 合并为“助手”,将(标, 注)合并为“标注”。经过若干轮迭代后,词表中除了原始基础符号外,还包含新加入的子词单元,得到的词表如下:

{人:0, 工:1, 智:2, 能:3, 标:4, 注:5, 系:6, 统:7, 助:8, 手:9, 智能:10, 人工:11, 系统:12, 助手:13, 标注:14}

与此同时,BPE 还会记录这些子词单元的合并规则,生成合并规则表,并按照规则优先级从高到低排序,越早被合并的符号对,代表着在训练语料中越普遍、优先级越高。例如:

优先级合并规则
1(智, 能) → 智能
2(人, 工) → 人工
3(系, 统) → 系统
4(助, 手) → 助手
5(标, 注) → 标注

表:合并规则表

由此可见,BPE 词表中的子词单元并不是人工指定的,而是通过语料统计逐步学习得到的。词表构建完成后,分词器就可以利用这些子词单元和对应的合并规则,对新的输入文本进行切分和编码。

2.2.2 文本切分和编码流程

当 BPE 分词器训练好之后,它内部就保存了一张词表和一份按先后顺序排列的合并规则表。模型处理用户输入时,需要依据合并规则表对文本进行切分。具体步骤如下:

  1. 切分为基础符号:将输入文本按照词表训练时相同的方式,拆分为最小粒度的基础符号(字符或字节)序列。

  2. 按序应用合并规则: 根据训练时生成的合并规则表,严格按照规则的优先级顺序,逐条扫描当前符号序列。若序列中存在某条规则对应的相邻符号对,则将其合并为新的子词单元;若不存在,则跳过该条规则,继续检查下一条。重复此过程,直到所有规则都已检查完毕。

3. **映射 ID**:合并结束后,符号序列中的每个元素都是词表中的一个子词单元。将每个子词单元查询词表,替换为其对应的整数编号(Token ID),即得到最终的编码结果。

2.2.3 Qwen 的 Tokenizer 实现

前面的例子为了展示方便,采用字符作为基础符号。为了贴近实际工程实现,下面的代码采用 UTF-8 字节码作为基础符号,总体思路没有变化。

class BytePairEncoder:
    def __init__(self, target_vocab_size):
        # 目标词表大小必须大于基础的 256 个字节
        self.target_vocab_size = target_vocab_size
        self.merges = {} # 记录合并规则,格式为:(token_a, token_b) -> new_token
        # 初始化基础词表 0~255
        self.vocab = {i: bytes([i]) for i in range(256)}
        
    def _get_stats(self, ids):
        # 统计序列中所有相邻 token 对的出现频率
        counts = {}
        for pair in zip(ids, ids[1:]):
            counts[pair] = counts.get(pair, 0) + 1
        return counts
    def _merge(self, ids, pair, new_idx):
        # 将序列中所有匹配的 pair 替换为新的 token ID
        newids = []
        i = 0
        while i < len(ids):
            if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
                newids.append(new_idx)
                i += 2 # 匹配成功,跳过两个原始 token
            else:
                newids.append(ids[i])
                i += 1
        return newids
    def train(self, text: str):
        print(f"--- 开始训练 BPE,目标词表大小: {self.target_vocab_size} ---")
        # 将文本编码为原始的 UTF-8 字节流 IDs (0-255)
        tokens = list(text.encode("utf-8"))
        
        num_merges = self.target_vocab_size - 256
        for i in range(num_merges):
            stats = self._get_stats(tokens)
            if not stats:
                break
                
            # 找到频率最高的一对
            best_pair = max(stats, key=stats.get)
            new_idx = 256 + i
            
            # 执行合并
            tokens = self._merge(tokens, best_pair, new_idx)
            
            # 记录规则与词表映射
            self.merges[best_pair] = new_idx
            self.vocab[new_idx] = self.vocab[best_pair[0]] + self.vocab[best_pair[1]]
            print(f"Merge {i+1}: {best_pair} -> {new_idx} ({self.vocab[new_idx]})")
    def encode(self, text: str):
        # 推理阶段:应用训练好的合并规则
        tokens = list(text.encode("utf-8"))
        while len(tokens) >= 2:
            stats = self._get_stats(tokens)
            # 在当前文本的所有相邻对中,找到在训练集中最先合并的(即优先级最高的)
            pair = min(stats.keys(), key=lambda p: self.merges.get(p, float("inf")))
            
            if pair not in self.merges:
                break # 没有任何可识别的合并对,退出循环
                
            idx = self.merges[pair]
            tokens = self._merge(tokens, pair, idx)
        return tokens
    def decode(self, ids):
        # 根据词表将 ID 列表还原回 UTF-8 文本
        tokens = b"".join(self.vocab[idx] for idx in ids)
        return tokens.decode("utf-8", errors="replace")
# 简单测试代码
text_corpus = "我爱北京天安门,天安门上太阳升"
encoder = BytePairEncoder(target_vocab_size=260) # 合并4次
encoder.train(text_corpus)
encoded = encoder.encode("天安门")
print("\n'天安门' Encode 结果:", encoded)
print("Decode 结果:", encoder.decode(encoded))

2.3 特殊 Token

模型接收到的输入 Token 序列中,并非所有 Token 都直接来自用户输入的提示词。其中有一类 Token 主要用于表示结构和控制信息,称为特殊 Token。例如,<|im_start|> 表示一段对话内容的开始,<|im_end|> 表示一段对话内容的结束。

模型通过识别这些特殊 Token,来判断一段内容从何处开始、在何处结束,并区分当前发言者的身份(如用户或助手),甚至触发某些特定的内部处理流程。

比如你与模型进行了一次对话:

User: 你好! Assistant: 你好!请问有什么我可以帮你的?

在模型底层,模型看到的并不是一段没有结构的自然语言文本,而是由普通 Token 与特殊 Token 共同组成的序列:

<|im_start|>user
你好!
<|im_end|>
<|im_start|>assistant
你好!请问有什么我可以帮你的?
<|im_end|>

如果缺少这些特殊 Token,模型就很难准确区分对话中的角色边界,也无法判断哪部分内容属于用户输入、哪部分内容属于助手回复。这样一来,整段文本就可能被混杂在一起,从而影响模型对上下文结构的理解以及后续生成的准确性。

经过分词器处理后,原始文本被转换为 token id 序列。对于模型而言,这些 id 仍然只是离散符号,尚不具备可计算的连续表示。下一章将进一步介绍嵌入层如何将这些离散 id 映射到向量空间。

3. 嵌入层(Embedding Layer)

把每个离散的 token id 映射为一个连续的向量(vector)表示,这一步就是由 Embedding Layer(嵌入层) 完成,结合前一步的分词,总的流程为:

文本 → Token → Token ID → Embedding -> 向量

3.1 Embedding 计算

Embedding的计算过程很简单,类似一次“查表(lookup)”,把 token ID 映射成连续向量。

操作流程:

  1. 有一个可训练矩阵$W \in \mathbb{R}^{V \times d}$(词表大小 (V),维度 (d)),第 (i) 行对应 token ID=i 的向量

  2. Embedding 输出就是按 ID 取行:$\mathrm{Emb}(x_t) = W[x_t]$,比如 ID=2 取到蓝色那行向量,ID=5 取到红色那行向量。

  3. 把每个 token 的向量按顺序堆起来,得到形状$([t,d])$ 的序列向量,供后续模型使用。 $t$表示输入元素的个数。

Embedding层计算过程

3.2 Embedding 代码实现

PyTorch 中提供了标准的 Embedding 实现,用起来也很方便:

import torch
input_ids = torch.tensor([4, 1, 2, 3])
vocab_size = 6
output_dim = 3

torch.manual_seed(42)
embedding_layer = torch.nn.Embedding(vocab_size, output_dim)

print(embedding_layer.weight)
print("get embedding at pos 3:")
print(embedding_layer(torch.tensor([3])))

输出:

Parameter containing:
tensor([[ 1.9269,  1.4873, -0.4974],
        [ 0.4396, -0.7581,  1.0783],
        [ 0.8008,  1.6806,  0.3559],
        [-0.6866,  0.6105,  1.3347],
        [-0.2316,  0.0418, -0.2516],
        [ 0.8599, -0.3097, -0.3957]], requires_grad=True)

get embedding at pos 3:
tensor([[-0.6866,  0.6105,  1.3347]], grad_fn=<EmbeddingBackward0>)

3.3 为什么需要 Embedding?

3.3.1 Token ID 没有语义关系

分词器输出的是整数 ID,例如:

  • 输入文本:"Choices matter more than effort"

  • 分词后:[314, 2719, 98, 5021, ...]

这些整数只是词表里的索引,不携带任何可供模型利用的“语义距离”。对神经网络而言,314 与 315 的数值接近并不表示语义接近,它们只是随机分配离散符号。直接把 ID 当作数值输入线性层,会迫使模型学习一种毫无意义的数字符号,训练效果不好。

3.3.2 One-hot 编码的缺陷

自然语言处理中常见做法是把离散 ID 映射到连续向量空间,让“相近含义的 token id”在向量空间里可以被学习成“相近方向或相近距离”。

ID 映射到向量空间最直观的做法是把每个 token 表示成 one-hot 向量。假设词表大小为$𝑉$,那么每个 token 对应一个长度为$𝑉$的向量,只有一个位置是 1,其余都是 0。例如tokenId=2,词表长度为8, 对应的one-hot向量为:[0,0,1,0,0,0,0,0]

问题在于大模型的词表规模上万,one-hot 表示将带来两个问题:

  1. 维度爆炸:每个 token 都是几万维向量。

  2. 极度稀疏:绝大多数维度为 0,线性计算会浪费大量算力在“乘以 0”上。

3.3.3 Embedding 是如何学到语义的

Embedding解决了one-hot的问题,它是将高维、稀疏、离散的 One-Hot 向量,压缩并映射为低维、稠密、连续的向量。在这个新的向量空间中,语义相似的词在几何距离上也会彼此靠近。但这张表中的向量不是人工设计出来的,而是在模型训练过程中自动学习得到的。

训练初始时,Embedding 矩阵值通常是随机初始化,随着模型在海量语料上训练,反向传播会不断更新这些向量,使它们逐渐形成结构化的语义空间。

3.4 Embedding 层的本质

神经网络的基础结构由线性层(Linear Layer)组成,那Embedding层到底怎么来的?其实Embedding 层在数学上完全等价于先进行 One-hot 编码,然后再经过一个没有偏置项的全连接层(Linear / Dense Layer)。

下面是一个简单的数学推导:

假设隐藏层维度(Hidden Size,记作 $d$)为 4096。有一个形状为 $[V, d]$(即 $150000 \times 4096$)的权重矩阵 $W_{embedding}$。

如果拿 Token 的 One-hot 向量(形状为 $[1, V]$)去乘以这个权重矩阵:

$E_{dense} = X_{one-hot} \times W_{embedding}$

由于 $X_{one-hot}$ 中只有一个位置是 1,其余全为 0,根据矩阵乘法的规则,这个乘法的结果,恰好就是把 $W_{embedding}$ 矩阵的第 $i$ 行原封不动地提取出来。

图:$E\_{dense} = X\_{one-hot} \times W\_{embedding}$计算过程

既然数学上等价于“提取矩阵的某一行”,那么在工程代码实现时,深度学习框架(如 PyTorch 的 nn.Embedding)不会真的去生成那个庞大的 One-hot 矩阵并进行极其耗时的乘法运算。Embedding采用查表(Lookup Table) 操作,直接利用 Token ID 作为行索引(Row Index),以 O(1)的时间复杂度从权重矩阵中把对应的那一行向量取出来。

4. Transformer Decoder(核心)

4.1 总体结构

到目前为止,我们已经把文本变成了 Token ID,并用 Embedding 把离散符号映射到连续向量空间。接下来进入大模型真正“思考”的地方:Transformer Decoder 层。

图:Decoder层结构

本节按照模型数据向前传播的方向:

RMSNorm → RoPE → Attention → 残差连接 → RMSNorm → FeedForward → 残差连接

依次介绍每个模块的原理和实现,并最终组合成一个完整的 Transformer Decoder 层结构。

4.2 归一化(Normalization)

4.2.1 归一化解决什么问题

Transformer Decoder层的核心计算(Attention / FFN)本质上是大量的线性变换 + 非线性组合。堆叠层数一多,训练就很容易出现:

  • 激活值尺度在层间漂移(越叠越大或越小)

  • 梯度传播不稳定(尤其是深层网络)

  • 学习率稍微激进就炸(loss NaN、梯度爆炸)

归一化的作用是把每层输入的数值尺度限制在一个可控范围内,让神经网络更容易优化。

4.2.2 LayerNorm 为什么逐渐被 RMSNorm 替代

LayerNorm 和 RMSNorm 是大语言模型中常用的两种归一化方法:

LayerNorm:

对于输入向量 $x = (x_1, x_2, ..., x_n)$,LayerNorm 的计算公式为:

  1. 计算均值(Mean):$\mu = \frac{1}{n} \sum_{i=1}^n x_i$

  2. 计算方差(Variance):$\sigma^2 = \frac{1}{n} \sum_{i=1}^n (x_i - \mu)^2$

  3. 归一化:$y_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma_i + \beta_i$ (γ,β 是可学习的参数)

LayerNorm 的计算由两步组成:平移缩放

RMSNorm:

RMSNorm 的思路是:只缩放,不平移

这篇论文研究中发现,LayerNorm 中对模型收敛起决定性作用的其实是“方差缩放”,而不是“均值平移”。因此,RMSNorm 直接使用均方根(RMS)来进行归一化:

  1. 计算均方根(Root Mean Square):$RMS(x) = \sqrt{\frac{1}{n} \sum_{i=1}^n x_i^2}$

  2. 归一化与仿射变换:$y_i = \frac{x_i}{RMS(x) + \epsilon} \cdot \gamma_i$

在大语言模型中使用 RMSNorm,相比 LayerNorm 有如下优势

  1. **计算效率更高(最重要原因):**而在 RMSNorm 中,由于没有均值$\mu$,只需要直接对输入求平方和$x_i^2$ 即可算出 RMS,仅需一次数据遍历。在拥有数百亿参数的大模型中,归一化层在每个Transformer Decoder Block中会被高频调用。省去均值计算可以带来 10% ~ 50% 的归一化层计算加速,这对整体训练和推理的性能提升可观

  2. 模型效果几乎没有折损。

  3. 提高训练效率:在大模型训练中,通常会使用张量并行。在某些跨 GPU 的并行切分方案中,减少一次均值求和(Reduce)操作,就意味着有可能减少一次跨 GPU 的通信或同步开销,从而提升集群的整体训练效率。

4.2.3 RMSNorm 实现

使用pytorch实现:

class RMSNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.emb_dim = emb_dim
        self.weight = nn.Parameter(torch.ones(emb_dim)).float()

    def forward(self, x):
        means = x.pow(2).mean(dim=-1, keepdim=True)
        x_normed = x * torch.rsqrt(means + self.eps)
        return (x_normed * self.weight).to(dtype=x.dtype)

也可以直接使用PyTorch的内置方法torch.nn.RMSNorm实现,下面的代码是等价的

example_batch = torch.randn(2, 3, 4)

rms_norm = RMSNorm(emb_dim=example_batch.shape[-1])
rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)

4.3 自注意力机制(Self-Attention)

自注意力机制是为当前token动态地提取与其关联的上下文信息,重新得到该token的特征表示。在没有自注意力机制之前(比如早期的Word2Vec),每个token的特征向量是静态的。例如“Apple”这个词,无论它出现在“吃苹果”还是“苹果公司发布新iPhone”中,它的向量表示都是一样的。

自注意力机制突破了这种静态限制,让当前token去**“观察”周围的伙伴**。通过提取上下文特征,把原本静态的token向量融合了丰富的上下文信息。经过自注意力层后,代表“Apple”Token的那个向量里,已经融合了“吃”或者“iPhone”的信息。用带着上下文信息的token向量去预测下一个token自然会更准确。

上面讲了原理,其中很关键的一步是提取上下文的特征,那如何提取特征?如果只是简单地把句子中所有token的向量加起来求平均,会让不重要的token冲淡语义相关token的特征,所以需要先计算当前token与其他token的相关性,按相关性的比例再加权求和。下面是论文Attention Is All You Need中提出的计算方法:

$\mathrm{Attention}(Q,K,V)=\mathrm{softmax}!\left(\frac{QK^{\top}}{\sqrt{d_k}}\right)V$

单看这个公式还比较抽象,下面介绍下具体的思路,并举例说明计算的细节。

4.3.1 Q、K、V 计算

在自注意力机制中,Q、K、V三个字母分别代表 Query(查询)Key(键)Value(值)。用一个生活中的例子来理解它们的关系:你把所有输入的Token向量都看成一本本书,当你去图书馆找书,你在检索系统里输入的书名或者关键词就是 Query;而图书馆里每一本书的书脊上都贴着一个分类标签,这个标签就是 Key;你翻开那本书,从书里面的学习到的知识就是 Value

对于句子中的每一个token,经过Embedding层得到词向量矩阵$X_{t\times d}$,$t$表示token的个数,$d$表示向量的维度(就是Embedding层的维度)。模型内部会初始化三个权重矩阵:$W^Q$、$W^K$、$W^V$(这三个矩阵是模型在训练过程中不断调整数值),三个矩阵的维度都是$[d_{in}, d_{out}]$, 其中$d_{in} = d$

将词向量矩阵$X_{t\times d}$分别与这三个矩阵相乘,就得到了对应的 Q、K、V 矩阵

  • $Q_{t\times d_{out}} =X_{t\times d} \cdot W^Q$

  • $K_{t\times d_{out}} = X_{t\times d} \cdot W^K$

  • $V_{t\times d_{out}} = X_{t\times d} \cdot W^V$

图:计算Q、K、V

这里的矩阵映射等价于线性层计算,在pytorch中的实现方式为:

class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        

4.3.2 Self-Attention 计算

a. 计算步骤

Attention计算的流程主要三步:

计算相似度 —> 归一化 —> 特征融合

第一步:计算相似度

用第一个 Token 对应的的 $Q^{(1)}$向量,去和句子中所有 Token(包括自己)的 $K[t, d_{out}]$ 进行点乘运算。在数学上,两个向量的点乘结果越大,说明它们越相似、越匹配。点乘算出来的分数称为**注意力分数(Attention Score),**使用向量$Score_{[1,t]}$表示,对应矩阵乘法表示为$Q^{(1)}K^{\top}$

第二步:归一化(Softmax)

上一步算出的分数向量$Score_{[1, t]}$通过一个 Softmax 函数,把所有得分转换成 0 到 1 之间的概率值,且总和为 1。注意力分数变为注意力权重$ω_{[1, t]}$。

第三步:特征融合

$Q^{(1)}$的注意力权重$ω_{[1, t]}$与对应的$V^{(i)}$向量(维度$[t, d_{out}]$)相乘再相加,就得到了融合了其他token特征的新****向量$Attention_{[1, d_{out}]}$, 长度与$Q^{(i)}$一致。这一步使用矩阵乘法可以表示为$ω_{[1, t]}V_{t,d_{out}}$

把前三步的矩阵计算合并起来,得到公式:

$\mathrm{softmax}!\left(Q^{(1)}K^{\top}\right)V$

上面的过程是用第一个Token对应的$Q^{(1)}$计算出对应的$Attention_{[1, d_{out}]}$, 后续其他$t-1$个Token重复上面的计算步骤,得到$\mathrm{softmax}!\left(Q^{(2)}K^{\top}\right)V$, ... , $\mathrm{softmax}!\left(Q^{(t)}K^{\top}\right)V$,将$Q^{(1)}$到$Q^{(t)}$作为行向量合并一起就是矩阵$Q$,于是得到了公式:

$Attenttion(Q,K,V) = \mathrm{softmax}!\left(QK^{\top}\right)V$

计算出的$Attention_{[t, d_{out}]}$矩阵维度为$[t, d_{out}]$,表示每一个token都计算出了一个新的特征向量,长度为$d_{out}$。

对比论文中给出的公式:$\mathrm{Attention}(Q,K,V)=\mathrm{softmax}\!\left(\frac{QK^{\top}}{\sqrt{d\_k}}\right)V$,还缺少除以$\sqrt{d\_k}$,$\sqrt{d\_k}$叫做缩放因子,值为矩阵$K$的第二个维度$d\_{out}$。缩放因子是为了解决当向量维度很大时,点积算出来的数值可能会非常大,softmax 输入很大会把权重推得非常“尖”,几乎变成“只选一个 token”。这会导致:
  • 梯度变小(softmax 饱和)

  • 学习不稳定

除以 $\sqrt{d_k}$ 做缩放,让分数落在更合适的范围内。所以self-attention也叫做scaled dot-product attention

self-attention计算过程
b. Self-Attention 对应的 PyTorch 实现
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
      """
      x: (B, t, d_in)   # B=batch_size, t=seq_len, d_in=输入特征维度
      """
      keys    = self.W_key(x)    # (B, t, d_out)
      queries = self.W_query(x)  # (B, t, d_out)
      values  = self.W_value(x)  # (B, t, d_out)
  
      # 2) 计算注意力分数:Q @ K^T
      # queries: (B, t, d_out)
      # 因为是3维矩阵,不能用K^T,使用keys.transpose(-2, -1): (B, d_out, t)
      attn_scores = queries @ keys.transpose(-2, -1)  # (B, t, t)
  
      # 3) 缩放 + softmax 得到注意力权重
      # keys.shape[-1] == d_out
      attn_weights = torch.softmax(
          attn_scores / (keys.shape[-1] ** 0.5),      # (B, t, t)
          dim=-1                                      # 对最后一维 t 做 softmax
      )                                               # -> (B, t, t)
  
      # 4) 加权求和得到带有上下文信息的向量
      # attn_weights: (B, t, t)
      # values:       (B, t, d_out)
      context_vec = attn_weights @ values               # (B, t, d_out)
  
      return context_vec                                # (B, t, d_out)

4.3.3 因果掩码(Causal Mask)

在文本生成任务中,模型的工作方式是自回归(Autoregressive),即根据已经生成的词来预测下一个词。如果模型提前“偷看”了答案,会导致模型在真正预测时准确率下降。

为了解决这个问题引入了因果掩码(Causal Mask)。它的核心思想是:强制让当前 token 只能看到它自己以及它之前的 token,把未来的 token 给遮挡起来,维持时间上的因果关系,实现方式是遮住部分$QK^{\top}$计算出的注意力矩阵中的元素。

如下图示:

计算masked attention scores

假设输入有 4 个 Token($t_1, t_2, t_3, t_4$),计算出的 $QK^\top$ 相似度矩阵是一个 $4 \times 4$ 的矩阵。矩阵的第 $i$ 行代表第 $i$ 个 Token 去关注其他 Token 的得分。 为了不让模型看到未来,我们把矩阵的 右上角(上三角部分) 的分数全部强制替换成一个极小的负数,通常是负无穷大($-\infty$)。

掩码处理后的得分矩阵看起来像这样:

$\begin{bmatrix} Score_{1,1} & \mathbf{-\infty} & \mathbf{-\infty} & \mathbf{-\infty} \ Score_{2,1} & Score_{2,2} & \mathbf{-\infty} & \mathbf{-\infty} \ Score_{3,1} & Score_{3,2} & Score_{3,3} & \mathbf{-\infty} \ Score_{4,1} & Score_{4,2} & Score_{4,3} & Score_{4,4} \ \end{bmatrix}$

a. 为什么要用 $-\infty$ 而不是 0

这是由Softmax函数中$e^{x_i}$的性质决定,

$Softmax(x_i) = \frac{e^{x_i}}{\sum e^{x_j}}$

如果在进入 Softmax 之前把未来的分数设为 0,经过 $e^0$ 计算后,它的权重会变成 $1$,未来 Token 依然会占有一部分权重。如果我们把分数设为 $-\infty$,由于 $e^{-\infty} \approx 0$,经过 Softmax 之后,未来 Token 分配到的注意力权重就会变成绝对的 $0$。

此时生成的注意力权重矩阵(Attention Weights)会变成一个标准的下三角矩阵

$\begin{bmatrix} 1.0 & 0 & 0 & 0 \ 0.6 & 0.4 & 0 & 0 \ 0.3 & 0.5 & 0.2 & 0 \ 0.1 & 0.2 & 0.4 & 0.3 \ \end{bmatrix}$

在第一行Token 1 的注意力 100% 都放在了自己身上,看不到 2、3、4。

在第三行Token 3 将注意力按比例分配给了 1、2、3,依然看不到 4。

经过Causal Mask 处理过的注意力机制,也叫做因果注意力机制

b. Causal Mask 的 PyTorch 实现

在代码中实现这个功能非常简单,PyTorch 提供了 torch.tril()(triangle lower,取下三角)方法。我们将 4.2.1 中的代码升级为带有因果掩码的 CausalSelfAttention

class CausalSelfAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
      
        # context_length 为允许输入的最大长度
        # 生成一个 context_length x context_length 的下三角全 1 矩阵
        mask = torch.tril(torch.ones(context_length, context_length))
        self.register_buffer("mask", mask)

    def forward(self, x):
        B, t, d_in = x.shape
        # 1) 线性映射得到 Q K V
        keys    = self.W_key(x)  
        queries = self.W_query(x)
        values  = self.W_value(x)
        # 2) 计算注意力分数:Q @ K^T
        attn_scores = queries @ keys.transpose(-2, -1)  # (B, t, t)
        # ----------------- 新增:Causal Mask 逻辑 -----------------
        # 截取当前序列长度对应的 mask (因为实际输入长度 t 可能小于 context_length)
        causal_mask = self.mask[:t, :t] == 0 
        # 将 mask 中为 True 的位置(即上三角部分)的得分,替换为负无穷
        attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
        # ----------------------------------------------------------
        # 3) 缩放 + softmax 得到注意力权重
        # 此时得分为负无穷的位置,经过 softmax 后权重就变成了 0
        attn_weights = torch.softmax(
            attn_scores / (keys.shape[-1] ** 0.5),    
            dim=-1                                    
        )                                             
        # 4) 加权求和得到带有上下文信息的向量
        context_vec = attn_weights @ values             
        return context_vec                              

4.3.4 多头注意力机制(Multi-Head Attention)

自注意力(Single-Head Self-Attention)让 Token 有了观察全局的能力,那么**多头注意力机制(Multi-Head Attention)**就是让 Token 拥有了“多维度、多视角观察”的能力。当前的大语言模型无一例外全都使用了多头注意力。

假设你正在阅读这样一句话:

"The animal didn't cross the street because it was too tired."

当模型在处理 “it” 这个词时,需要提取上下文信息来搞清楚 “it” 到底指代什么。

  • 如果只有一个注意力头,模型可能会把全部注意力权重都放在“寻找语法关系”上,发现 “it” 是一个代词。

  • 但这不够,模型还需要知道 “it” 指代的是 “animal” 还是 “street”。

  • 这时候需要第二个注意力头专门负责“分析语义属性”(谁会觉得累?显然是 animal 而不是 street),第三个注意力头专门负责“分析时态和动作”……

多头注意力原理:把原本的一个注意力机制,拆解成多个独立的“头”。每个头使用自己独立的 $W^Q, W^K, W^V$ 权重矩阵,去学习输入序列的不同特征子空间(语法、语义、情感、逻辑等)。最后再把所有头的结果汇总起来。

用公式表示为:

$\mathrm{MultiHead}(Q,K,V) = \mathrm{Concat}(\mathrm{head}_1, ..., \mathrm{head}_h)W^O$

其中每一头的计算为:

$\mathrm{head}_i = \mathrm{Attention}(XW_i^Q, XW_i^K, XW_i^V)$

再用一个图来来展示这个过程:

多头注意力机制计算流程

对应的pytorch实现(基于上一节的MultiHeadAttention)如下:

import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, num_heads, qkv_bias=False):
        super().__init__()      
        
        self.num_heads = num_heads
        # 每个注意力头负责的特征维度 (head_dim)
        self.head_dim = d_out // num_heads 
      
        # 定义 Q、K、V 的线性映射层
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
      
        # 最终输出的线性投影层 (W^O),用于融合各个头的特征
        self.out_proj = nn.Linear(d_out, d_out)
      
        # context_length 为允许输入的最大长度
        # 生成一个 context_length x context_length 的下三角全 1 矩阵
        mask = torch.tril(torch.ones(context_length, context_length))
        self.register_buffer("mask", mask)

    def forward(self, x):
        B, t, d_in = x.shape
      
        # 形状: (B, t, d_out)
        keys    = self.W_key(x)  
        queries = self.W_query(x)
        values  = self.W_value(x)

        # 按头拆分并交换维度
        # view: (B, t, d_out) -> (B, t, num_heads, head_dim)
        # transpose(1, 2): 交换 t 和 num_heads,变为 -> (B, num_heads, t, head_dim)
        keys = keys.view(B, t, self.num_heads, self.head_dim).transpose(1, 2)
        queries = queries.view(B, t, self.num_heads, self.head_dim).transpose(1, 2)
        values = values.view(B, t, self.num_heads, self.head_dim).transpose(1, 2)

        # 并行计算注意力分数
        # queries: (B, num_heads, t, head_dim) 
        # keys转置: (B, num_heads, head_dim, t)
        # 结果 attn_scores: (B, num_heads, t, t)
        attn_scores = queries @ keys.transpose(-2, -1) 
      
        # 应用因果掩码
        causal_mask = self.mask[:t, :t] == 0 
        # mask的广播机制会自动应用到所有的 num_heads 上
        attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
      
        # 缩放因子现在是基于 head_dim
        attn_weights = torch.softmax(attn_scores / (self.head_dim ** 0.5), dim=-1)

        # 加权求和得到各个头独立提取的上下文向量
        # (B, num_heads, t, t) @ (B, num_heads, t, head_dim) -> (B, num_heads, t, head_dim)
        context_vec = attn_weights @ values 

        # 拼装
        # transpose(1, 2): (B, num_heads, t, head_dim) -> (B, t, num_heads, head_dim)
        # contiguous(): 保证内存连续性
        # view: 重新展平成 (B, t, d_out)
        context_vec = context_vec.transpose(1, 2).contiguous().view(B, t, -1)

        # 最后经过一个线性层,融合所有头抽取的信息
        output = self.out_proj(context_vec) # (B, t, d_out)

        return output

4.4 位置编码(Positional Encoding)

4.4.1 作用

Attention通过计算 Token 之间的相关性来提取上下文特征。但 Attention 公式本身只看内容,不知道顺序。举个例子,下面两句话包含的词一样,但顺序不同,语义完全不同:

  • “小明打了小红”

  • “小红打了小明”

所以需要额外引入“位置信息”,告诉模型这个词在句子里的位置。获取位置的方法有两类:

  1. 绝对位置编码:在每个输入序列的元素上添加一个位置向量,以表示该元素在序列中的具体位置。这个位置向量通常通过正弦和余弦函数生成,具有周期性,能够捕捉序列中的相对位置信息 。

  2. 相对位置编码:不直接编码“我在第几个位置”,编码“我与其他 token 之间的相对距离”。

4.4.2 RoPE 位置编码

RoPE通过绝对位置编码的形式,实现相对位置编码的效果,目前几乎所有主流的开源大模型(如 LLaMA、Qwen、ChatGLM、Mistral 等)都在使用它。它的数学原理比较复杂,后面单独写一篇文章介绍,现在只需要知道经过RoPE,注意力机制计算出的向量包含了位置信息。

4.4.3 RoPE PyTorch 实现

sin,cos参数计算

def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
    """Precompute cosine and sine tables for Rotary Position Embedding (RoPE)."""
    assert head_dim % 2 == 0, "head_dim must be even for RoPE"

    inverse_frequencies = 1.0 / (
        theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim)
    )

    position_indices = torch.arange(context_length, dtype=dtype)

    # Outer product: each position × each frequency → (context_length, head_dim // 2)
    angles = position_indices.unsqueeze(1) * inverse_frequencies.unsqueeze(0)

    # Duplicate to cover full head_dim → (context_length, head_dim)
    angles = torch.cat([angles, angles], dim=1)

    return torch.cos(angles), torch.sin(angles)

RoPE计算


def apply_rope(hidden_states, cos, sin):
    """
    Apply rotary position embedding to hidden states.

    Args:
        hidden_states: Tensor of shape (batch_size, num_heads, sequence_length, head_dim).
        cos: Precomputed cosine table of shape (max_sequence_length, head_dim).
        sin: Precomputed sine table of shape (max_sequence_length, head_dim).
    """
    batch_size, num_heads, sequence_length, head_dim = hidden_states.shape
    assert head_dim % 2 == 0, "head_dim must be even for RoPE"

    first_half = hidden_states[..., : head_dim // 2]
    second_half = hidden_states[..., head_dim // 2:]

    # Broadcast cos/sin to (1, 1, sequence_length, head_dim)
    cos = cos[:sequence_length, :].unsqueeze(0).unsqueeze(0)
    sin = sin[:sequence_length, :].unsqueeze(0).unsqueeze(0)

    rotated = torch.cat((-second_half, first_half), dim=-1)
    rotated_hidden_states = (hidden_states * cos) + (rotated * sin)

    return rotated_hidden_states.to(dtype=hidden_states.dtype)

4.4.4 在 Multi-Head Attention 中接入 RoPE

RoPE 一般加在 Q、K 被拆成多头之后、计算注意力分数之前:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, num_heads, qkv_bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        assert d_out % num_heads == 0

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)

        mask = torch.tril(torch.ones(context_length, context_length))
        self.register_buffer("mask", mask)

        cos, sin = precompute_rope_params(
            head_dim=self.head_dim,
            context_length=context_length
        )
        self.register_buffer("cos", cos)
        self.register_buffer("sin", sin)

    def forward(self, x):
        B, t, _ = x.shape

        keys    = self.W_key(x)
        queries = self.W_query(x)
        values  = self.W_value(x)

        keys    = keys.view(B, t, self.num_heads, self.head_dim).transpose(1, 2)
        queries = queries.view(B, t, self.num_heads, self.head_dim).transpose(1, 2)
        values  = values.view(B, t, self.num_heads, self.head_dim).transpose(1, 2)

        # 在 Q、K 上应用 RoPE
        queries = apply_rope(queries, self.cos, self.sin)
        keys    = apply_rope(keys, self.cos, self.sin)

        attn_scores = queries @ keys.transpose(-2, -1)

        causal_mask = self.mask[:t, :t] == 0
        attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))

        attn_weights = torch.softmax(attn_scores / (self.head_dim ** 0.5), dim=-1)
        context_vec = attn_weights @ values

        context_vec = context_vec.transpose(1, 2).contiguous().view(B, t, -1)
        output = self.out_proj(context_vec)
        return output

4.5 前馈神经网络(FFN)

通过注意力机制拿到丰富的上下文后,每个 Token 会独立地通过一个全连接的前馈神经网络(FFN)。FFN 通常包含维度很大的隐藏层(通常是输入维度的 4 倍或更大),这个庞大的非线性空间实际上充当了大模型的知识库与记忆区。模型在预训练中学到知识和逻辑推理能力,绝大部分都固化为 FFN 的权重参数。

4.5.1 经典 FFN

经典的 FFN 结构是两层线性层

其中:
  • 第一层把维度升高(升维)

  • 激活函数引入非线性,常用 ReLU 或 GELU

  • 第二层再把维度投回原始大小

例如:

  • 输入维度:4096

  • 中间维度:11008 或 14336

  • 输出维度:4096

中间层维度通常会比输入大很多(输入维度的4倍),这样模型可以在更高维空间中学习到更复杂的特征组合。

FFN 的公式为:

$x = W_2(GELU(W_1(x)))$

4.5.2 SwiGLU

现代大模型普遍采用了效果更好的 SwiGLU 结构。SwiGLU (Swish-Gated Linear Unit)引入了门控(Gating)机制。它通过两个平行的线性层来处理输入:

  1. 一条路负责计算基础特征。

  2. 另一条路经过SiLU激活函数后作为“门控信号”,去控制第一条路的特征哪些该保留,哪些该丢弃。

SwiGLU 的公式为:

$\mathrm{SwiGLU}(x) = \mathrm{SiLU}(xW_1) \odot (xW_2)$

SwiGLU FFN 实现:

import torch.nn as nn
import torch.nn.functional as F

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        """
        dim: 模型的输入维度 d_model
        hidden_dim: FFN 内部的扩展维度
        """
        super().__init__()
        # SwiGLU 需要三个权重矩阵
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)  # 线性变换
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)  # 门控信号
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)  # 降维输出

    def forward(self, x):
        return self.w2(F.silu(self.w3(x)) * self.w1(x))

4.5.3 MoE(Mixture of Experts)

让所有 Token 都经过同一个超大的 FFN,随着参数量的增大计算成本会快速上升。MoE(Mixture of Experts,专家混合)解决了这个问题。

MoE中不再让每个 Token 都经过同一套 FFN,而是准备多组不同的 FFN(称为 Experts,专家),由一个路由器(Router)为每个 Token 动态选择少数几个专家进行计算。这样模型的总参数量可以非常大,但每次前向传播时,真正被激活参与计算的参数量只占一小部分。

图:MoE架构

MoE 架构包含两个主要部分:

  1. Router(路由器):本质上是一个简单的线性层,负责阅读输入的 Token 特征,并计算出该 Token 分配给各个专家的概率权重。

  2. Experts(专家网络):通常是 N 个并行的传统 FFN 或 SwiGLU FFN。

工作流程: 通常采用 Top-K 路由机制(比如有 8 个专家,每次选 Top-2)。对于每个输入的 Token:

  • Router 给 8 个专家打分。

  • 选取得分最高的 2 个专家。

  • 这个 Token 只会被送入这 2 个专家进行计算(其余 6 个专家处于休眠状态,不消耗计算资源)。

  • 将这 2 个专家的输出结果,按照 Router 给出的得分进行加权求和,得到最终输出。

Qwen3 也提供了 MoE 架构的模型,比如 Qwen3-235B-A22B。235B 表示模型总参数量约为 2350 亿,每次推理时实际激活的参数量为 220 亿(22B)。由于只激活其中一部分参数,它也可以称为 sparse model;与之相对,前一节的 FFN 属于 dense model。本文基于 Qwen3-0.6B 架构实现,这里不再展开 MoE 架构的代码实现。

4.6 残差连接(Residual Connection)

Transformer Decoder 不只有一层,而是由很多层 Decoder Block 反复堆叠起来的(例如 LLaMA-3-70B 有 80 层 Decoder Block)。在神经网络中,层数越深,反向传播时梯度就越容易消失,导致深层网络训练困难。为了解决这个问题,Transformer Decoder 引入了残差连接。

在 Decoder Block 中,通常会分别在 Attention 和 FFN 后添加残差连接:

$x=x+Attention(Norm(x))$

$x=x+FFN(Norm(x))$

残差在深度学习已经有广泛的应用,这里不再介绍其原理,感兴趣的话可以看论文 《Deep Residual Learning for Image Recognition》

4.7 组装 Transformer Decoder

现在已经介绍完一个 Decoder 层里的核心模块:

  1. RMSNorm

  2. RoPE

  3. Multi-Head Self-Attention

  4. 残差连接

  5. RMSNorm

  6. FeedForward(SwiGLU)

  7. 残差连接

一个完整的 Transformer Decoder Layer 的前向传播流程为:

$x→RMSNorm→RoPE+Attention→Residual→RMSNorm→FeedForward→Residual$

如下图所示:

图:Transformer Decoder

对应的pytorch代码实现如下:

class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"]
        self.ff = FeedForward(cfg)
        self.norm1 = RMSNorm(cfg["emb_dim"])
        self.norm2 = RMSNorm(cfg["emb_dim"])

    def forward(self, x, mask, cos, sin):
        """
        Args:
            x:    输入隐藏状态,形状 (B, T, d)
                  B = batch_size
                  T = num_tokens
                  d = emb_dim
            mask: 注意力掩码,通常用于 causal mask 或 padding mask
            cos:  RoPE 的 cos 部分,供 Q/K 旋转位置编码使用
            sin:  RoPE 的 sin 部分,供 Q/K 旋转位置编码使用
                  形状通常与 cos 相同

        Returns:
            输出张量,形状 (B, T, d)
        """

        # ===== Attention Block =====
        # 保存残差分支输入
        shortcut = x                        # (B, T, d)
        # 先做归一化,再送入注意力模块(Pre-Norm 结构)
        x = self.norm1(x)                   # (B, T, d) -> (B, T, d)
        # 分组查询注意力
        # 输入: x, mask, cos, sin
        # 输出: (B, T, d)
        x = self.att(x, mask, cos, sin)     # (B, T, d) -> (B, T, d)
        # 残差连接
        x = x + shortcut                    # (B, T, d)

        # ===== FeedForward Block =====
        # 保存第二个残差分支输入
        shortcut = x                        # (B, T, d)
        # 先做归一化,再送入前馈网络
        x = self.norm2(x)                   # (B, T, d) -> (B, T, d)
        # 前馈网络
        x = self.ff(x)                      # (B, T, d) -> (B, T, d)
        # 残差连接
        x = x + shortcut                    # (B, T, d)

        return x                            # (B, T, d)

5. 输出层

前面经过 Embedding、多个 Transformer Decoder Block 之后,模型已经把输入文本逐步转换成了一串带有丰富上下文信息的隐藏状态(Hidden States),但仍然不是我们最终想要的“下一个 token ”。所以还需要一个全连接层(Linear Layer)把隐藏状态重新映射成词表上每个 token 的打分(logits),这里的全连接层也叫做LMHead。

假设模型最后的隐状态向量为 $\mathbf{h} \in \mathbb{R}^{d_{model}}$,其中 $d_{model}$ 是隐藏层维度,词表大小为$V$,LM Head 只包含一个权重矩阵 $\mathbf{W} \in \mathbb{R}^{V \times d_{model}}$,经过一次线性变换:

$\mathbf{z} = \mathbf{W} \cdot \mathbf{h}$

输出的 $\mathbf{z} \in \mathbb{R}^V$ 被称为 Logits(未归一化的对数概率)。$\mathbf{z}$向量中的第 $i$ 个元素,代表了词表中第 $i$ 个 Token 成为“下一个词”的原始得分

对应的代码实现为:

class LMHead(nn.Module):
    def __init__(self, hidden_size, vocab_size, bias=False):
        super().__init__()
        # 定义一个线性层,将 hidden_size 映射到 vocab_size
        self.linear = nn.Linear(hidden_size, vocab_size, bias=bias)
        
    def forward(self, hidden_states):
        """
        hidden_states 维度: [batch_size, seq_len, hidden_size]
        返回 logits 维度: [batch_size, seq_len, vocab_size]
        """
        logits = self.linear(hidden_states)
        return logits

6. 组装 Qwen3

到这里我们已经介绍完一个大语言模型前向传播中的主要模块:

  1. Tokenizer:把文本切分成 token id

  2. Embedding:把离散 id 映射成连续向量

  3. 多层 Transformer Decoder Block:提取和学习上下文特征

  4. Final Norm:稳定输出分布

  5. LM Head:映射回词表,得到 logits

把它们按顺序拼装起来得到一个完整的 Qwen3 Decoder-only 模型。

图:qwen3的完整模型结构

对应的pytorch实现:

class Qwen3Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
        self.decoder_blocks = nn.ModuleList(
            [TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )

        self.final_norm = RMSNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])

        # Reusable utilities
        if cfg["head_dim"] is None:
            head_dim = cfg["emb_dim"] // cfg["n_heads"]
        else:
            head_dim = cfg["head_dim"]
        cos, sin = compute_rope_params(
            head_dim=head_dim,
            theta_base=cfg["rope_base"],
            context_length=cfg["context_length"]
        )
        self.register_buffer("cos", cos, persistent=False)
        self.register_buffer("sin", sin, persistent=False)
        self.cfg = cfg

    def forward(self, in_idx):
        # Forward pass
        tok_embeds = self.tok_emb(in_idx)
        x = tok_embeds

        num_tokens = x.shape[1]
        mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)
        
        for decoder in self.decoder_blocks:
            x = decoder(x, mask, self.cos, self.sin)
        x = self.final_norm(x)
        logits = self.out_head(x.to(self.cfg["dtype"]))
        return logits

7. 自回归推理和采样

经过前面的步骤,我们终于搭建好了一个完整的Qwen3模型,这一节我们将介绍如何使用模型生成文本,以及如何通过采样策略控制生成内容的多样性和质量。

模型在训练和推理时的工作模式有所不同:训练时模型采用并行计算,一次性看到所有历史 Token 预测下一个 Token,而在推理时模型采用的是**自回归(Autoregressive)**生成方式。

7.1 自回归推理流程

自回归推理的流程为用模型自己生成的输出,作为下一步的输入。

假设用户输入提示词(Prompt):“选择大于”,模型生成文本的流程如下:

第一步:将输入序列 ["选", "择", "大", "于"] 输入模型。模型进行前向传播,LM Head 输出词表大小的 logits 向量。提取最后一个位置(对应“于”字)的 logits,通过采样策略(如选概率最大的那个)选出一个 Token,假设是

**第二步:**把新生成的 拼接到原来的输入上,变成 ["选", "择", "大", "于", "努"]。再次输入模型,预测出下一个 Token 为

**第三步:**继续拼接,变成 ["选", "择", "大", "于", "努", "力"],预测下一个 Token。

结束条件:一直重复上述步骤,直到模型预测出了特殊结束符(如 <|im_end|>),或者达到了预设的最大生成长度(max_new_tokens),循环停止。

自回归推理python实现:

def generate_text_basic_stream(model, token_ids, max_new_tokens, eos_token_id=None, sampling_params=None):
    """使用指定的采样策略,逐个生成 token。

    参数:
        model: 语言模型。
        token_ids: 输入的 token id 张量,形状为 (1, seq_len),仅支持单条序列。
        max_new_tokens: 最多生成的新 token 数量。
        eos_token_id: 序列结束 token 的 id,用于提前停止生成。
        sampling_params: SamplingParams 实例,用于控制 temperature、top-k、top-p 等采样参数。
    """
    model.eval()
    with torch.no_grad():
        for _ in range(max_new_tokens):
            # 使用当前 token 序列作为输入运行模型。
            outputs = model(token_ids)

            # 取出最后一个位置的 logits。
            last_token_logits = outputs[0, -1:]

            # 根据最后一个位置的 logits 采样下一个 token。
            next_token_id = sample(last_token_logits, sampling_params)

            # 如果采样结果是 EOS,则提前结束生成。
            if eos_token_id is not None and next_token_id.item() == eos_token_id:
                break

            # 返回新生成的 token id。
            yield next_token_id.item()

            # 将新生成的 token 拼接到原序列末尾,用于下一轮生成。
            token_ids = torch.cat([token_ids, next_token_id], dim=1)

7.2 采样策略与生成多样性

在自回归推理过程中,如何从模型输出的 logits(或概率分布)中选择下一个 Token,直接决定了生成文本的质量、多样性以及逻辑连贯性。如果仅仅每次都选择概率最高的 Token(即贪婪搜索),生成的文本往往虽然通顺但缺乏变化,容易陷入重复或枯燥的模式。为了平衡“确定性”与“创造性”,引入了采样策略和随机采样。

  • **采样策略:**先把原始分布用采样策略改造

  • 生成多样性:从改造后的分布按概率抽样

现代大语言模型通常将温度Temperature缩放Top-K 采样Top-P采样 结合使用。通过调整这些参数,控制模型输出的随机程度。

1. 贪婪搜索 (Greedy Search) 贪婪搜索是最基础的策略,即在每一步都选择概率最高(Logits 最大)的 Token。它保证了绝对的确定性,适用于事实问答、代码生成等需要严格逻辑和准确性的场景。

2. 温度缩放 (Temperature Scaling) 模型输出的原始分数称为 Logits。通常我们会使用 Softmax 函数将 Logits 转化为概率分布。温度缩放通过在 Softmax 之前将 Logits 除以一个标量 $T$(Temperature)来调整分布的形态:

  • $T = 1.0$:不改变原始分布。

  • $T < 1.0$:会使概率分布变得更加“尖锐”,拉大高分和低分 Token 之间的差距,降低长尾 Token 被选中的概率,从而使得输出更加确定和集中。

  • $T > 1.0$:会使概率分布变得更加“平缓”,增加低概率 Token 被选中的机会,提升文本的随机性和创造力。

  • $T = 0$: 退化为贪婪搜索。

3. Top-K 采样 Top-K 采样是一种暴力的截断方法。它将所有 Token 按照 Logits 降序排列,仅保留前 $K$ 个概率最高的 Token,这样可以防止模型在采样时意外生成概率极低的“乱码”或无关词汇。

4. Top-P 采样 (Nucleus Sampling) 与固定的 $K$ 值不同,Top-P 采样是一种动态截断策略,它同样将 Token 按概率降序排列,然后计算累积概率。当累积概率刚好超过预设的阈值$P$时,截断后续的所有 Token。

Top-K采样由于设置固定阈值,仍然可能引入极低概率的“垃圾词”,一般不推荐配置,目前主流做法推荐只通过Temperature和Top-P来控制随机性

采样策略的 Python 实现:

"""
Token的采样策略。

实现了 temperature 缩放、top-k 过滤和 top-p(nucleus)采样,
参考 vLLM 的采样方式。
"""

import torch
from dataclasses import dataclass
from typing import Optional

def _apply_temperature(logits: torch.Tensor, temperature: float) -> torch.Tensor:
    """按 temperature 对 logits 进行缩放。

    较低的 temperature 会使分布更尖锐(更确定),
    较高的 temperature 会使分布更平缓(更随机)。
    """
    if temperature == 0.0 or temperature == 1.0:
        return logits
    return logits / temperature

def _apply_top_k(logits: torch.Tensor, top_k: int) -> torch.Tensor:
    """将 top-k 之外的 logits 置零。

    遵循 vLLM 的做法:按 logit 值从高到低排序,只保留前 top_k 个条目,
    其余位置都屏蔽为 -inf。
    """
    if top_k <= 0 or top_k >= logits.size(-1):
        return logits

    # 移除 logits 低于 top-k 阈值的 token
    top_k_values, _ = torch.topk(logits, top_k, dim=-1)
    min_top_k_value = top_k_values[:, -1].unsqueeze(-1)
    logits = logits.masked_fill(logits < min_top_k_value, float("-inf"))
    return logits

def _apply_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
    """应用 nucleus(top-p)采样。

    遵循 vLLM 的做法:
    1. 按 logits 从高到低排序。
    2. 根据排序后的 softmax 计算累计概率。
    3. 将累计概率超过 top_p 的 token 屏蔽掉。
    4. 再把屏蔽结果散射回原始 logit 的位置。
    """
    if top_p >= 1.0:
        return logits

    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)

    # 创建掩码:保留累计概率 <= top_p 的 token
    # 向右偏移一位,这样第一个超过 top_p 的 token 仍会被保留
    sorted_mask = cumulative_probs - torch.softmax(sorted_logits, dim=-1) > top_p

    # 在排序后的顺序中,将被屏蔽的 logits 设为 -inf
    sorted_logits = sorted_logits.masked_fill(sorted_mask, float("-inf"))

    # 散射回原始顺序
    logits = torch.zeros_like(logits).scatter_(dim=-1, index=sorted_indices, src=sorted_logits)
    return logits

def sample(
    logits: torch.Tensor,
    sampling_params: Optional[SamplingParams] = None,
) -> torch.Tensor:
    """根据给定采样参数,从 logits 中采样下一个 token。

    处理顺序(遵循 vLLM 约定):
    1. Temperature 缩放
    2. Top-k 过滤
    3. Top-p(nucleus)过滤
    4. 多项式采样(若为 greedy 则取 argmax)

    参数:
        logits: 原始 logits,形状为 (batch_size, vocab_size)。
        sampling_params: 采样配置。如果为 None,则使用贪心解码。

    返回:
        形状为 (batch_size, 1) 的 token id。
    """
    if sampling_params is None or sampling_params.is_greedy:
        return torch.argmax(logits, dim=-1, keepdim=True)

    # 1. Temperature 缩放
    logits = _apply_temperature(logits, sampling_params.temperature)

    # 2. Top-k 过滤
    if sampling_params.top_k > 0:
        logits = _apply_top_k(logits, sampling_params.top_k)

    # 3. Top-p(nucleus)过滤
    if sampling_params.top_p < 1.0:
        logits = _apply_top_p(logits, sampling_params.top_p)

    # 4. 转换为概率并进行采样
    probs = torch.softmax(logits, dim=-1)
    next_token_id = torch.multinomial(probs, num_samples=1)

    return next_token_id

8. 推理优化

8.1 使用 KV Cache 推理加速

8.1.1 为什么需要 KV Cache

上节讲到,大模型的文本生成过程中采用自回归方式,即每次只生成一个 token。生成下一个 token 时,模型需要用前面所有的 token 作为上下文进行预测。

如果不做优化,每生成一个新 token,模型都需要重新计算此前序列中token对应的Key和Value矩阵。随着序列不断变长,重复计算会越来越多,推理开销也会持续增加,导致生成速度明显下降,尤其是在长文本场景下更为突出。

KV Cache 解决了上面的问题,它的核心思想是将每一步计算得到的 Key(K)Value(V) 向量缓存起来。在下一步生成时,模型无需再为旧 token 重复计算 K 和 V,只需要对新生成的 token 计算对应的 K、V,并将其追加到缓存中再进行后续Attention计算。

通过这种方式,KV Cache 大幅减少了重复K、V计算,从而提升了长文本生成时的推理效率。

例如我们要生成句子:“选择大于努力”,模型按如下步骤生成:

  1. 输入 [选, 择] -> 计算 Q, K, V -> 预测出 [大]

  2. 输入 [选, 择, 大] -> 计算 Q, K, V -> 预测出 [于]

  3. 输入 [选, 择, 大, 于] -> 计算 Q, K, V -> 预测出 [努]

观察第 2 步和第 3 步:当输入 [选, 择, 大, 于] 时,模型又重新计算了一遍 [选], [择], [大] 这三个词的特征表示(Q,K, V 矩阵)。随着生成文本越来越长,模型每前进一步,都要把前面几千个 Token 重新计算一遍,这会导致推理速度呈指数级下降。

8.1.2 KV Cache 流程

既然历史 Token 的内容没有改变,它们对应的 Key (K) 和 Value (V) 向量也不会改变。我们可以在显存中开辟一块空间,把每一层算好的历史 Token 的 K 和 V 保存下来,这块空间就叫做 KV Cache

引入 KV Cache 后的推理分为两个阶段:

  1. Prefill(预填充)阶段:处理用户输入的 Prompt(如 [选, 择]),一次性计算出所有的 K 和 V 并存入 Cache,预测出第一个新词 [大]

  2. Decode(解码)阶段:此时需要把最新生成的一个词 [大] 送入模型即可。模型只计算 [大] 这个词的 Q, K, V。然后,把 [大] 的 K 和 V 追加到 Cache 中。用 [大] 的 Q 去和 Cache 中所有的 K 做注意力计算,得出下一个词 [于]

流程可参考下图:

图:KV Cache 流程

8.1.3 KV Cache 代码实现

要在模型中接入 KV Cache,需要修改 MultiHeadAttentionforward 方法,让它能够接收历史缓存,拼接新的K、V并更新缓存。

def forward(self, x, kv_cache=None):
    B, t, d_in = x.shape
    
    # 1. 正常计算当前输入 x 的 Q, K, V
    queries = self.W_query(x).view(B, t, self.num_heads, self.head_dim).transpose(1, 2)
    keys    = self.W_key(x).view(B, t, self.num_heads, self.head_dim).transpose(1, 2)
    values  = self.W_value(x).view(B, t, self.num_heads, self.head_dim).transpose(1, 2)
    
    # 2. KV Cache 逻辑
    if kv_cache is not None:
        past_keys, past_values = kv_cache
        # 将当前新计算的 K, V 与历史的 K, V 拼接起来
        keys = torch.cat([past_keys, keys], dim=2)     # 在 sequence 维度 (dim=2) 拼接
        values = torch.cat([past_values, values], dim=2)
        
    # 3. 将更新后的 KV 保存,供下一步使用
    new_kv_cache = (keys, values)
    # 4. 计算注意力 (注意:此时 keys 包含了所有历史,但 queries 只有当前 1 个 token)
    # attn_scores: (B, num_heads, 1, seq_len_total)
    attn_scores = queries @ keys.transpose(-2, -1) / (self.head_dim ** 0.5)
    
    # (Decode 阶段由于 t=1,实际上不需要 Causal Mask,因为当前 token 看不到未来)
  
    attn_weights = torch.softmax(attn_scores, dim=-1)
    context_vec = (attn_weights @ values).transpose(1, 2).contiguous().view(B, t, -1)
    output = self.out_proj(context_vec)
    return output, new_kv_cache # 返回输出和更新后的缓存

8.2 GQA(Grouped-Query Attention)

8.2.1 KV Cache 占用大量显存

KV Cache 用空间换取了时间提升了推理速度,但它却带来了一个新问题:显存占用过大。

假设我们使用一个类似于 Qwen-7B 级别的模型,隐藏层维度 $d=4096$,层数 $L=32$,精度为 FP16(每个数值占 2 字节)。 如果要处理一段长度为 $4096$ 的上下文,并采用并发批处理(Batch Size = 128):

  • 单个 Token 的 KV Cache 占用:$2 \text{ (K和V)} \times 2 \text{ (字节)} \times 32 \text{ (层)} \times 4096 \text{ (维度)} = 512 \text{ KB}$。

  • 全局总 KV Cache 占用:$128 \times 4096 \times 512 \text{ KB} \approx \mathbf{268 \text{ GB}}$

仅是存放这些中间缓存,就需要 4 张 80G 的 A100 显卡。这使得大模型在面对超长上下文(如 Qwen3 支持的 128K 乃至 1M 上下文)或高并发 API 访问时,极易发生 OOM显存溢出。

8.2.2 从 MHA 到 MQA 再到 GQA

为了压缩 KV Cache 的体积,对注意力机制进行了一系列改进:

  1. MHA (Multi-Head Attention): 标准的多头注意力。Q、K、V 的头数完全一样(例如都是 32 个头)。效果最好,但 KV Cache 占用最大。

  2. MQA (Multi-Query Attention): 走向另一个极端。Q 依然保留 32 个头,但 K 和 V 全局只共享 1 个头。这直接把 KV Cache 压缩到了原来的 $\frac{1}{32}$,但代价是,所有的 Q 头都去关注同一个 K/V 空间,导致模型在处理复杂逻辑和细节信息时能力显著下降。

  3. GQA (Grouped-Query Attention): MHA 和 MQA 的折中方案,这也是目前包括 Qwen、LLaMA-3、GLM-4 在内主流开源大模型的标配。

#### 8.2.3 GQA 原理

GQA 的思路是“分组”。 假设模型 Q 有 32 个头(num_heads = 32),我们把它分成 8 组(num_kv_heads = 8)。 那么每组就包含 4 个 Q 头。在这一个组内的 4 个 Q 头,共享 1 个 K 头和 1 个 V 头。

如上图所示:

  • MHA:每一条线独立对应。

  • MQA:所有 Q 头连接到同一个 K/V 头。

  • GQA:相邻的几个 Q 头连接到同一个 K/V 头。

这样一来,KV 头的数量从 32 降到了 8,KV Cache 的大小直接缩小为原来的 1/4。

根据Google的论文《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》中的实验结果,GQA 的模型效果几乎与全量 MHA 保持一致,远超 MQA。

8.2.4 GQA 实现

和MultiHeadAttention实现基本一致,只是同一个group内用共享了key和value矩阵。

def forward(self, hidden_states, mask, cos, sin, start_position=0):
    """
    参数:
        hidden_states: 形状为 (batch_size, sequence_length, embedding_dim)
        mask: 布尔类型的因果掩码。
        cos, sin: RoPE 查找表。
        start_position: 使用 KV cache 时,RoPE 的位置偏移量。
                        0 表示不使用缓存 / 预填充阶段;>0 表示使用缓存进行解码。
    """
    batch_size, sequence_length, _ = hidden_states.shape

    queries = self.query_projection(hidden_states)
    keys = self.key_projection(hidden_states)
    values = self.value_projection(hidden_states)

    # 调整形状为 (batch_size, num_heads_or_kv_groups, sequence_length, head_dim)
    queries = queries.view(batch_size, sequence_length, self.num_heads, self.head_dim).transpose(1, 2)
    keys = keys.view(batch_size, sequence_length, self.num_kv_groups, self.head_dim).transpose(1, 2)
    values = values.view(batch_size, sequence_length, self.num_kv_groups, self.head_dim).transpose(1, 2)

    # 应用 RoPE
    queries = apply_rope(queries, cos, sin, offset=start_position)
    keys = apply_rope(keys, cos, sin, offset=start_position)

    # KV cache:将新的 keys/values 与缓存中的内容拼接起来
    if self.cached_keys is not None:
        keys = torch.cat([self.cached_keys, keys], dim=2)
        values = torch.cat([self.cached_values, values], dim=2)
    self.cached_keys = keys
    self.cached_values = values

    # 为了适配 GQA,将 KV 头重复到与 query 头数量一致(通过 expand 实现零拷贝)
    keys = repeat_kv(keys, self.heads_per_kv_group)
    values = repeat_kv(values, self.heads_per_kv_group)

    attention_scores = queries @ keys.transpose(2, 3)
    attention_scores = attention_scores.masked_fill(mask, -torch.inf)
    attention_weights = torch.softmax(attention_scores / self.head_dim**0.5, dim=-1)

    context_vectors = (attention_weights @ values).transpose(1, 2).reshape(batch_size, sequence_length, self.total_head_dim)
    return self.output_projection(context_vectors)

前面Qwen3的实现中还在使用MultiHeadAttention,我们将其替换为GroupQueryAttention。

## 9. 加载权重运行

现在我们已经从零搭建出了 Qwen3 的模型结构,但此时模型参数是随机初始化的,无法进行有意义的文本生成。要让模型具备真正的推理能力,还需要把官方已经训练好的权重加载进来。

这一节将介绍:

  1. 如何读取 Qwen3-0.6B 的配置与权重

  2. 如何把 Hugging Face 的参数名映射到我们自己实现的模型结构

  3. 如何完成推理

9.1 初始化实现的模型

在加载权重之前,先根据 Qwen3-0.6B 的配置,定义模型超参数.

QWEN3_0_6B_CONFIG = {
    "vocab_size": 151_936,           # 词表大小
    "context_length": 40_960,        # 训练该模型时使用的上下文长度
    "emb_dim": 1024,                 # 嵌入维度
    "n_heads": 16,                   # 注意力头的数量
    "n_layers": 28,                  # 层数
    "hidden_dim": 3072,              # FeedForward 中间层的维度大小
    "head_dim": 128,                 # GQA 中每个头的维度大小
    "qk_norm": True,                 # 是否在 GQA 中对 query 和 key 做归一化
    "n_kv_groups": 8,                # 分组查询注意力中的 Key-Value 分组数
    "rope_base": 1_000_000.0,        # RoPE 中 “theta” 的基数
    "dtype": torch.bfloat16,         # 使用较低精度的数据类型以减少内存占用
}
model = Qwen3Model(QWEN3_0_6B_CONFIG)
model.to(device)

9.2 加载 Tokenizer 和官方权重

# load original qwen weight
weights_filepath = os.path.join(model_local_dir, "model.safetensors")
pretrained_weights = load_file(weights_filepath)
# load weight into our qwen 
load_weights_into_qwen(model, QWEN3_0_6B_CONFIG, pretrained_weights)
del pretrained_weights
# load tokenizer
tokenizer_file_path = f"{model_save_dir}/tokenizer.json"
tokenizer = Qwen3Tokenizer(
    tokenizer_file_path=tokenizer_file_path,
    repo_id=model_repo_id,  
    apply_chat_template=True,
    add_generation_prompt=True,
    add_thinking=False
)

9.3 权重映射

我们实现的模型类和官方实现的模块命名方式有些不同,需要手动写一个映射函数,把官方参数拷贝到自己的模型中

def load_weights_into_qwen(model, model_config, pretrained_weights):
    def assign(target_param, source_weight, tensor_name="unknown"):
        if target_param.shape != source_weight.shape:
            raise ValueError(f"{tensor_name}: {target_param.shape} != {source_weight.shape}")
        
        with torch.no_grad():
            if isinstance(source_weight, torch.Tensor):
                target_param.copy_(source_weight)
            else:
                target_param.copy_(torch.as_tensor(source_weight, dtype=target_param.dtype, device=target_param.device))
    
        return target_param 

    model.token_embedding.weight = assign(model.token_embedding.weight, pretrained_weights["model.embed_tokens.weight"], "model.embed_tokens.weight")

    for layer_idx in range(model_config["n_layers"]):
        transformer_block = model.transformer_blocks[layer_idx]
        attention = transformer_block.attention

        # Q, K, V projections
        attention.query_projection.weight = assign(
            attention.query_projection.weight,
            pretrained_weights[f"model.layers.{layer_idx}.self_attn.q_proj.weight"],
            f"model.layers.{layer_idx}.self_attn.q_proj.weight"
        )
        attention.key_projection.weight = assign(
            attention.key_projection.weight,
            pretrained_weights[f"model.layers.{layer_idx}.self_attn.k_proj.weight"],
            f"model.layers.{layer_idx}.self_attn.k_proj.weight"
        )
        attention.value_projection.weight = assign(
            attention.value_projection.weight,
            pretrained_weights[f"model.layers.{layer_idx}.self_attn.v_proj.weight"],
            f"model.layers.{layer_idx}.self_attn.v_proj.weight"
        )

        # Output projection
        attention.output_projection.weight = assign(
            attention.output_projection.weight,
            pretrained_weights[f"model.layers.{layer_idx}.self_attn.o_proj.weight"],
            f"model.layers.{layer_idx}.self_attn.o_proj.weight"
        )

        # QK norms
        if hasattr(attention, "query_norm") and attention.query_norm is not None:
            attention.query_norm.scale = assign(
                attention.query_norm.scale,
                pretrained_weights[f"model.layers.{layer_idx}.self_attn.q_norm.weight"],
                f"model.layers.{layer_idx}.self_attn.q_norm.weight"
            )
        if hasattr(attention, "key_norm") and attention.key_norm is not None:
            attention.key_norm.scale = assign(
                attention.key_norm.scale,
                pretrained_weights[f"model.layers.{layer_idx}.self_attn.k_norm.weight"],
                f"model.layers.{layer_idx}.self_attn.k_norm.weight"
            )

        # Attention layernorm
        transformer_block.attention_norm.scale = assign(
            transformer_block.attention_norm.scale,
            pretrained_weights[f"model.layers.{layer_idx}.input_layernorm.weight"],
            f"model.layers.{layer_idx}.input_layernorm.weight"
        )

        # Feedforward weights
        transformer_block.feed_forward.gate_projection.weight = assign(
            transformer_block.feed_forward.gate_projection.weight,
            pretrained_weights[f"model.layers.{layer_idx}.mlp.gate_proj.weight"],
            f"model.layers.{layer_idx}.mlp.gate_proj.weight"
        )
        transformer_block.feed_forward.up_projection.weight = assign(
            transformer_block.feed_forward.up_projection.weight,
            pretrained_weights[f"model.layers.{layer_idx}.mlp.up_proj.weight"],
            f"model.layers.{layer_idx}.mlp.up_proj.weight"
        )
        transformer_block.feed_forward.down_projection.weight = assign(
            transformer_block.feed_forward.down_projection.weight,
            pretrained_weights[f"model.layers.{layer_idx}.mlp.down_proj.weight"],
            f"model.layers.{layer_idx}.mlp.down_proj.weight"
        )
        transformer_block.feed_forward_norm.scale = assign(
            transformer_block.feed_forward_norm.scale,
            pretrained_weights[f"model.layers.{layer_idx}.post_attention_layernorm.weight"],
            f"model.layers.{layer_idx}.post_attention_layernorm.weight"
        )

    # Final normalization and output head
    model.final_norm.scale = assign(model.final_norm.scale, pretrained_weights["model.norm.weight"], "model.norm.weight")

    if "lm_head.weight" in pretrained_weights:
        model.output_projection.weight = assign(model.output_projection.weight, pretrained_weights["lm_head.weight"], "lm_head.weight")
    else:
        model.output_projection.weight = model.token_embedding.weight
        print("Model uses weight tying.")

9.4 完整流程

# --------------------------------------------------
#  从modelscope下载模型
# --------------------------------------------------
model_repo_id = "Qwen/Qwen3-0.6B"
model_save_dir = f"model_repo/{model_repo_id}"
model_local_dir = snapshot_download(model_id=model_repo_id, local_dir=model_save_dir)

# --------------------------------------------------
#  初始化自己实现的模型
# --------------------------------------------------
model = Qwen3Model(QWEN3_0_6B_CONFIG)
model.to(device)

# --------------------------------------------------
#  加载 tokenizer和官方权重
# --------------------------------------------------
# load original qwen weight
weights_filepath = os.path.join(model_local_dir, "model.safetensors")
pretrained_weights = load_file(weights_filepath)
# load weight into our qwen 
load_weights_into_qwen(model, QWEN3_0_6B_CONFIG, pretrained_weights)
del pretrained_weights
# load tokenizer
tokenizer_file_path = f"{model_save_dir}/tokenizer.json"
tokenizer = Qwen3Tokenizer(
    tokenizer_file_path=tokenizer_file_path,
    repo_id=model_repo_id,  
    apply_chat_template=True,
    add_generation_prompt=True,
    add_thinking=False
)

# --------------------------------------------------
#  自回归生成
# --------------------------------------------------
prompt = input("Please input your prompt: ")
input_token_ids = tokenizer.encode(prompt)
input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)

# sampling params (matching Qwen3-0.6B generation_config.json defaults)
sampling_params = SamplingParams(temperature=0.6, top_k=20, top_p=0.95)

if torch.cuda.is_available():
    torch.cuda.reset_peak_memory_stats()

start_time = time.perf_counter()
generated_tokens = 0

for token_id in generate_text_basic_stream(
    model=model,
    token_ids=input_token_ids_tensor,
    max_new_tokens=500,
    eos_token_id=tokenizer.eos_token_id,
    sampling_params=sampling_params
):
    generated_tokens += 1
    print(tokenizer.decode([token_id]), end="", flush=True)

elapsed_seconds = time.perf_counter() - start_time
generation_speed = generated_tokens / elapsed_seconds if elapsed_seconds > 0 else 0.0
print(f"\n\nGeneration speed: {generation_speed:.2f} tokens/sec")

if torch.cuda.is_available():
    def calc_gpu_gb(bytes_count):
        return f"{bytes_count / 1024 / 1024 / 1024:.2f} GB"

    print(f"GPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}")

9.5 运行

python inference/qwen3_infer.py
Downloading Model from https://www.modelscope.cn to directory: /home/admin/workspace/easy-llm/model_repo/Qwen/Qwen3-0.6B
2026-03-09 20:21:53,981 - modelscope - INFO - Target directory already exists, skipping creation.
Please input your prompt: hi
Hello! How can I assist you today?

Generation speed: 13.02 tokens/sec
GPU memory used: 1.48 GB

10. 总结和展望

本文沿着一次推理时数据在模型内部流动的方向,梳理了大语言模型的核心结构,并以 Qwen3-0.6B 为例,实现了模型推理的完整流程。

本文依次介绍了:

  1. Tokenizer:将原始文本切分为 token,并映射为离散的 token id;

  2. Embedding Layer:将离散 id 映射到连续向量空间,使模型能够对文本进行数值计算;

  3. Transformer Decoder Block:包括 RMSNorm、RoPE、注意力机制、SwiGLU 前馈网络和残差连接,是模型提取上下文特征和建模语言规律的核心;

  4. LM Head:将隐藏状态映射回词表空间,输出下一个 token 的 logits;

  5. 自回归生成与采样策略:介绍了模型如何逐 token 生成文本,以及 temperature、top-k、top-p 等采样方法如何影响生成结果;

  6. 推理优化:说明了 KV Cache 和 GQA 为什么重要,以及它们如何在保证效果的同时提升推理速度、降低显存占用。

回头再看主流开源模型,会发现它们虽然名字不同、参数规模不同,但在核心思路上非常相似,没有脱离Transformer Decoder架构,理解一个通用的大模型基础架构,比记忆某一个具体模型的名字更重要。

实际上,当前大模型能力的突飞猛进,很大程度上源于海量高质量的训练数据和训练技巧的提升,下一篇文章《从零构建大语言模型(下):模型训练》将介绍从数据构造、损失函数设计、预训练与指令微调,到偏好优化、强化学习和蒸馏的完整训练流程。

如果你对 AI Agent 感兴趣,想进一步了解 MCP、Skill、SubAgent,以及 Claude Code、Qwen Code 等 AI 编程助手的实现原理,也可以继续阅读本系列文章:《从零构建AI Agent》

本项目完整代码仓库:

https://github.com/msober/learn-llm

留言

欢迎补充观点、指出疏漏,或者继续展开文中的问题。

Table of Contents