Beam Search 算法及代码解读

我没想到第二个暴搜 tag 居然会给一篇深度学习的文章。

一般的暴搜题目没必要记,有奇妙特性和剪枝的暴搜重点往往便不再是暴搜了。以及,以前不爱写题解。结果就是占据暴搜 tag 的第二篇文章居然是深度学习……

彻头彻尾的暴搜。

Beam Search 一般使用在 seq2seq 任务上。由于输入输出不定长,所以往往是让模型采取一种循环的方式来输出 seq。

简单来讲,

  1. 我们有一个字典 D\mathcal D,模型根据输入 SS 先输出句子 TT 中第一个单词为字典中每个单词的置信。
  2. 现在按照置信度高低采纳第一个词为 T1T_1。然后,我们会把 TT 再次输入模型,模型依照 SSTT 计算 TT 的下一个单词的置信。
  3. 我们再采纳一个词加入到 TT 的最后。
  4. 如此循环往复,直到模型输出终结符。

那么现在问题就在于,我们其实是想让模型给一个最大化 P(TS)P(T \mid S)TT。但是我们无法证明概率最大的 TT 前缀也总是概率最大的,上述步骤就无法保证正确。

为了得到最优解,理论上我们要遍历所有的解空间,枚举 TT 的长度和每一个单词。但这显然无法接受。所以最后就有了一个折中的办法:选置信度最大的前 kk 个作为备选答案的前缀来扩展。

这就是 Beam Search 了。

代码

这东西的思想实在太简单,太粗暴,一句话就能讲明白。它的难度其实主要是怎么充分地并行化。

下面给出一段 Pytorch 实现,不是我写的,只是我感觉细节处理得挺好。

我在必要的位置做了注释。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
entry_length=67, temperature=1., stop_token: str = '.'):
"""
beam_size: 就是对搜索空间的限制,上文中的 k。
prompt: 和这段代码的任务相关,不必过分关注。可以简单理解为模型输入。
embed: 同样和代码的任务相关,不必过分关注。
entry_length: 句子的最大长度。
temperature: 用于控制模型输出置信度的数值大小。我没明白这是干嘛的。
stop_token: 终结符。
"""

# 注意这段代码假设输入 batch size 为 1。若想改成更大 size,恐怕要费点功夫。有空我会试试。

model.eval()
stop_token_index = tokenizer.encode(stop_token)[0]
tokens = None
scores = None
device = next(model.parameters()).device
# 待选句子的长度
seq_lengths = torch.ones(beam_size, device=device)
# 待选句子是否已经结尾
is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
with torch.no_grad():
# 和这段代码原本的任务相关,不必过分关注。
if embed is not None:
generated = embed
else:
if tokens is None:
tokens = torch.tensor(tokenizer.encode(prompt))
tokens = tokens.unsqueeze(0).to(device)
generated = model.gpt.transformer.wte(tokens)

for i in range(entry_length):
# 将所有待选句子(或者第一次只有输入)扔进模型,计算置信度
# Tip: 这个模型的输入和输出是放在一起输入的。依照你的模型设计,你可能需要做一些大的调整
outputs = model.gpt(inputs_embeds=generated)
logits = outputs.logits
# 取出新添单词的 confidence
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
logits = logits.softmax(-1).log()
# 如果是第一轮循环,我们并没有待选句子
if scores is None:
# 排序,选出前 k 个最好的单词
scores, next_tokens = logits.topk(beam_size, -1)
# 将 size 为 1 的输入复制 k 份,好能够把 k 个待选单词分别加在后面
generated = generated.expand(beam_size, *generated.shape[1:])
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
# 拼接句子
if tokens is None:
tokens = next_tokens
else:
tokens = tokens.expand(beam_size, *tokens.shape[1:])
tokens = torch.cat((tokens, next_tokens), dim=1)
else:
# 将已经结束的句子封堵住。具体再往下看。
logits[is_stopped] = -float(np.inf)
# 并不是字面上的让所有已结束句子的标签 0 为 100% 置信,只是为了在接下来的过程里保留每个已结束句子的一个副本
logits[is_stopped, 0] = 0
# P(T[:k]|S)=P(T|S,T[:k-1])*P(T[:k-1]|S),因为已经取过 log 了所以用加法
# 注意维度操作,计算后,就是 k 个句子后接分别接每组 k 个单词的概率,维度为 (k, k)
scores_sum = scores[:, None] + logits
# 将所有未结束句子的长度 +1s
seq_lengths[~is_stopped] += 1
# 对句子长度归一化,防止句子长度影响概率判断
scores_sum_average = scores_sum / seq_lengths[:, None]
# 对 k*k 个待选新句子取前 k 个最优
# Tip: 我们不想在这里继续扩展已结束句子,这就是之前两句代码的作用。
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
# 计算 k 个新句子从哪个旧句子而来
next_tokens_source = next_tokens // scores_sum.shape[1]
seq_lengths = seq_lengths[next_tokens_source]
# 计算 k 个新句子的新加单词是什么
next_tokens = next_tokens % scores_sum.shape[1]
next_tokens = next_tokens.unsqueeze(1)
# (下方代码就很容易理解了)
tokens = tokens[next_tokens_source]
tokens = torch.cat((tokens, next_tokens), dim=1)
generated = generated[next_tokens_source]
scores = scores_sum_average * seq_lengths
is_stopped = is_stopped[next_tokens_source]
next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
generated = torch.cat((generated, next_token_embed), dim=1)
is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
if is_stopped.all():
break
scores = scores / seq_lengths
output_list = tokens.cpu().numpy()
output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
order = scores.argsort(descending=True)
output_texts = [output_texts[i] for i in order]
return output_texts

Beam Search 算法及代码解读
https://blog.chenc.me/2022/12/07/beam-search-note/
作者
CC
发布于
2022年12月7日
许可协议