0%

Attention原理及TensorFlow AttentionWrapper源码解析

本节来详细说明一下 Seq2Seq 模型中一个非常有用的 Attention 的机制,并结合 TensorFlow 中的 AttentionWrapper 来剖析一下其代码实现。

Seq2Seq

首先来简单说明一下 Seq2Seq 模型,如果搞过深度学习,想必一定听说过 Seq2Seq 模型,Seq2Seq 其实就是 Sequence to Sequence,也简称 S2S,也可以称之为 Encoder-Decoder 模型,这个模型的核心就是编码器(Encoder)和解码器(Decoder)组成的,架构雏形是在 2014 年由论文 Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation, Cho et al 提出的,后来 Sequence to Sequence Learning with Neural Networks, Sutskever et al 算是比较正式地提出了 Sequence to Sequence 的架构,后来 Neural Machine Translation by Jointly Learning to Align and Translate, Bahdanau et al 又提出了 Attention 机制,将 Seq2Seq 模型推上神坛,并横扫了非常多的任务,现在也非常广泛地用于机器翻译、对话生成、文本摘要生成等各种任务上,并取得了非常好的效果。 下面的图示意了 Seq2Seq 模型的基本架构: 可以看到图中有一个中间状态 $ c $ 向量,在 $ c $ 向量左侧的我们可以称之为编码器(Encoder),编码器这里示意的是 RNN 序列,另外 RNN 单元还可以使用 LSTM、GRU 等变体, 在编码器下方输入了 $ x_1 $、$ x_2 $、$ x_3 $、$ x_4 $,代表模型的输入内容,例如在翻译模型中可以分别代表“我爱中国”这四个字,这样经过序列处理,它就会得到最后的输出,我们将其表示为 $ c $ 向量,这样编码器的工作就完成了。在图中 $ c $ 向量的右侧部分我们可以称之为解码器(Decoder),它拿到编码器生成的 $ c $ 向量,然后再进行序列解码,得到输出结果 $ y_1 $、$ y_2 $、$ y_3 $,例如刚才输入的“我爱中国”四个字便被解码成了 “I love China”,这样就实现了翻译任务,以上就是最基本的 Seq2Seq 模型原理。 另外还有一种变体,$ c $ 向量在每次解码的时候都会作为解码器的输入,其实原理都是类似的,如图所示: 这种模型架构是通用的,所以它的适用场景也非常广泛。如机器翻译、对话生成、文本摘要、阅读理解、语音识别,也可以用在一些趣味场景中,如诗词生成、对联生成、代码生成、评论生成等等,效果都很不错。

Attention

