Normalization and Attention

Creation Time: 2025-04-25 19:37


tech notes

本节将对大模型中常用的RMSNorm和基于旋转位置编码的注意力机制做介绍

1 Normalization

1.1 LayerNorm:

layerNorm是自然语言处理任务中最为常用的一种正则化函数,和BatchNorm不同的在于,它计算的是每个样本的隐藏层的正则。

对于输入\(x \in \mathbb{R}^{B \times L \times H}\)(其中B表示batch size,L表示序列长度,H表示特征维度),LayerNorm通过如下步骤完成对特征维度的正则化处理。

  • 按照公式(1.1)得到均值\(\mu\)(一阶原点矩):
\[\mu = \frac{1}{H}\sum^H_{i=1} x_i \tag{1.1}\]
  • 按照公式(1.2)得到方差\(\sigma^2\)(二阶中心矩):
\[\sigma^2 = \frac{1}{H} \sum ^H _{i=1} (x_i - \mu)^2 \tag{1.2}\]
  • 最后再按照公式(1.3)得到正则化值:
\[\text{LayerNorm}(x) = \gamma \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta \tag{1.3}\]

其中,\(\epsilon\)是为了避免除零错误而引入的一个很小的数,\(\gamma\)和\(\beta\)分别表示可学习的缩放参数和偏置。

一个简单的实现如下:

import torch
import torch.nn as nn

class LayerNorm(nn.Module):
  def __init__(self, dim: int, eps=1e-5):
    super().__init__()
    self.eps = eps
    self.gamma = nn.Parameter(torch.ones(dim))
    self.beta = nn.Parameter(torch.zeros(dim))
  
  def forward(self):
    mean = x.mean(dim=-1, keepdim=True)
    var = x.var(dim=-1, keepdim=True, unbiased=False)
    x_norm = (x - mean) / torch.sqrt(var + self.eps)

    return self.gamma * x_norm + self.beta

而在LLaMA等大模型中,用的更多的是RMSNorm(Root Mean Square Norm),具体计算如下:

1.2 RMSNorm:

  • 类似的输入\(x\),首先根据公式(1.4)计算出均方根(二阶原点矩开根号):
\[\text{RMS}(x) = \sqrt{\frac{1}{H} \sum ^H _{i=1}x^2_i} \tag{1.4}\]
  • 之后按公式(1.5)得到正则化的值:
\[\text{RMSNorm}(x) = \gamma \frac{x}{\text{RMS}(x) + \epsilon} \tag{1.5}\]

其中\(\gamma\)和\(\epsilon\)分别是可学习的缩放参数和防止除零操作。

和LayerNorm做对比,RMSNorm的主要区别在于:直接计算二阶原点矩,而不是先计算二阶中心矩,因此少了求均值的操作(没有了减均值的步骤)

正是因为少了一步操作,RMSNorm计算步骤更少,因此速度更快,对于对计算量消耗巨大的大模型而言非常有益。

实际验证发现RMSNorm与LayerNorm相比,性能上的损失并不大,但其带来的速度增益更可观。

其简单的实现如下:

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps=1e-8):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        x_norm = x / rms
        return self.gamma * x_norm

这里做了一个(并不严谨)非常简单的小实验,在NYT24数据集上做分类(不求三元组,只求每个样本中可能存在的关系,即多标签分类)

  • 模型方面,使用BERT的分词器和nn.Embedding构建了一个非常简单的模型,其参数量仅为7.92M
LayerNorm f1-score RMSNorm f1-score
Figure 1.1 LayerNorm和RMSNorm的F1值变化曲线

如图1.1所示,两个不同正则化函数的最好F1值分别为:

  • LayerNorm: 0.9585

  • RMSNorm: 0.9559

可见二者的性能差别并不大。

LayerNorm step loss RMSNorm step loss
Figure 1.2 LayerNorm和RMSNorm的loss变化曲线

图1.2显示了两个正则函数的loss变化,虽然一开始RMSNorm有着略微大一些的loss,但并不影响RMSNorm的学习效率。

二者最终的训练时间分别为:

  • LayerNorm: 231.2874秒

  • RMSNorm: 214.5342秒

可见RMSNorm确实有着更少的学习时间,所节省的时间在比7.92M大的多的模型上会更具优势。


2 Attention with RoPE

下文需要先了解绝对位置编码和相对位置编码,可参见:

旋转位置编码的来源和详细介绍可以参见:

RoFormer的提出动机非常简单,苏剑林在博客中自称是通过\(e^{im\theta}\bar{e^{in\theta}} = e^{i(m-n)\theta}\)发现可以构建出一种实际上是相对,表达上又是绝对形式的位置编码。

