seq2seq

Snip20191021_16.png Snip20191021_15.png

如果没有安装 keras 和 tensorflow 库

请使用 pip install keras tensorflow 安装

如果使用conda虚拟环境

请使用conda install -c conda-forge keras

conda install -c conda-forge tensorflow

In [1]:
from keras.models import Model
from keras.layers import Input, LSTM, Dense, Embedding
from keras.optimizers import Adam
import numpy as np
Using TensorFlow backend.
In [2]:
batch_size = 64  
epochs = 9
latent_dim = 256  
embedding_size = 128
file_name = '../input/poetry.txt'

下面这段代码用于处理原始数据

seq2seq的训练数据是由输入和目标组成的一对,即input和target

我们这里展示的任务是对诗,那么input是上句诗,target就是下句诗

我们首先建立所有输入句子的词典input_vocab和target_vocab

其次,解码的时候需要起始字符<BOS>和结束字符<EOS>,这里分别用制表符'\t'和回车符'\n'来表示

In [3]:
input_texts = []
target_texts = []
input_vocab = set()
target_vocab = set()
with open(file_name, 'r', encoding='utf-8') as f:
    lines = f.readlines()
for line in lines:
    # 将诗句用逗号分开
    line_sp = line.strip().split(',')
    # 如果诗中不含逗号,这句诗我们就不用了
    if len(line_sp) < 2:
        continue
    # 上句为input_text,下句为target_text
    input_text, target_text = line_sp[0], line_sp[1]
    # 在下句前后开始字符和结束字符
    target_text = '\t' + target_text[:-1] + '\n'
    input_texts.append(input_text)
    target_texts.append(target_text)
    # 统计输入侧的词汇表和输出侧的词汇表
    for ch in input_text:
        if ch not in input_vocab:
            input_vocab.add(ch)
    for ch in target_text:
        if ch not in target_vocab:
            target_vocab.add(ch)

# 建立字典和反向字典
input_vocab = dict([(char, i) for i, char in enumerate(input_vocab)])
target_vocab = dict([(char, i) for i, char in enumerate(target_vocab)])
reverse_input_char_index = dict((i, char) for char, i in input_vocab.items())
reverse_target_char_index = dict((i, char) for char, i in target_vocab.items())

# 输入侧词汇表大小
encoder_vocab_size = len(input_vocab)
# 最长输入句子长度
encoder_len = max([len(sentence) for sentence in input_texts])
# 输出侧词汇表大小
decoder_vocab_size = len(target_vocab)
# 最长输出句子长度
decoder_len = max([len(sentence) for sentence in target_texts])

print(encoder_vocab_size)
print(encoder_len)
print(decoder_vocab_size)
print(decoder_len)
4767
7
4816
9

下面这段代码用于构建训练数据

训练数据由三部分构成,编码器输入,解码器输入,解码器目标

即encoder_input, decoder_input, decoder_target、

在构建的同时还把字转化成了字典里的编号

In [4]:
encoder_input_data = np.zeros((len(input_texts), encoder_len), dtype='int')
decoder_input_data = np.zeros((len(input_texts), decoder_len), dtype='int')
decoder_target_data = np.zeros((len(input_texts), decoder_len, 1), dtype='int')

for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
    for t, char in enumerate(input_text):
        encoder_input_data[i, t] = input_vocab[char]
    for t, char in enumerate(target_text):
        decoder_input_data[i, t] = target_vocab[char]
        if t > 0:
            decoder_target_data[i, t - 1, 0] = target_vocab[char]
            
print(encoder_input_data.shape)
print(decoder_input_data.shape)
print(decoder_target_data.shape)
(72514, 7)
(72514, 9)
(72514, 9, 1)

下面这段代码用于搭建模型