通过上图我们可以发现,Encoder 把所有的输入序列编码成了一个 $ c $ 向量,然后使用 $ c $ 向量来进行解码,因此,$ c $ 向量中必须包含了原始序列中的所有信息,所以它的压力其实是很大的,而且由于 RNN 容易把前面的信息“忘记”掉,所以基本的 Seq2Seq 模型,对于较短的输入来说,效果还是可以接受的,但是在输入序列比较长的时候,$ c $ 向量存不下那么多信息,就会导致生成效果大大折扣。 Attention 机制解决了这个问题,它可以使得在输入文本长的时候精确率也不会有明显下降,它是怎么做的呢?既然一个 $ c $ 向量存不了,那么就引入多个 $ c $ 向量,称之为 $ c1 $、$ c_2 $、…、$ c_i $,在解码的时候,这里的 $ i $ 对应着 Decoder 的解码位次,每次解码就利用对应的 $ c_i $ 向量来解码,如图所示: 这里的每个 $ c_i $ 向量其实包含了当前所输出与输入序列各个部分重要性的相关的信息。不同的 $ c_i $ 向量里面包含的输入信息各部分的权重是不同的,先放一个示意图: 还是上面的例子,例如输入信息是“我爱中国”,输出的的理想结果应该是“I love China”,在解码的时候,应该首先需要解码出 “I” 这个字符,这时候会用到 $ c_1 $ 向量,而 $ c_1 $ 向量包含的信息中,“我”这个字的重要性更大,因此它便倾向解码输出 “I”,当解码第二个字的时候,会用到 $ c_2 $ 向量,而 $ c_2 $ 向量包含的信息中,“爱” 这个字的重要性更大,因此会解码输出 “love”,在解码第三个字的时候,会用到 $ c_3 $ 向量,而 $ c_3 $向量包含的信息中,”中国” 这两个字的权重都比较大,因此会解码输出 “China”。所以其实,Attention 注意力机制中的 $ c_i $ 向量记录了不同解码时刻应该更关注于哪部分输入数据,也实现了编码解码过程的对齐。经过实验发现,这种机制可以有效解决输入信息过长时导致信息解码效果不理想的问题,另外解码生成效果同时也有提升。 下面我们以 Bahdanau 提出的 Attention 为例来详细剖析一下 Attention 机制。 在没有引入 Attention 之前,Decoder 在某个时刻解码的时候实际上是依赖于三个部分的,首先我们知道 RNN 中,每次输出结果会依赖于隐层和输入,在 Seq2Seq 模型中,还需要依赖于 $ c $ 向量,所以这里我们设在 $ i $ 时刻,解码器解码的内容是 $ y_i $,上一次解码结果是 $ y{i-1} $,隐层输出是 $ st $,所以它们满足这样的关系: $$ y_i = g(y{i-1}, si, c) s_i = f(s{i-1}, y{i-1}, c) y_i = g(y{i-1}, si, c_i) s_i = f(s{i-1}, y{i-1}, c_i) $$ 所以,这里每次解码得出 $ y_i $ 时,都有与之对应的 $ c_i $ 向量。那么这个 $ c_i $ 向量又是怎么来的呢?实际上它是由编码器端每个时刻的隐含状态加权平均得到的,这里假设编码器端的的序列长度为 $ T_x $,序列位次用 $ j $ 来表示,编码器段每个时刻的隐含状态即为 $ h_1 $、$ h_2 $、…、$ h_j $、…、$ h{Tx} $,对于解码器的第 $ i $ 时刻,对应的 $ c_i $ 表示如下: $$ c_i = \sum{j=1}^{Tx} \alpha{ij}hj $$ 编码器输出的结果中,$ h_j $ 中包含了输入序列中的第 $ j $ 个词及前面的一些信息,如果是用了双向 RNN 的话,则包含的是第 $ j $ 个词即前后的一些词的信息,这里 $ \alpha{ij} $ 代表了分配的权重,这代表在生成第 i 个结果的时候,对于输入信息的各个阶段的 $ hj $ 的注意力分配是不同的。 当 $ a{ij} $ 的值越高,表示第 $ i $ 个输出在第 $ j $ 个输入上分配的注意力越多,这样就会导致在生成第 $ i $ 个输出的时候,受第 $ j $ 个输入的影响也就越大。 那么 $ a{ij} $ 又是怎么得来的呢?其实它就又关系到第 $ i-1 $ 个输出隐藏状态 $ s{i-1} $ 以及输入中的各个隐含状态 $ h_j $,公式表示如下: $$ \alpha{ij} = \frac {exp(e{ij})} {\sum{k=1}^{Tx} exp(e{ik})} e{ij} = a(s{i-1}, hj) = {v_a}^Ttanh(W_as{i-1} + Uah_j) $$ 这也就是说,这个权重就是 $ s{i-1} $ 和 $ hj $ 分别计算得到一个数值,然后再过一个 softmax 函数得到的,结果就是 $ \alpha{ij} $。 因此 $ ci $ 就可以表示为: $$ c_i = \sum{j=1}^{Tx} softmax(a(s{i-1}, h_j)) \cdot h_j $$ 以上便是整个 Attention 机制的推导过程。

TensorFlow AttentionWrapper

我们了解了基本原理,但真正离程序实现出来其实还是有很大差距的,接下来我们就结合 TensorFlow 框架来了解一下 Attention 的实现机制。 在 TensorFlow 中,Attention 的相关实现代码是在 tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py 文件中,这里面实现了两种 Attention 机制,分别是 BahdanauAttention 和 LuongAttention,其实现论文分别如下:

整个 attention_wrapper.py 文件中主要包含几个类,我们主要关注其中几个:

  • AttentionMechanism、_BaseAttentionMechanism、LuongAttention、BahdanauAttention 实现了 Attention 机制的逻辑。
    • AttentionMechanism 是 Attention 类的父类,继承了 object 类,内部没有任何实现。
    • _BaseAttentionMechanism 继承自 AttentionMechanism 类,定义了 Attention 机制的一些公共方法实现和属性。
    • LuongAttention、BahdanauAttention 均继承 _BaseAttentionMechanism 类,分别实现了上面两篇论文的 Attention 机制。
  • AttentionWrapperState 用来存储整个计算过程中的 state,和 RNN 中的 state 类似,只不过这里额外还存储了 attention、time 等信息。
  • AttentionWrapper 主要用于对封装 RNNCell,继承自 RNNCell,封装后依然是 RNNCell 的实例,可以构建一个带有 Attention 机制的 Decoder。
  • 另外还有一些公共方法,例如 hardmax、safe_cumpord 等。

