本文对李宏毅老师课程 Pointer Network 的一些理解和总结。
什么是 Pointer Network?
Pointer Network 最早出现在凸包问题的解决上。作者为了解决凸包问题,“丧心病狂” 地硬生生的训练了一个 Neural Network 来解决这个问题。
这个 Neural Network 的输入是所有点的坐标,输出是构成凸包的点的合集。整体架构如下图所示。
那么这个 Network 可以用 seq2seq 的模式么?
答案是不行的,因为,我们并不知道输出的数据的多少。更具体地说,就是在 encoder 阶段,我们只知道这个凸包问题的输入,但是在 decoder 阶段,我们不知道我们一共可以输出多少个值。举例来说就是,第一次我们的输入是 50 个点,我们的输出可以是 0-50 (0 表示 END);第二次我们的输入是 100 个点,我们的输出依然是 0-50, 这样的话,我们就没办法输出 51-100 的点了。
为了解决这个问题,我们可以引入 Attention 机制。
这里的 Attention 跟别的地方的 Attention 并不一样,这里的 Attention 是根据输入的点去算一个分布,然后将 arg max 的点作为输出,也就是当前这个凸包问题的解的点,一直 train 下去,知道输出的点是 END。
为了避免理解上出现偏差,这里再用一个例子解释一遍。
模型一共是 encoder 和 decoder 两个部分。 刚开始 encoder 将所有输入训练完毕, 在 decoder 部分,刚开始输入一个 $z_0$ 作为初始变量,然后 $z_0$ 会去跟 encoder 中所有的点做一个 attention operation, 得到的结果是一个分布,搁在平时我们会对这个分布做 weigted sum,但是在这里,我们对这个分布做一个 arg max 取出概率最大的点(假设这个点是 1)。然后这个点$(x_1, y_1)$ 和 $z_0$ 又会作为新一轮输入,扔进 decoder 里面得到 $z_1$,然后 $z_1$ 又会跟 encoder 中所有的点做一个 attention operation 得到另一个新的分布,同样再取出概率最大的点,如此反复,直到取到的点是 0 (0 表示 END)
Pointer Network 有什么用?
个人的理解是 Pointer Network 有三个好处
- 提供了一种新视角去理解 Attention,把 Attention 作为一种求分布的手段。
- 对于输出字典长度不固定问题提供了一种新的解决方案。
- 将输入作为输出的一种补充手段,让输出部分可以更好的引入输入部分的信息。
所以,接下来就有一篇 Application 是结合 Pointer Network 去做 Summarization。 Get To The Point: Summarization with Pointer-Generator Networks