如何正确地锁住一个 Normalization Layer?

在之前的一篇文章中我们介绍过关于 BN 和 LN 的一些小细节,在那里提到了这类 Layer 一般会在训练阶段统计输入数据的分布信息,并将该信息使用在推理阶段。随着近期 CV 侧深度学习也从 fine-tune 逐渐走向了直接 freeze backbone,我觉得是时候进一步明确 Norm Layer 在训练阶段和测试阶段的行为细节了。

本文将以 PyTorch 的 BatchNorm 为例。

统计信息是如何更新的

BatchNorm 可以描述为以下过程:

y=xE(x)Var(x)+ϵγ+βy=\frac{x-E(x)}{\sqrt{\mathrm{Var}(x)+\epsilon}}\gamma+\beta

其中 xx 为输入,γ\gammaβ\beta 是两个可学习参数。

我们来看一看 PyTorch 如何实现的 BatchNorm。BatchNorm1d 类仅仅是检查了输入的维度是否符合要求,_BatchNorm 中描述了逻辑,而 NormBase 则包含所有的可训练参数,我们先检查它。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class _NormBase(Module):
def __init__(
self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
self.running_mean: Optional[Tensor]
self.running_var: Optional[Tensor]
self.register_buffer('num_batches_tracked',
torch.tensor(0, dtype=torch.long,
**{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
self.num_batches_tracked: Optional[Tensor]
else:
self.register_buffer("running_mean", None)
self.register_buffer("running_var", None)
self.register_buffer("num_batches_tracked", None)
self.reset_parameters()

可以看到,此类注册了几个值得注意的参数:

  • weight
  • bias
  • running_mean
  • running_var

前两个实际上就是 gamma 和 beta。[1]比较有趣的东西是 running_meanrunning_mean,在 _BatchNorm 中描述了这两个参数的作用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class _BatchNorm(_NormBase):
# [DELETED]
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)

# [DELETED]

r"""
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
"""
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)

r"""
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
"""
return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean
if not self.training or self.track_running_stats
else None,
self.running_var if not self.training or self.track_running_stats else None,
self.weight,
self.bias,
bn_training,
exponential_average_factor,
self.eps,
)

简单来说,BatchNorm 的 E(x)E(x)Var(x)\mathrm{Var}(x) 的来源有三种:

  • 当你特意地关闭更新开关时:来自于当前的 mini-batch。
  • 当你没做什么特别的事情:来自于一个可学习参数,这个参数在所有的 mini-batch 上平滑更新
    • x^(1m)x^+mx\hat x \leftarrow (1-m)\hat x+mx
  • 当你没做什么特别的事情,并且在推理阶段时:来自于训练阶段的可学习参数。

所以 BatchNorm 并不是简单的将数据以当前 batch 做标准化,通常情况会在相对更 global 的均值方差上归一化。

我们是如何锁住 Backbone 的

一般我们是这么做的:

1
2
3
4
5
6
7
8
9
10
11
# set lr as 0
opti=Optimizer([
{
"params": "<params you want to freeze>",
"lr": 0,
}
])

# or set require_grad, which is more common in practice.
for param in model.parameters():
param.requires_grad_(False)

发现问题了吗?BN 的平滑更新并非梯度优化,所以并没有被锁住。大多数情况这不是人们所期望的。

正确做法是同时将 BN 层切换到 eval 状态:

1
bn.eval()

  1. https://github.com/pytorch/pytorch/issues/16149 ↩︎


如何正确地锁住一个 Normalization Layer?
https://blog.chenc.me/2023/07/13/how-to-correctly-freeze-norm-layer/
作者
CC
发布于
2023年7月13日
许可协议