下面我们以 BahdanauAttention 为例来说明 Attention 机制及 AttentionWrapper 的实现。

BahdanauAttention

首先我们来介绍 BahdanauAttention 类的具体原理。 首先我们来看下它的初始化方法:

1
2
3
4
5
6
7
8
9
def __init__(self,
num_units,
memory,
memory_sequence_length=None,
normalize=False,
probability_fn=None,
score_mask_value=None,
dtype=None,
name="BahdanauAttention"):

这里一共接受八个参数,下面一一进行说明:

  • numunits:神经元节点数,我们知道在计算 $ e{ij} $ 的时候,需要使用 $ s_{i-1} $ 和 $ h_j $ 来进行计算,而二者的维度可能并不是统一的,需要进行变换和统一,所以这里就有了 $ W_a $ 和 $ U_a $ 这两个系数,所以在代码中就是用 num_units 来声明了一个全连接 Dense 网络,用于统一二者的维度,以便于下一步的计算:
1
2
query_layer=layers_core.Dense(num_units, name="query_layer", use_bias=False, dtype=dtype)
memory_layer=layers_core.Dense(num_units, name="memory_layer", use_bias=False, dtype=dtype)

这里我们可以看到声明了一个 querylayer 和 memory_layer,分别和 $ s{i-1} $ 及 $ h_j $ 做全连接变换,统一维度。

  • memory:The memory to query; usually the output of an RNN encoder. 即解码时用到的上文信息,维度需要是 [batch_size, max_time, context_dim]。这时我们观察一下父类 _BaseAttentionMechanism 的初始化方法,实现如下:
1
2
3
4
5
6
7
8
with ops.name_scope(
name, "BaseAttentionMechanismInit", nest.flatten(memory)):
self._values = _prepare_memory(
memory, memory_sequence_length,
check_inner_dims_defined=check_inner_dims_defined)
self._keys = (
self.memory_layer(self._values) if self.memory_layer
else self._values)

这里通过 _prepare_memory() 方法对 memory 进行处理,然后调用 memory_layer 对 memory 进行全连接维度变换,变换成 [batch_size, max_time, num_units]。

  • memory_sequence_length:Sequence lengths for the batch entries in memory. 即 memory 变量的长度信息,类似于 dynamic_rnn 中的 sequence_length,被 _prepare_memory() 方法调用处理 memory 变量,进行 mask 操作:
1
2
3
4
5
6
7
seq_len_mask = array_ops.sequence_mask(
memory_sequence_length,
maxlen=array_ops.shape(nest.flatten(memory)[0])[1],
dtype=nest.flatten(memory)[0].dtype)
seq_len_batch_size = (
memory_sequence_length.shape[0].value
or array_ops.shape(memory_sequence_length)[0])
  • normalize:Whether to normalize the energy term. 即是否要实现标准化,方法出自论文:Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks, Salimans, et al
  • probability_fn:A callable function which converts the score to probabilities. 计算概率时的函数,必须是一个可调用的函数,默认使用 softmax(),还可以指定 hardmax() 等函数。
  • score_mask_value:The mask value for score before passing into probability_fn. The default is -inf. Only used if memory_sequence_length is not None. 在使用 probability_fn 计算概率之前,对 score 预先进行 mask 使用的值,默认是负无穷。但这个只有在 memory_sequence_length 参数定义的时候有效。
  • dtype:The data type for the query and memory layers of the attention mechanism. 数据类型,默认是 float32。
  • name:Name to use when creating ops,自定义名称。

接下来类里面定义了一个 call() 方法:

1
2
3
4
5
6
def __call__(self, query, previous_alignments):
with variable_scope.variable_scope(None, "bahdanau_attention", [query]):
processed_query = self.query_layer(query) if self.query_layer else query
score = _bahdanau_score(processed_query, self._keys, self._normalize)
alignments = self._probability_fn(score, previous_alignments)
return alignments

