03_Seq2Seq与注意力机制
1 Seq2Seq 模型
Seq2Seq 并不是一种全新的网络架构,而是传统序列模型的一种应用。传统的自然语言处理任务(如文本分类、序列标注)以静态输出为主,其目标是预测固定类别或标签。
而在机器翻译、文本摘要、问答系统等任务中,输入和输出都是长度动态可变的序列,使用传统的单一序列模型,只能完成 N vs 1,或者 N vs N,因此提出了 Seq2Seq(Sequence to Sequence,序列到序列)模型,可以完成 N vs M 的任务。
1.1 基本结构
Seq2Seq 模型由一个编码器(Encoder)和一个解码器(Decoder)构成。
- 编码器:提取输入序列的语义信息,并压缩为上下文向量(Context Vector)
- 解码器:使用来自编码器的上下文向量,逐步生成目标序列

1.1.1 编码器
编码器主要由一个循环神经网络(RNN/LSTM/GRU)构成,依次接收每个 token 的输入,并在每个时间步更新隐藏状态。每个隐藏状态都携带了截止到当前时间步的历史信息,最终在最后一个时间步形成一个包含整句信息的隐藏状态。最终的隐藏状态就可以作为上下文向量(context vector) 传递给解码器用于指导后续的序列生成。
这个循环神经网络也可以采用双向结构(结合前文与后文信息)或多层结构(提取更深的语义特征),以增强模型理解能力。

1.1.2 解码器
解码器主要也由一个循环神经网络组成,这个循环神经网络将编码传来的上下文向量作为初始隐藏状态(也就是编码器的历史记忆),并接收一个特殊起始符
<sos>(start of sentence)作为第一个时间步的输入,用于预测第一个
token。
在后续的每个时间步,将上一步的预测结果作为下一步的输入,并生成一个 token 作为输出。这种的方式被称为自回归生成(Autoregressive Generation),它确保了生成结果的连贯性。
直到最终生成了一个特殊结束符
<eos> (end of sentence),表示句子生成完成。
起始符和结束符会在训练数据中显式添加,模型会在训练中学会何时开始、如何续写,以及何时结束,从而掌握完整的生成流程。

1.2 交叉熵损失
解码器每一个时间步的输出经过线性层转换为对每个词的预测得分,使用 softmax 转换为概率分布,使用标准的损失函数————多元交叉熵损失函数。在每一步,我们都力求最大化模型分配给正确标签的概率。

1.3 训练和推理策略
编码器在训练阶段和推理阶段使用相同的策略,而解码器会使用不同的策略。
1.3.1 训练阶段
在训练阶段,解码器会使用 Teacher Forcing 策略,每一个时间步的输入不是上一个时间步的输出,而是真实的目标序列。避免了在训练阶段由于初始预测错误,而在后续时间步中不断累积错误。有两个明显的好处:
- 训练更快,误差不会累积
- 梯度传播更稳定,有利于优化收敛
在预测完成之后,解码器每一步输出一个 token,每个时间步的损失本质上就是多分类任务的交叉熵,一个样本的总损失就是所有时间步的交叉熵之和。

1.3.2 推理阶段
在推理阶段,解码器生成方式采用自回归生成(Autoregressive Generation),每一步的输出会作为下一步的输入,逐步构造完整句子。
而每一个时间步的输出实质是一个所有词的概率分布,那么选择哪个词作为下一步的输入,有两种常见的词选择策略:
- 贪心策略(Greedy
Strategy):每次选择概率最大的词作为下一步的输入。
- 优点:简单,计算量小
- 缺点:容易陷入局部最优,生成不够多样
- 束搜索(beam
search):每次保留几个概率较大的假设词作为下一步的输入。束尺寸通常为
4-10。
- 优点:全局考虑,生成质量高,生成更加多样
- 缺点:计算量大