在介绍具体含义前,先引入两个前置知识:

2.1 前置知识

2.1.1 二维平面的旋转

从计算机图形学的视角来看,什么是旋转

假设有二维向量\(v = \begin{bmatrix} x \\ y \end{bmatrix}\),对他做二维平面旋转,旋转角度为\(\theta\),可以通过旋转矩阵实现。旋转矩阵的形式如公式(2.1)所示。

\[R_{\theta} = \begin{bmatrix} \cos\theta & -\sin\theta \\ \sin \theta & \cos \theta \\ \end{bmatrix} \tag{2.1}\]

因此旋转后的向量如公式(2.2)所示。

\[v' = R_{\theta} \cdot v \tag{2.2}\]

这个操作在计算机图像学中表示“将图像绕原点旋转\(\theta\)度”(在实现时需要先计算出图像的原点坐标)。

2.1.2 二维复平面的旋转

复数\(e^{i\theta}\)表示了一个沿着单元圆逆时针旋转角度为\(\theta\)的点的位置(欧拉公式),即\(e^{i\theta} = \cos\theta + i \sin \theta\),其中\(\cos\theta\)是水平方向的坐标(实部),\(\sin\theta\)是竖直方向的坐标(虚部)。

而两个点\(z_1 = e^{im\theta}\)和\(z_2 = e^{in\theta}\)之间的乘积如公式(2.3)表示:

\[\begin{equation} \begin{aligned} z_1 \times z_2 &= e^{im\theta} \times e^{in\theta}\\ & = \cos m\theta \cos n\theta + i^2 \sin m \theta \sin n \theta + i (\cos m \theta \sin n \theta + \sin m \theta \cos n \theta) \\ & = \cos (m+n)\theta + i \sin (m+n)\theta \\ & = e ^ {i (m+n) \theta} \end{aligned} \tag{2.3} \end{equation}\]

也就是说,二者的乘积实际上表示了两个点在复平面上的角度之和。

如果用\(z_2\)的共轭\(\bar{z_2} = e^{-i n \theta}\),就变成了公式(2.4)所示的角度之差:

\[z_1 \times \bar{z_2} = e^{i (m-n)\theta} \tag{2.4}\]

因此就能够建模出两个复数之间的相对性。


根据前置知识,假设二维平面上有一个点为\((x, y)\),则它旋转\(\theta\)后的坐标如公式(2.5)所示:

\[\begin{bmatrix} \cos \theta & - \sin \theta \\ \sin \theta & \cos \theta \end{bmatrix} \begin{bmatrix} x \\ y \end{bmatrix} = \begin{bmatrix} x \cos \theta - y \sin \theta \\ x \sin \theta + y \cos \theta \end{bmatrix} \tag{2.5}\]

同理,假设二维复平面有 \(z = x + i y\),则乘上旋转角度的复数变化\(e ^ {i\theta}\)后的复数如公式(2.6)所示:

\[\begin{equation} \begin{aligned} e^{i\theta} \cdot z & = (\cos \theta + i \sin \theta)(x + i y) \\ & = (x \cos \theta - y \sin \theta) + i (x \sin \theta + y \cos \theta) \end{aligned} \tag{2.6} \end{equation}\]

可见,公式(2.6)变化后的实部和虚部恰好和与公式(2.5)旋转后的坐标相对应,所以本质上一个复数乘以\(e^{i \theta}\),就是乘了一个旋转矩阵。

2.2 旋转位置编码

本来想写不少的,但发现已经有不少优秀的博客做了解释,可以直接参考,如:

简单来说,旋转位置编码就是实现了使每个token的位置可以直接被计算出来,就像绝对位置编码那样,同时又能在注意力计算时建模出token之间的相对位置关系。

在实际计算时,对于位置为\(i\)的token,将其特征向量分组(按照奇偶进行分组),从而有公式(2.7)所示的转换。

\[\text{RoPE}(x_i) = (x_i^{even} \cos i\theta - x_i^{odd} \sin i\theta) \oplus (x_i^{even} \sin i\theta + x_i ^{odd}\cos i\theta) \tag{2.7}\]

因此,当计算注意力相关性分数\(q_m^Tk_n\)时,可以得到含有相对位置信息\((m-n)\)的特征。

也就是说,RoPE将每个token的特征向量都看作是由若干个二维子空间组成的向量,然后对于每个二维子空间上都根据token所在位置\(i\)旋转角度\(i\theta\)。

而在计算相关性分数时,由于复数的性质(或者说旋转矩阵的特点),自然的就会得到任意两个token之间的位置信息,相当于把任意两个token看作是一个旋转了\((i-j)\theta\)度。