这里首先定义了 processed_query,这里也是通过 query_layer 过了一个全连接网络,将最后一维统一成 num_units,然后调用了 bahdanau_score() 方法,这个方法是比较重要的,主要用来计算公式中的 $ e{ij} $,传入的参数是 processed_query 以及上文中提及的 keys 变量,二者一个代表了 $ s{i-1} $,一个代表了 $ h_j $,_bahdanau_score() 方法实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def _bahdanau_score(processed_query, keys, normalize):
dtype = processed_query.dtype
# Get the number of hidden units from the trailing dimension of keys
num_units = keys.shape[2].value or array_ops.shape(keys)[2]
# Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
processed_query = array_ops.expand_dims(processed_query, 1)
v = variable_scope.get_variable(
"attention_v", [num_units], dtype=dtype)
if normalize:
# Scalar used in weight normalization
g = variable_scope.get_variable(
"attention_g", dtype=dtype,
initializer=math.sqrt((1. / num_units)))
# Bias added prior to the nonlinearity
b = variable_scope.get_variable(
"attention_b", [num_units], dtype=dtype,
initializer=init_ops.zeros_initializer())
# normed_v = g * v / ||v||
normed_v = g * v * math_ops.rsqrt(
math_ops.reduce_sum(math_ops.square(v)))
return math_ops.reduce_sum(normed_v * math_ops.tanh(keys + processed_query + b), [2])
else:
return math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query), [2])

这里其实就是实现了 keys 和 processedquery 的加和,如果指定了 normalize 的话还需要进行额外的 normalize,结果就是公式中的 $ e{ij} $,在 TensorFlow 中常用 score 变量表示。 接下来再回到 call() 方法中,这里得到了 score 变量,接下来可以对齐求 softmax() 操作,得到 $ \alpha_{ij} $:

1
alignments = self._probability_fn(score, previous_alignments)

这就代表了在 $ i $ 时刻,Decoder 的时候对 Encoder 得到的每个 $ hj $ 的权重大小比例,在 TensorFlow 中常用 alignments 变量表示。 所以综上所述,BahdanauAttention 就是初始化时传入 num_units 以及 Encoder Outputs,然后调时传入 query 用即可得到权重变量 alignments。

AttentionWrapperState

接下来我们再看下 AttentionWrapperState 这个类,这个类其实比较简单,就是定义了 Attention 过程中可能需要保存的变量,如 cell_state、attention、time、alignments 等内容,同时也便于后期的可视化呈现,代码实现如下:

1
2
3
4
class AttentionWrapperState(
collections.namedtuple("AttentionWrapperState",
("cell_state", "attention", "time", "alignments",
"alignment_history"))):

可见它就是继承了 namedtuple 这个数据结构,其实整个 AttentionWrapperState 就像声明了一个结构体,可以传入需要的字段生成这个对象。

AttentionWrapper

了解了 Attention 机制及 BahdanauAttention 的原理之后,最后我们再来了解一下 AttentionWrapper,可能你用过很多其他的 Wrapper,如 DropoutWrapper、ResidualWrapper 等等,它们其实都是 RNNCell 的实例,其实 AttentionWrapper 也不例外,它对 RNNCell 进行了封装,封装后依然还是 RNNCell 的实例。一个普通的 RNN 模型,你要加入 Attention,只需要在 RNNCell 外面套一层 AttentionWrapper 并指定 AttentionMechanism 的实例就好了。而且如果要更换 AttentionMechanism,只需要改变 AttentionWrapper 的参数就好了,这可谓对 Attention 的实现架构完全解耦,配置非常灵活,TF 大法好! 接下来我们首先来看下它的初始化方法,其参数是这样的:

1
2
3
4
5
6
7
8
9
def __init__(self,
cell,
attention_mechanism,
attention_layer_size=None,
alignment_history=False,
cell_input_fn=None,
output_attention=True,
initial_cell_state=None,
name=None):

下面对参数进行一一说明:

  • cell:An instance of RNNCell. RNNCell 的实例,这里可以是单个的 RNNCell,也可以是多个 RNNCell 组成的 MultiRNNCell。
  • attention_mechanism:即 AttentionMechanism 的实例,如 BahdanauAttention 对象,另外可以是多个 AttentionMechanism 组成的列表。
  • attention_layer_size:是数字或者数字做成的列表,如果是 None(默认),直接使用加权计算后得到的 Attention 作为输出,如果不是 None,那么 Attention 结果还会和 Output 进行拼接并做线性变换再输出。其代码实现如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
if attention_layer_size is not None:
attention_layer_sizes = tuple(attention_layer_size if isinstance(attention_layer_size, (list, tuple)) else (attention_layer_size,))
if len(attention_layer_sizes) != len(attention_mechanisms):
raise ValueError("If provided, attention_layer_size must contain exactly one integer per attention_mechanism, saw: %d vs %d" % (len(attention_layer_sizes), len(attention_mechanisms)))
self._attention_layers = tuple(layers_core.Dense(attention_layer_size, name="attention_layer", use_bias=False, dtype=attention_mechanisms[i].dtype) for i, attention_layer_size in enumerate(attention_layer_sizes))
self._attention_layer_size = sum(attention_layer_sizes)
else:
self._attention_layers = None
self._attention_layer_size = sum(attention_mechanism.values.get_shape()[-1].value for attention_mechanism in attention_mechanisms)