In [5]:
# 编码器输入层
encoder_inputs = Input(shape=(None,))
# 编码器词嵌入层
encoder_embedding = Embedding(input_dim=encoder_vocab_size, output_dim=embedding_size, trainable=True)(encoder_inputs)
# 编码器长短期记忆网络层
encoder = LSTM(latent_dim, return_state=True)
# 编码器长短期记忆网络输出是一个三元组(encoder_outputs, state_h, state_c)
# encoder_outputs是长短期记忆网络每个时刻的输出构成的序列
# state_h和state_c是长短期记忆网络最后一个时刻的隐状态和细胞状态
encoder_outputs, state_h, state_c = encoder(encoder_embedding)
# 我们会把state_h和state_c作为解码器长短期记忆网络的初始状态,之前我们所说的状态向量的传递就是这样实现的
encoder_states = [state_h, state_c]

# 解码器网络建构

# 解码器输入层
decoder_inputs = Input(shape=(None,))
# 解码器词嵌入层
decoder_embedding = Embedding(input_dim=decoder_vocab_size, output_dim=embedding_size, trainable=True)(decoder_inputs)
# 解码器长短期记忆网络层
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
# 解码器长短期记忆网络的输出也是三元组,但我们只关心三元组的第一维,同时我们在这里设置了解码器长短期记忆网络的初始状态
decoder_outputs, _, _ = decoder_lstm(decoder_embedding, initial_state=encoder_states)
# 解码器输出经过一个隐层softmax变换转换为对各类别的概率估计
decoder_dense = Dense(decoder_vocab_size, activation='softmax')
# 解码器输出层
decoder_outputs = decoder_dense(decoder_outputs)
# 总模型,接受编码器和解码器输入,得到解码器长短期记忆网络输出
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.compile(optimizer=Adam(lr=0.001), loss='sparse_categorical_crossentropy')
model.summary()
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, None)         0                                            
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, None, 128)    610176      input_1[0][0]                    
__________________________________________________________________________________________________
embedding_2 (Embedding)         (None, None, 128)    616448      input_2[0][0]                    
__________________________________________________________________________________________________
lstm_1 (LSTM)                   [(None, 256), (None, 394240      embedding_1[0][0]                
__________________________________________________________________________________________________
lstm_2 (LSTM)                   [(None, None, 256),  394240      embedding_2[0][0]                
                                                                 lstm_1[0][1]                     
                                                                 lstm_1[0][2]                     
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, None, 4816)   1237712     lstm_2[0][0]                     
==================================================================================================
Total params: 3,252,816
Trainable params: 3,252,816
Non-trainable params: 0
__________________________________________________________________________________________________

训练模型

In [6]:
model.fit([encoder_input_data, decoder_input_data], decoder_target_data, batch_size=batch_size, epochs=epochs, validation_split=0.2)
Train on 58011 samples, validate on 14503 samples
Epoch 1/9
58011/58011 [==============================] - 15s 255us/step - loss: 3.9766 - val_loss: 3.7704
Epoch 2/9
58011/58011 [==============================] - 13s 227us/step - loss: 3.6544 - val_loss: 3.5933
Epoch 3/9
58011/58011 [==============================] - 13s 227us/step - loss: 3.4701 - val_loss: 3.4549
Epoch 4/9
58011/58011 [==============================] - 13s 228us/step - loss: 3.3023 - val_loss: 3.3684
Epoch 5/9
58011/58011 [==============================] - 13s 227us/step - loss: 3.1695 - val_loss: 3.3098
Epoch 6/9
58011/58011 [==============================] - 13s 229us/step - loss: 3.0556 - val_loss: 3.2697
Epoch 7/9
58011/58011 [==============================] - 13s 228us/step - loss: 2.9514 - val_loss: 3.2489
Epoch 8/9
58011/58011 [==============================] - 14s 235us/step - loss: 2.8551 - val_loss: 3.2304
Epoch 9/9
58011/58011 [==============================] - 13s 230us/step - loss: 2.7646 - val_loss: 3.2255
Out[6]:
<keras.callbacks.callbacks.History at 0x7faba9e4ada0>

这里可能同学们会很困惑,为什么下面这段代码又在构建模型,原因是seq2seq在训练和生成的时候并不完全相同

训练的时候,解码器是有预先输入的,我们会把正确的下句作为输入指导解码器进行学习,具体来说,不管上一个时刻解码器的输出是什么,我们都用预先给定的输入作为本时刻的输入

这种训练方式称为Teacher forcing

但是在生成的时候,解码器是没有预先输入的,我们会把上一个时刻解码器的输出作为本时刻的输入,如此迭代的生成句子

训练的时候我们的model是一整个seq2seq的模型,这个黑盒在给定encoder_input和decoder_input的情况下可以产生对应的输出

但是生成时我们没有decoder_input,我们就把黑盒拆成两个黑盒,一个是编码器,一个是解码器,方便我们的操作

In [7]:
# 第一个黑盒,编码器,给定encoder_inputs,得到encoder的状态
encoder_model = Model(encoder_inputs, encoder_states)
# 第二个黑盒,解码器
# 解码器接受三个输入,两个是初始状态,一个是之前已经生成的文本
decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
# 解码器产生三个输出,两个当前状态,一个是每个时刻的输出,其中最后一个时刻的输出可以用来计算下一个字
decoder_outputs, state_h, state_c = decoder_lstm(decoder_embedding, initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = Model([decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states)

这段代码就实现了迭代的解码

假设我们已经生成了前n个字,我们把前n个字作为输入,得到第n+1个字,再把这n+1个字作为输入,得到第n+2个字,以此类推

In [8]:
def decode_sequence(input_seq):
    # 先把上句输入编码器得到编码的中间向量,这个中间向量将是解码器的初始状态向量
    states_value = encoder_model.predict(input_seq)
    # 初始的解码器输入是开始符'\t'
    target_seq = np.zeros((1, 1))
    target_seq[0, 0] = target_vocab['\t']

    stop_condition = False
    decoded_sentence = ''
    # 迭代解码
    while not stop_condition:
        # 把当前的解码器输入和当前的解码器状态向量送进解码器
        # 得到对下一个时刻的预测和新的解码器状态向量
        output_tokens, h, c = decoder_model.predict([target_seq] + states_value)
        # 采样出概率最大的那个字作为下一个时刻的输入
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = reverse_target_char_index[sampled_token_index]
        decoded_sentence += sampled_char
        # 如果采样到了结束符或者生成的句子长度超过了decoder_len,就停止生成
        if (sampled_char == '\n' or len(decoded_sentence) > decoder_len):
            stop_condition = True
        # 否则我们更新下一个时刻的解码器输入和解码器状态向量
        target_seq = np.zeros((1, 1))
        target_seq[0, 0] = sampled_token_index
        states_value = [h, c]

    return decoded_sentence
In [9]:
for seq_index in range(200, 300):
    input_seq = encoder_input_data[seq_index: seq_index + 1]
    decoded_sentence = decode_sequence(input_seq)
    print('-')
    print('Input sentence:', input_texts[seq_index])
    print('Decoded sentence:', decoded_sentence)
-
Input sentence: 辞枝枝暂起
Decoded sentence: 远雁入寒林

-
Input sentence: 向日终难托
Decoded sentence: 何人问此生

-
Input sentence: 只待纤纤手
Decoded sentence: 无人问旧游

-
Input sentence: 骏骨饮长泾
Decoded sentence: 金壶自可论

-
Input sentence: 细纹连喷聚
Decoded sentence: 红粉拂尘埃

-
Input sentence: 水光鞍上侧
Decoded sentence: 山月照寒沙

-
Input sentence: 翻似天池里
Decoded sentence: 空濛碧树新

-
Input sentence: 阶兰凝曙霜
Decoded sentence: 水色含清辉

-
Input sentence: 露浓晞晚笑
Decoded sentence: 风暖叶初来

-
Input sentence: 细叶凋轻翠
Decoded sentence: 清风入洞庭

-
Input sentence: 还持今岁色
Decoded sentence: 不觉夜钟催

-
Input sentence: 秋露凝高掌
Decoded sentence: 寒山月照明

-
Input sentence: 参差丽双阙
Decoded sentence: 落日照金茎

-
Input sentence: 仙驭随轮转
Decoded sentence: 金炉压柳条

-
Input sentence: 临波光定彩
Decoded sentence: 落日照金台

-
Input sentence: 还当葵霍志
Decoded sentence: 不觉别离情

-
Input sentence: 半月无双影
Decoded sentence: 寒山独自伤

-
Input sentence: 摧藏千里态
Decoded sentence: 落日照青山

-
Input sentence: 促节萦红袖
Decoded sentence: 清风入翠微

-
Input sentence: 驶弹风响急
Decoded sentence: 落日照金台

-
Input sentence: 空余关陇恨
Decoded sentence: 不觉别离情

-
Input sentence: 驱马出辽阳
Decoded sentence: 青山空有心

-
Input sentence: 对敌六奇举
Decoded sentence: 还应问所思

-
Input sentence: 斩鲸澄碧海
Decoded sentence: 高枕入青山

-
Input sentence: 昔去兰萦翠
Decoded sentence: 还应上苑回

-
Input sentence: 云芝浮碎叶
Decoded sentence: 山月照寒沙

-
Input sentence: 回首长安道
Decoded sentence: 空山独自伤

-
Input sentence: 四时运灰琯
Decoded sentence: 不觉寒松里

-
Input sentence: 送寒余雪尽
Decoded sentence: 春色满山川

-
Input sentence: 焰听风来动
Decoded sentence: 风吹不可留

-
Input sentence: 镇下千行泪
Decoded sentence: 人间不是家

-
Input sentence: 九龙蟠焰动
Decoded sentence: 万里一开颜

-
Input sentence: 即此流高殿
Decoded sentence: 无人问旧游

-
Input sentence: 上弦明月半
Decoded sentence: 万里一帆孤

-
Input sentence: 落雁带书惊
Decoded sentence: 秋风入洞房

-
Input sentence: 初秋玉露清
Decoded sentence: 不觉寒松色

-
Input sentence: 隔云时乱影
Decoded sentence: 高树落花迟

-
Input sentence: 岸曲丝阴聚
Decoded sentence: 风吹叶叶黄

-
Input sentence: 还将眉里翠
Decoded sentence: 不觉梦中人

-
Input sentence: 贞条障曲砌
Decoded sentence: 红粉拂香茵

-
Input sentence: 拂牖分龙影
Decoded sentence: 开花拂绮罗

-
Input sentence: 散影玉阶柳
Decoded sentence: 含香满碧林

-
Input sentence: 微形藏叶里
Decoded sentence: 落日照青山

-
Input sentence: 盘根直盈渚
Decoded sentence: 清净不能变

-
Input sentence: 舒华光四海
Decoded sentence: 落日照金河

-
Input sentence: 近谷交萦蕊
Decoded sentence: 开门对月明

-
Input sentence: 径细无全磴
Decoded sentence: 山深不见人

-
Input sentence: 疾风知劲草
Decoded sentence: 不觉别离情

-
Input sentence: 勇夫安识义
Decoded sentence: 不觉老夫名

-
Input sentence: 太液仙舟迥
Decoded sentence: 登临御史骢

-
Input sentence: 未晓征车度
Decoded sentence: 孤灯对月明

-
Input sentence: 烟生遥岩隐
Decoded sentence: 水色上林峦

-
Input sentence: 连山惊鸟乱
Decoded sentence: 远水落花前

-
Input sentence: 醽醁胜兰生
Decoded sentence: 金闺怨秋色

-
Input sentence: 千日醉不醒
Decoded sentence: 一杯无不知

-
Input sentence: 雪耻酬百王
Decoded sentence: 风吹不可绊

-
Input sentence: 昔乘匹马去
Decoded sentence: 何处是归期

-
Input sentence: 近日毛虽暖
Decoded sentence: 孤云远不归

-
Input sentence: 温渚停仙跸
Decoded sentence: 青山白首新

-
Input sentence: 路曲回轮影
Decoded sentence: 山连白日斜

-
Input sentence: 暖溜惊湍驶
Decoded sentence: 寒山月影斜

-
Input sentence: 林黄疏叶下
Decoded sentence: 山月照寒山

-
Input sentence: 眺听良无已
Decoded sentence: 人间不得知

-
Input sentence: 停轩观福殿
Decoded sentence: 清跸上龙城

-
Input sentence: 法轮含日转
Decoded sentence: 金鼎出金台

-
Input sentence: 翠烟香绮阁
Decoded sentence: 金缕拂花开

-
Input sentence: 幡虹遥合彩
Decoded sentence: 玉树荫春风

-
Input sentence: 萧然登十地
Decoded sentence: 不觉有风尘

-
Input sentence: 日宫开万仞
Decoded sentence: 山月照明光

-
Input sentence: 花盖飞团影
Decoded sentence: 山禽落日斜

-
Input sentence: 绮霞遥笼帐
Decoded sentence: 红蕉叶色黄

-
Input sentence: 寥廓烟云表
Decoded sentence: 清风入洞庭

-
Input sentence: 今宵冬律尽
Decoded sentence: 不觉夜钟催

-
Input sentence: 花余凝地雪
Decoded sentence: 山色上楼台

-
Input sentence: 绶吐芽犹嫩
Decoded sentence: 山中见有人

-
Input sentence: 薄红梅色冷
Decoded sentence: 高枕绿云生

-
Input sentence: 送迎交两节
Decoded sentence: 遥见白云端

-
Input sentence: 九日正乘秋
Decoded sentence: 相逢不见家

-
Input sentence: 泛桂迎尊满
Decoded sentence: 含香满紫微

-
Input sentence: 长房萸早熟
Decoded sentence: 高树带寒烟

-
Input sentence: 何藉龙沙上
Decoded sentence: 不见山下人

-
Input sentence: 四郊秦汉国
Decoded sentence: 万里一帆飞

-
Input sentence: 阊阖雄里閈
Decoded sentence: 玉箸出尘埃

-
Input sentence: 贯渭称天邑
Decoded sentence: 登楼望九州

-
Input sentence: 金门披玉馆
Decoded sentence: 玉箸出金台

-
Input sentence: 眷言君失德
Decoded sentence: 不觉有风尘

-
Input sentence: 政烦方改篆
Decoded sentence: 人事不知名

-
Input sentence: 阿房久已灭
Decoded sentence: 玉箸自相侵

-
Input sentence: 欲厌东南气
Decoded sentence: 无人问此生

-
Input sentence: 有隋政昏虐
Decoded sentence: 不知身不知

-
Input sentence: 先圣按剑起
Decoded sentence: 清光满太明

-
Input sentence: 饮马河洛竭
Decoded sentence: 一杯空自清

-
Input sentence: 克敌睿图就
Decoded sentence: 清光动清漪

-
Input sentence: 顾惭嗣宝历
Decoded sentence: 不知身未知

-
Input sentence: 幸过翦鲸地
Decoded sentence: 不知身在斯

-
Input sentence: 汉家重东郡
Decoded sentence: 白首一沾衣

-
Input sentence: 黎庶既蕃殖
Decoded sentence: 不知心自知

-
Input sentence: 远别初首路
Decoded sentence: 相逢无处时

-
Input sentence: 课成应第一
Decoded sentence: 应是故人期

-
Input sentence: 北风吹同云
Decoded sentence: 白首无人心