BatchNorm 和 LayerNorm 的区别

TL; DR

其实二者的区别一张图就可以说明了:

BatchNorm 和 LayerNorm 的运算方向

  • 对于 BatchNorm,它将一个 batch 里各个抽样特征的同个下标间做标准化。也就是作用在「样本」维度上。
  • 对于 LayerNorm,它在一个抽样的特征中做标准化。也就是作用在「通道」维度上。

BatchNorm

BatchNorm 的更加明确的定义为:给出一个 batch 的输入 xBx_B,其输出 yy 的计算方式为

μB=1mk=1mxkσB2=1mk=1m(xkμB)x^k=xkμBσB+ϵyk=γx^k+β\begin{aligned} \mu_B &= \frac{1}{m}\sum_{k=1}^mx_k \\ \sigma_B^2 &= \frac{1}{m}\sum_{k=1}^m(x_k-\mu_B) \\ \hat x_k &= \frac{x_k-\mu_B}{\sqrt{\sigma_B+\epsilon}} \\ y_k &= \gamma \hat x_k + \beta \end{aligned}

一言蔽之:将同 batch 内数据每一下标的分布规整为标准分布,并平移缩放到统一的分布上。两个过程互相拮抗,在期望达到 norm 的同时再保证网络的非线性性。

标准化技术的主要目的是去除网络输出的分布变化。如果一个深层网络中某一层网络的输出分布一直变化(也就是所谓的 Internal Covariate Shift),就会逼迫它的下层不断改变以适应新分布,这对于下层学习有不利影响。

更详细的说,Covariate Shift 是一种分布不一致学习问题的细分:对于源空间 S 和目标空间 T,虽然

PS(Y=yX=x)=PT(Y=yX=x),P_S(Y=y \mid X=x) = P_T(Y=y \mid X=x),

但是

PS(X)PT(X).P_S(X) \neq P_T(X).

在极深的网络中后者存在,各层的输入会因上层输出变换而发生改变,并且累积而不断扩大,像一只蝴蝶引起的风暴,也就是 Internal 的 Covariate Shift。

但是这个问题不应被如此简单的解决,也不会被简单的标准化解决,有人[1]认为 BatchNorm 的作用主要还是解决了梯度消失的问题:通过在非线性激活函数之前使用 BatchNorm,使得输出主要落在了函数的线性位置[2]

更详细的研究则可参见 How Does Batch Normalization Help Optimization。其中一个非常有趣的实验就是在 BatchNorm 输出上叠加分布随机变化的噪声,使之得到显然会大于无 BatchNorm 网络的输出分布移动。结果带有噪声的 BatchNorm 却仍然得到了更好的结果(下图)。这就进一步说明了一般情况下 BatchNorm 的性能与所谓的 Internal Convariate Shift 没什么关系,文章认为 BatchNorm 带来的性能提升原因是 BatchNorm 平滑了 loss。

对 BatchNorm 的结果叠加噪声扰动后仍然得到了极好的精度

BatchNorm 的 loss 更加平滑

既然 BatchNorm 的作用并不来自什么 Internal Convariate Shift,那也便不用再讲别的了。进一步的探索发现其实对网络输出做一些标准化都能得到类似的效果,无论是 l1-norm 还是 l2-norm 等都可以。不过具体到某个网络上的效果……请以实验为准吧。·


  1. https://www.zhihu.com/question/38102762/answer/85238569 ↩︎

  2. 但也有反应将 BN 放在激活函数后反而得到了好的效果,这就无从而知了。 ↩︎


BatchNorm 和 LayerNorm 的区别
https://blog.chenc.me/2023/01/25/batchnorm-and-layernorm/
作者
CC
发布于
2023年1月25日
许可协议