本节我们来尝试使用 TensorFlow 搭建一个双向 LSTM (Bi-LSTM) 深度学习模型来处理序列标注问题,主要目的是学习 Bi-LSTM 的用法。
Bi-LSTM
我们知道 RNN 是可以学习到文本上下文之间的联系的,输入是上文,输出是下文,但这样的结果是模型可以根据上文推出下文,而如果输入下文,想要推出上文就没有那么简单了,为了弥补这个缺陷,我们可以让模型从两个方向来学习,这就构成了双向 RNN。在某些任务中,双向 RNN 的表现比单向 RNN 要好,本文要实现的文本分词就是其中之一。不过本文使用的模型不是简单的双向 RNN,而是 RNN 的变种 — LSTM。 如图所示为 Bi-LSTM 的基本原理,输入层的数据会经过向前和向后两个方向推算,最后输出的隐含状态再进行 concat,再作为下一层的输入,原理其实和 LSTM 是类似的,就是多了双向计算和 concat 过程。
数据处理
本文的训练和测试数据使用的是已经做好序列标注的中文文本数据。序列标注,就是给一个汉语句子作为输入,以“BEMS”组成的序列串作为输出,然后再进行切词,进而得到输入句子的划分。其中,B 代表该字是词语中的起始字,M 代表是词语中的中间字,E 代表是词语中的结束字,S 则代表是单字成词。 这里的原始数据样例如下:
1 |
人/b 们/e 常/s 说/s 生/b 活/e 是/s 一/s 部/s 教/b 科/m 书/e |
这里一个字对应一个标注,我们首先需要对数据进行预处理,预处理的流程如下:
- 将句子切分
- 将句子的的标点符号去掉
- 将每个字及对应的标注切分
- 去掉长度为 0 的无效句子
首先我们将句子切分开来并去掉标点符号,代码实现如下:
1 |
# Read origin data |
这样我们就可以将句子切分开来并做好了清洗,接下来我们还需要把每个句子中的字及标注转为 Numpy 数组,便于下一步制作词表和数据集,代码实现如下:
1 |
import re |
这里我们利用正则 re 库的 findall() 方法将字及标注分开,并分别添加到 words 和 labels 数组中,运行效果如下:
1 |
Words Length 321533 Labels Length 321533 |
接下来我们有了这些数据就要开始制作词表了,词表制作起来无非就是输入词表和输出词表的不重复的正逆对应,制作词表的目的就是将输入的文字或标注转为 index,同时还能反向根据 index 获取对应的文字或标注,所以我们这里需要制作 word2id、id2word、tag2id、id2tag 四个字典。 为了解决 OOV 问题,我们还需要将无效字符也进行标注,这里我们统一取 0。制作时我们借助于 pandas 库的 Series 进行了去重和转换,另外还限制了每一句的最大长度,这里设置为 32,如果大于32,则截断,否则进行 padding,代码如下:
1 |
from itertools import chain |
这样我们就完成了 word2id、id2word、tag2id、id2tag 四个字典的制作,并制作好了 Numpy 数组类型的 data_x 和 data_y,这里 data_x 和 data_y 单句示例如下:
1 |
Data X Example: [8, 43, 320, 88, 36, 198, 7, 2, 41, 163, 124, 245, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] |
可以看到数据的 x 部分,原始文字和标注结果都转化成了词表中的 index,同时不够 32 个字符就以 0 补全。 接下来我们将其保存成 pickle 文件,以备训练和测试使用:
1 |
print('Starting pickle to file...') |
好,现在数据预处理部分就完成了。
构造模型
接下来我们就需要利用 pickle 文件中的数据来构建模型了,首先进行 pickle 文件的读取,然后将数据分为训练集、开发集、测试集,详细流程不再赘述,赋值为如下变量:
1 |
# Load data |
接下来我们使用 TensorFlow 自带的 Dataset 数据结构构造输入输出,利用 Dataset 我们可以构造一个 iterator 迭代器,每调用一次 get_next() 方法,我们就可以得到一个 batch,这里 Dataset 的初始化我们使用 from_tensor_slices() 方法,然后调用其 batch() 方法来初始化每个数据集的 batch_size,接着初始化同一个 iterator,并绑定到三个数据集上声明为三个 initializer,这样每调用 initializer,就会将 iterator 切换到对应的数据集上,代码实现如下:
1 |
# Train and dev dataset |
有了 Dataset 的 iterator,我们只需要调用一次 get_next() 方法即可得到 x 和 y_label 了,就不需要使用 placeholder 来声明了,代码如下:
1 |
# Input Layer |
接下来我们需要实现 embedding 层,调用 TensorFlow 的 embedding_lookup 即可实现,这里没有使用 Pre Train 的 embedding,代码实现如下:
1 |
# Embedding Layer |
接下来我们就需要实现双向 LSTM 了,这里我们要构造一个 2 层的 Bi-LSTM 网络,实现的时候我们首先需要声明 LSTM Cell 的列表,然后调用 stack_bidirectional_rnn() 方法即可:
1 |
cell_fw = [lstm_cell(FLAGS.num_units, keep_prob) for _ in range(FLAGS.num_layer)] |
这个方法内部是首先对每一层的 LSTM 进行正反向计算,然后对输出隐层进行 concat,然后输入下一层再进行计算,这里值得注意的地方是,我们不能把 LSTM Cell 提前组合成 MultiRNNCell 再调用 bidirectional_dynamic_rnn() 进行计算,这样相当于只有最后一层才进行 concat,是错误的。 现在我们得到的 output 就是 Bi-LSTM 的最后输出结果了。 接下来我们需要对输出结果进行一下 stack() 操作转化为一个 Tensor,然后将其 reshape() 一下,转化为 [-1, num_units * 2] 的 shape:
1 |
output = tf.stack(output, axis=1) |
这样我们再经过一层全连接网络将维度进行转换:
1 |
# Output Layer |
这样得到的最后的 y_predict 即为预测结果,shape 为 [batch_size],即每一句都得到了一个最可能的结果标注。 接下来我们需要计算一下准确率和 Loss,准确率其实就是比较 y_predict 和 y_label 的相似度,Loss 即为二者交叉熵:
1 |
# Reshape y_label |
这里计算交叉熵使用的是 sparse_softmax_cross_entropy_with_logits() 方法,Optimizer 使用的是 Adam。 最后指定训练过程和测试过程即可,训练过程如下:
1 |
for epoch in range(FLAGS.epoch_num): |
这里训练时首先调用了 train_initializer,将 iterator 指向训练数据,这样每调用一次 get_next(),x 和 y_label 就会被赋值为训练数据的一个 batch,接下来打印输出了 Loss,Accuracy 等内容。另外对于开发集来说,每次进行验证的时候也需要重新调用 dev_initializer,这样 iterator 会再次指向开发集,这样每调用一次 get_next(),x 和 y_label 就会被赋值为开发集的一个 batch,然后进行验证。 对于测试来说,我们可以计算其准确率,然后将测试的结果输出出来,代码实现如下:
1 |
sess.r test_initializer) |
这里打印输出了当前测试的准确率,然后得到了测试结果,然后再结合词表将测试的真正结果打印出来即可。
运行结果
在训练过程中,我们需要构建模型图,然后调用训练部分的代码进行训练,输出结果类似如下:
1 |
Global Step 0 Step 0 Train Loss 1.67181 Accuracy 0.1475 |
随着训练的进行,准确率可以达到 96% 左右。 在测试阶段,输出了当前模型的准确率及真实测试输出结果,输出结果类似如下:
1 |
Test step 0 Accuracy 0.946125 |
可见测试准确率在 95% 左右,对于测试数据,此处还输出了每句话的序列标注结果,如第一行结果中,“据”字对应的标注就是 s,代表单字成词,“新”字对应的标注是 b,代表词的起始,“华”字对应标注是 m,代表词的中间,“社”字对应的标注是 e,代表结束,这样 “据”、“新华社” 就可以被分成两个词了,可见还是有一定效果的。
结语
本节通过搭建一个 Bi-LSTM 网络实现了序列标注,并可实现分词,准确率可达到 95% 左右,但是最主要的还是学习 Bi-LSTM 的用法,本实例代码较多,部分代码已经省略,完整代码见:https://github.com/AIDeepLearning/BiLSTMWordBreaker。