for i, attention_mechanism in enumerate(self._attention_mechanisms):
attention, alignments = _compute_attention(attention_mechanism, cell_output, previous_alignments[i], self._attention_layers[i] if self._attention_layers else None)
alignment_history = previous_alignment_history[i].write(state.time, alignments) if self._alignment_history else ()
  • alignment_history:即是否将之前的 alignments 存储到 state 中,以便于后期进行可视化展示。
  • cell_input_fn:将 Input 进行处理的方式,默认会将上一步的 Attention 进行 拼接操作,以免造成重复关注同样的内容。代码调用如下:
1
cell_inputs = self._cell_input_fn(inputs, state.attention)
  • output_attention:是否将 Attention 返回,如果是 False 则返回 Output,否则返回 Attention,默认是 True。
  • initial_cell_state:计算时的初始状态。
  • name:自定义名称。

AttentionWrapper 的核心方法在它的 call() 方法,即类似于 RNNCell 的 call() 方法,AttentionWrapper 类对其进行了重载,代码实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def call(self, inputs, state):
# Step 1
cell_inputs = self._cell_input_fn(inputs, state.attention)
# Step 2
cell_state = state.cell_state
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
# Step 3
if self._is_multi:
previous_alignments = state.alignments
previous_alignment_history = state.alignment_history
else:
previous_alignments = [state.alignments]
previous_alignment_history = [state.alignment_history]
all_alignments = []
all_attentions = []
all_histories = []
for i, attention_mechanism in enumerate(self._attention_mechanisms):
attention, alignments = _compute_attention(attention_mechanism, cell_output, previous_alignments[i], self._attention_layers[i] if self._attention_layers else None)
alignment_history = previous_alignment_history[i].write(state.time, alignments) if self._alignment_history else ()
all_alignments.append(alignments)
all_histories.append(alignment_history)
all_attentions.append(attention)
# Step 4
attention = array_ops.concat(all_attentions, 1)
# Step 5
next_state = AttentionWrapperState(
time=state.time + 1,
cell_state=next_cell_state,
attention=attention,
alignments=self._item_or_tuple(all_alignments),
alignment_history=self._item_or_tuple(all_histories))
# Step 6
if self._output_attention:
return attention, next_state
else:
return cell_output, next_state

在这里将一些异常判断代码去除了,以便于结构看得更清晰。 首先在第一步中,调用了 _cell_input_fn() 方法,对 inputs 和 state.attention 变量进行处理,默认是使用 concat() 函数拼接,作为当前时间步的输入。因为可能前一步的 Attention 可能对当前 Attention 有帮助,以免让模型连续两次将注意力放在同一个地方。 在第二步中,其实就是调用了普通的 RNNCell 的 call() 方法,得到输出和下一步的状态。 第三步中,这时得到的输出其实并没有用上 AttentionMechanism 中的 alignments 信息,所以当前的输出信息中我们并没有跟 Encoder 的信息做 Attention,所以这里还需要调用 _compute_attention() 方法进行权重的计算,其方法实现如下:

1
2
3
4
5
6
7
8
9
10
def _compute_attention(attention_mechanism, cell_output, previous_alignments, attention_layer):
alignments = attention_mechanism(cell_output, previous_alignments=previous_alignments)
expanded_alignments = array_ops.expand_dims(alignments, 1)
context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
context = array_ops.squeeze(context, [1])
if attention_layer is not None:
attention = attention_layer(array_ops.concat([cell_output, context], 1))
else:
attention = context
return attention, alignments

这个方法接收四个参数,其中 attentionmechanism 就是 AttentionMechanism 的实例,cell_output 就是当前 Output,previous_alignments 是上步的 alignments 信息,调用 attention_mechanism 计算之后就会得到当前步的 alignments 信息了,即 $ \alpha{ij} $。接下来再利用 alignments 信息进行加权运算,得到 attention 信息,即 $ c_{i} $,最后将二者返回。 在第四步中,就是将 attention 结果每个时间步进行 concat,得到 attention vector。 第五步中,声明 AttentionWrapperState 作为下一步的状态。 第六步,判断是否要输出 Attention,如果是,输出 Attention 及下一步状态,否则输出 Outputs 及下一步状态。 好,以上便是整个 AttentionWrapper 源码解析过程,了解了源码之后,再做模型优化的话就非常得心应手了。

参考来源