Pointer Network 的一些理解

本文对李宏毅老师课程 Pointer Network 的一些理解和总结。

什么是 Pointer Network?

Pointer Network 最早出现在凸包问题的解决上。作者为了解决凸包问题,“丧心病狂” 地硬生生的训练了一个 Neural Network 来解决这个问题。

凸包问题

这个 Neural Network 的输入是所有点的坐标,输出是构成凸包的点的合集。整体架构如下图所示。

Neural Network "train" 凸包问题

那么这个 Network 可以用 seq2seq 的模式么?

答案是不行的,因为,我们并不知道输出的数据的多少。更具体地说,就是在 encoder 阶段,我们只知道这个凸包问题的输入,但是在 decoder 阶段,我们不知道我们一共可以输出多少个值。举例来说就是,第一次我们的输入是 50 个点,我们的输出可以是 0-50 (0 表示 END);第二次我们的输入是 100 个点,我们的输出依然是 0-50, 这样的话,我们就没办法输出 51-100 的点了。

为了解决这个问题,我们可以引入 Attention 机制。

这里的 Attention 跟别的地方的 Attention 并不一样,这里的 Attention 是根据输入的点去算一个分布,然后将 arg max 的点作为输出,也就是当前这个凸包问题的解的点,一直 train 下去,知道输出的点是 END。
为了避免理解上出现偏差,这里再用一个例子解释一遍。
模型一共是 encoder 和 decoder 两个部分。 刚开始 encoder 将所有输入训练完毕, 在 decoder 部分,刚开始输入一个 $z_0$ 作为初始变量,然后 $z_0$ 会去跟 encoder 中所有的点做一个 attention operation, 得到的结果是一个分布,搁在平时我们会对这个分布做 weigted sum,但是在这里,我们对这个分布做一个 arg max 取出概率最大的点(假设这个点是 1)。然后这个点$(x_1, y_1)$ 和 $z_0$ 又会作为新一轮输入,扔进 decoder 里面得到 $z_1$,然后 $z_1$ 又会跟 encoder 中所有的点做一个 attention operation 得到另一个新的分布,同样再取出概率最大的点,如此反复,直到取到的点是 0 (0 表示 END)

Pointer Network 有什么用?

个人的理解是 Pointer Network 有三个好处

  1. 提供了一种新视角去理解 Attention,把 Attention 作为一种求分布的手段。
  2. 对于输出字典长度不固定问题提供了一种新的解决方案。
  3. 将输入作为输出的一种补充手段,让输出部分可以更好的引入输入部分的信息。

所以,接下来就有一篇 Application 是结合 Pointer Network 去做 Summarization。 Get To The Point: Summarization with Pointer-Generator Networks

Pointer Network 解决 Summarization。