1.4 代码实现
由于篇幅原因,这里只展示基本的模型架构。
- 编码器
1 | |
- 解码器
1 | |
注意双向编码器最后一个时间步的输出为 [batch_size, encoder_hidden_size * 2],单向解码器的隐状态输入维度 decoder_hidden_size 要等于 encoder_hidden_size * 2,才能保持维度匹配。
1.5 存在问题
Seq2Seq 架构下,编码器将整个输入序列转换为一个固定长度的上下文向量,作为解码器生成目标序列的唯一参考,这种机制存在以下问题:
- 语义丢失:无论多长的输入序列,都要被编码器压缩为一个上下文向量,导致信息被大大压缩,语义表达不完整。
- 缺乏动态感知:解码器生成目标序列时,只能参考唯一的上下文向量,不能有选择地关注输入序列中的不同部分。
2 注意力机制
为了解决 Seq2Seq 的问题,引入了 Attention 机制。核心思想是解码器在生成目标序列时,不再仅仅依靠静态的上下文向量,而是动态地从编码器各时间步隐状态中选取最相关的信息。这种机制能够使解码器在生成当前输出时,自动判断原始输入序列中哪些时间步信息最重要,从而提升生成质量。
2.1 核心机制
- 相关性计算:解码过程中,解码器在每个时间步 t 会计算其上一个隐状态 ht − 1 与编码器中每个隐状态 s1, s2, …, sm 的相关性。这个计算称为注意力评分函数,接收一个解码器状态和一个编码器状态,并返回一个标量值 score(ht − 1, sk)。
- 计算注意力权重:使用 softmax 函数计算将得分归一化为概率分布,作为注意力权重,表示各编码器隐状态在当前解码时的重要性。
- 更新上下文向量:解码器利用这些注意力权重,对编码器输出的所有隐状态 sk 进行加权求和,形成新的上下文向量 ct,用于聚合与当前预测最相关的信息。
- 解码信息融合:将新的上下文向量 ct 与解码器当前时间步的输入 xt 进行融合(通常使用拼接),融合后的向量作为解码器当前时间步的输入,进一步用于生成当前的输出。

下图以 RNN 为例,更为详细地描述了整个计算过程。

2.2 注意力评分函数
注意力评分函数有多种实现方式,虽然在结构上各有差异,但本质上都是用于衡量解码器当前隐藏状态与编码器各时间步隐藏状态之间的相关性。
2.2.1 点积注意力
点积注意力(Dot-Product Attention)通过计算解码器当前时间步的隐藏状态与编码器每个时间步的隐藏状态的点积,来衡量二者之间的相关性。如果两个向量方向越相似,它们的点积就越大,表示相关性越强,非常简单直接。
score(ht, sk) = ht ⋅ skT
两个有关联的词元,它们的向量如何变得相似?会在不断的训练当中为二者分配合适的词向量,使其变得相似,点积值也会变大。
为了提升数值稳定性,防止随着维度的增大,点积得分变得越来越大,从而在应用 softmax 时造成梯度消失或梯度爆炸等问题。可以把得分除以一个常数,通常是该向量维度的平方根,这就是缩放点积(Scaled Dot-Product Attention),也是 Transformer 中的实现方式。
$$score(h_t,s_k)=\frac{h_t \cdot {s_k}^T}{\sqrt{d}}$$
通用点积(General Dot-Product Attention)在点积的基础上引入了一个可学习的权重矩阵 W,先对编码器隐藏状态进行线性变换,再与解码器隐藏状态进行点积。不仅解决了编码器和解码器隐藏状态维度不一致的问题,还提升了模型的表达能力。
score(ht, sk) = ht ⋅ W ⋅ skT
2.2.2 加性注意力
加性注意力(Additive Attention)将解码器的隐状态和编码器的隐状态进行拼接后,通过一个全连接层,经过线性变换和非线性激活函数,最后和一个权重向量相乘,从而将多维向量投影为一个标量作为注意力得分。加性注意力通过引入非线性,模型能够捕捉更复杂的相似性关系。
score(ht, sk) = w2T ⋅ tanh (W1[ht, sk])
下图是 Bahdanau 模型的构造,由双向编码器组成,图上的计算方式就是加性注意力。

2.3 代码实现
PyTorch 提供了专门处理 3D 张量运算的模块
torch.bmm,全称是 Batch Matrix
Multiplication(批量矩阵乘法),不像 @ 和
torch.matmul 那样有广播机制,如果输入不是 3D
张量,直接报错,相当于提供了一个维度的强制检查。
torch.bmm(input, mat2),参数和形状:
- input: (b, n, m)
- mat2: (b, m, p)
其中 Batch 维度必须相同,这是并行的基础,内部维度和矩阵乘法规则相同(前一个矩阵的列数 = 后一个矩阵的行数),输出结果的形状为 (b, n, p)。
1 | |
tensor([[[ 0.2251, 1.9282, 0.2330, -0.7461],
[-1.3417, 3.0018, 1.7747, -1.1693],
[-4.1933, -2.5293, 3.4258, 1.8658]],
[[ 1.1425, 0.2468, 1.2771, 1.0578],
[-0.6750, 2.0262, 0.7572, 4.7870],
[-4.0731, -1.8482, -0.8492, 2.4882]]])
加入注意力机制,主要改变的是解码器部分,编码器不需要修改。
- 注意力机制
1 | |
- 解码器
1 | |
2.4 存在问题
尽管注意力机制极大地增强了 Seq2Seq 模型的建模能力,但是核心依然是基于传统的序列模型,RNN 结构始终存在两个问题无法得到根本解决:
- 长期依赖关系难以建模:在处理超长序列时,需要跨越多个时间步来传递信息,训练过程容易出现梯度消失。
- 无法并行计算:由于序列模型的时间步之间存在强依赖,必须顺序执行,无法利用硬件资源并行计算,限制了模型训练效率。