01 · Transformer 结构与 Attention 计算

目标:理解 LLM 推理时到底在算什么——从 token 输入到新语义向量输出的完整计算流程


一、推理的本质:一切都是矩阵乘法

神经网络是一个函数:输入一组数字,输出一组数字。推理(Inference)就是执行这个函数,不涉及任何参数更新。

LLM 推理:

  • 输入:一串 token(文字被切分并编码成数字序列)
  • 输出:下一个 token 的概率分布

整个推理过程中,绝大部分计算量只有一种运算:矩阵乘法(GEMM)。这是 GPU 适合做推理的根本原因——GPU 专为大规模并行矩阵乘法设计。


二、Attention 出现之前:RNN 的问题

RNN 处理语言序列的方式是逐词传递隐状态

1
2
3
"猫" → [RNN] → h₁
"喜欢" → [RNN, 读入 h₁] → h₂
"鱼" → [RNN, 读入 h₂] → h₃

问题:隐状态每传一步就被压缩一次,越早的词信息越难保留。处理”鱼”时,”猫”的信息已经被稀释了两次。这叫长程依赖问题

另一个问题:必须严格按顺序处理,无法并行,训练极慢。


三、Attention 的核心思想

Attention 不依赖隐状态传递,而是让每个词直接访问序列中任意其他词的信息

处理”鱼”时,不靠传来的隐状态,而是直接问:整个序列里哪个词和我最相关?→ 直接看到”猫”,借来它的信息。

维度 RNN Transformer
信息传递 逐步压缩传递 任意两词直接交互
长距离依赖 难,信息随距离衰减 容易,距离不影响
并行计算 不能(必须顺序) 能(所有位置同时算)
计算复杂度 O(n) O(n²)(两两交互)

Attention 用 O(n²) 的计算代价,换来了并行性和全局依赖能力。


四、Self-Attention 计算流程

以”猫 喜欢 鱼”三个词为例,计算”鱼”的新向量。

第一步:生成 Q、K、V

每个词的原始向量 x,分别乘以三个权重矩阵,得到三个角色:

1
2
3
Q (Query)  = x × W_Q   → "我在找什么信息?"
K (Key) = x × W_K → "我能提供什么信息的标签?"
V (Value) = x × W_V → "我实际携带的内容"

注意:K 和 V 的职责不同。

  • K 是”门牌号”——用来和其他词的 Q 做匹配,决定相关度
  • V 是”房间里的东西”——被关注之后,真正流进其他词的内容

两者用不同权重矩阵学出,可以独立优化:K 学”怎么让别人找到我”,V 学”被找到后给出什么”。

第二步:Q 和所有 K 做点积,得到相关度分数

1
2
3
score(鱼, 猫)   = Q_鱼 · K_猫   = 0.74
score(鱼, 喜欢) = Q_鱼 · K_喜欢 = 0.24
score(鱼, 鱼) = Q_鱼 · K_鱼 = 0.62

点积越大 = “鱼”认为这个词越相关。

第三步:除以 √d_k,再 Softmax

除以 √d_k 的原因:d_k 维的向量点积是 d_k 个乘积的求和,维度越高数值越大。不除的话,Softmax 输入过大,输出极端(最高分接近 1,其余接近 0),导致训练时梯度消失。

1
[0.74, 0.24, 0.62] ÷ √2 → Softmax → [0.42, 0.26, 0.32]

Softmax 把任意数字变成加起来等于 1 的概率分布:

1
Softmax(xᵢ) = e^xᵢ / Σ e^xⱼ

使用 Softmax 而非直接归一化的原因:Softmax 是指数函数,会放大差距——相关度高的词拿到更多注意力,低的被进一步压制。这让模型能学到”强焦点”。

第四步:用注意力权重对所有 V 做加权求和

1
输出 = 0.42 × V_猫 + 0.26 × V_喜欢 + 0.32 × V_鱼

这是线性组合,不是点积。结果是”鱼”融合了整个序列语义后的新向量。

完整公式:

1
Attention(Q, K, V) = Softmax(Q Kᵀ / √d_k) × V

五、Multi-Head Attention

单组 Q/K/V 只能捕捉一种”关注视角”。Multi-Head Attention 并行跑 h 组独立的 Attention,每组用不同权重矩阵,最后拼接投影:

1
2
Head_i = Attention(Q_i, K_i, V_i)
输出 = Concat(Head_1, ..., Head_h) × W_O

W_O 是拼接后的投影矩阵,把 h 倍维度压回原始维度。

不同的 head 会学到不同的关注模式(语法关系、语义关系、指代关系等)。


六、Transformer Block 完整结构

1
2
3
4
5
6
7
8
9
10
11
12
13
输入 x

Self-Attention(W_Q / W_K / W_V / W_O)
↓ + 残差连接(加上原始 x,防止梯度消失)

Layer Norm(归一化,稳定训练)

FFN:x × W₁ → ReLU → × W₂
↓ + 残差连接

Layer Norm

输出(进入下一个 Block)

FFN 的作用:

Self-Attention 做的是词间信息路由——决定从哪些词借多少信息,结果是线性组合。

FFN 做的是每个词独立的非线性加工——对混合后的向量做深度语义变换,互不影响。

Self-Attention FFN
交互范围 整个序列(词间交互) 单个位置(独立处理)
变换类型 线性组合 非线性变换(含 ReLU)
作用 收集上下文信息 消化加工语义

为什么需要非线性(ReLU):

纯矩阵乘法是线性变换,多层叠加等价于一层,深度无意义。ReLU(负数清零,正数保留)引入非线性,让网络能表达条件逻辑,无法被合并压缩。

FFN 的 W₁/W₂ 矩阵乘法(GEMM)是推理计算量的大头之一,是后续 kernel 优化的主要对象。


七、完整模型结构

1
2
3
4
5
6
7
8
9
输入 tokens

Embedding 层(token → 向量)+ 位置编码

Transformer Block × N 层(每层结构相同,参数不同)

输出层(向量 → 词表概率分布)

下一个 token

GPT-3:96 层,模型维度 12288,96 个 attention head,175B 参数。


参考材料

  1. The Illustrated Transformerhttps://jalammar.github.io/illustrated-transformer/
  2. Attention Is All You Needhttps://arxiv.org/abs/1706.03762