生成式任务解码#
1、任务说明#
以机器翻译任务为例,假设待翻译的中文是 "我爱中国",翻译成的英文是 "I love China"。另外为了简化后续的说明,我们假设英文里面总共就三个单词:"I"、"love"、"China"。
生成式任务的理论公式如下:
2、暴力搜索(穷举)#
这个很容易理解,就是把所有可能的序列都列出来,并且计算出每种可能的联合概率,然后选取联合概率最大的那个序列。按照上面英文中总共有三个单词 "I"、"love"、"China" 假设,所有的序列如下所示,总共是 3^3 = 27:
I I I
I I love
I I China
I love I
I love love
I love China
I China I
I China love
I China China
love I I
love I love
love I China
love love I
love love love
love love China
love China I
love China love
love China China
China I I
China I love
China I China
China love I
China love love
China love China
China China I
China China love
China China China
典型的 "元素可以重复使用的组合问题"(刷个题,代码如下):
result = []
def backtrace(arr: List[str], tmp):
if len(tmp) == len(arr):
result.append(copy.deepcopy(tmp))
return
for item in arr:
tmp.append(item)
backtrace(arr, tmp)
tmp = tmp[:-1]
def solution(arr: List[str]):
backtrace(arr, [])
return result
3、Greedy Search#
贪心搜索也好理解,针对当前的位置,模型会预测出一个 logits,对该 logits 进行 softmax 操作之后,就能够得到整个 vocab 中每个字的概率值,直接选取概率值最大的那个字作为当前位置的预测结果。就是说并不考虑整个序列的联合概率,每个位置直接使用生成那个位置的结果时概率最大的那个字。
使用贪心搜索不能够保证最终得到的序列是最优的,但是速度会非常的快。
4、beam search(集束搜索)#
还是以待翻译的中文是 "我爱中国",翻译成的英文是 "I love China" 为例,举例说明。假设beam search的束宽(beam size)为k=2。
-
在 t1 时刻:
- 预测 "I"、"love"、"China" 的概率分别为 0.4、0.5、0.1。此时选取概率最大的两个作为结果,为 "I" 和 "love"。
-
在 t2 时刻:
-
将前一时刻输出的 "I" 作为输入时,预测 "I"、"love"、"China" 的概率分别为 0.3、0.6、0.1。将之前时刻序列的概率乘上当前时刻的概率,可以得到 "I I"、"I love"、"I China" 的概率分别为 0.12、0.24、0.04。
-
将前一时刻输出的 "love" 作为输入时,预测 "I"、"love"、"China" 的概率分别为 0.3、0.3、0.4。将之前时刻序列的概率乘上当前时刻的概率,可以得到 "love I"、"love love"、"love China" 的概率分别为 0.15、0.15、0.2。
-
从当前时刻的 6 个序列 "I I"、"I love"、"I China"、"love I"、"love love"、"love China" 中选择出概率最大的两个序列,为 "I love" 和 "love China"。
-
-
在 t3 时刻:
-
将前一时刻输出的 "I love" 作为输入时,预测 "I"、"love"、"China" 的概率分别为 0.2、0.1、0.7。将之前时刻序列的概率乘上当前时刻的概率,可以得到 "I love I"、"I love love"、"I love China" 的概率分别为 0.048、0.024、0.168。
-
将前一时刻输出的 "love China" 作为输入时,预测 "I"、"love"、"China" 的概率分别为 0.3、0.3、0.4。将之前时刻序列的概率乘上当前时刻的概率,可以得到 "love China I"、"love China love"、"love China China" 的概率分别为 0.06、0.06、0.08。
-
由于这里是最后一个字符,所以从当前时刻的 6 个序列 "I love I"、"I love love"、"I love China"、"love China I"、"love China love"、"love China China" 中选择出概率最大那个序列最为最终结果,为 "I love China"。
-
正常情况下还会有一个终止字符,当预测到那个终止字符时才算是结束,这里简化了。
-
所以,对于beam search算法有:
定义当前时刻为t,定义beam search的束宽(beam size)为k,待翻译的语言类型中总共有n个字。那么可知,在 t-1 时刻,会筛选出 k 种情况;在 t 时刻,会生成 k*n 种情况的概率,然后在这 k*n 个概率中选出概率最大的 k 个作为当前 t 时刻的预测结果。不断重复这个过程,直到最后一个时刻,直接选取 k*n 种情况中概率最大的那个作为结果。
5、Top-K Sampling#
6、Top-p Sampling#
7、解码时各种参数的作用#
温度、do_sample
链接:https://www.zhihu.com/question/415657741/answer/2430106609