资讯专栏INFORMATION COLUMN

GRU网络生成莎士比亚小说

joyvw / 364人阅读

摘要:介绍本文我们将使用网络来学习莎士比亚小说,模型通过学习可以生成与小说风格相似的文本,如图所示虽然有些句子并没有实际的意思目前我们的模型是基于概率,并不是理解语义,但是大多数单词都是有效的,文本结构也与我们训练的文本相似。

介绍

本文我们将使用GRU网络来学习莎士比亚小说,模型通过学习可以生成与小说风格相似的文本,如图所示:

虽然有些句子并没有实际的意思(目前我们的模型是基于概率,并不是理解语义),但是大多数单词都是有效的,文本结构也与我们训练的文本相似。
由于项目中使用到了Eager ExecutionGRU,所以我们先进行简单介绍:

Eager Execution

TensorflowEager Execution之前想要评估操作必须通过运行计算图"sess.run()"的方式来获取值,而使用Eager Execution可以立即评估操作。Eager Execution基于python流程控制并可以使用python的调试工具进行错误报告。

梯度计算:

先使用tf.GradientTape记录然后再计算梯度,示例如下:

# tfe = tf.contrib.eager
w = tfe.Variable([[1.0]])
with tf.GradientTape() as tape:
  loss = w * w

grad = tape.gradient(loss, w)

常用函数:
tfe.gradients_function:返回一个函数,该函数会计算其输入函数参数相对其参数的导数。
tfe.value_and_gradients_function:除了返回函数还会返回输入函数的值。

其它:

在训练大数据集的时候,Eager Execution 性能与Graph Execution相当,但在小数据集中Eager Execution会慢一些。
Eager Execution胜在开发和调试的便利性,但是在分布式训练,性能优化,生产部署方面Graph Execution更好。
在未调用tf.enable_eager_execution(开启后不能关闭)的情况下可以使用tfe.py_func启用Eager Execution

GRU

GRULSTM的一种变体,它将LSTM的遗忘门,输入门,输出门改为更新门(LSTM的遗忘门,输入门合并),重置门。参数少,收敛快,不过在数据量较大的时候LSTM的表现更好。下图是GRU网络结构和前向传播计算方法。

更新门:控制前一时刻的状态信息被带入到当前状态中的程度。
重置门:控制忽略前一时刻的状态信息,重置门的值越小说明忽略的越多(被写入的信息越少)。

GRU训练:

我们要学习的参数有Wr、Wz、Wh、Wo,其中Wr、Wz、Wh是和ht-1拼接而成,所以需要进行分割:

采用反向传播对损失函数的各参数求偏导:

中间参数为:

算出每个参数的偏导数之后就可以更新参数了。GRU通过门控机制选择性的保留特征,为长时传播提供了保证。正因为门控机制的有效,门卷积目前也很受欢迎,感兴趣的朋友可以阅读相关文献。

数据导入
import tensorflow as tf
import numpy as np
import os
import re
import random
import time

# 开启后不能关闭,只能重新启动新的python会话
tf.enable_eager_execution()

# 获取数据,你也可以使用其他数据集
path_to_file=tf.keras.utils.get_file("shakespeare.txt", "https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt")
text=open(path_to_file).read()

文字是不能直接放进模型的需要将其转换为对应的ID表示:

# 去除重复字符并排序
unique=sorted(set(text))

# enumerate 返回value,index
# 文本转id 
char2idx={value:idx for idx,value in enumerate(unique)}
# id转文本
idx2char={idx:value for idx,value in enumerate(unique)}

部分参数配置:

# 每次输入的最大文本长度,对应GRU模型的‘time_step’
max_length=100
vocab_size=len(unique)
# 词嵌入维度
embedding_dim=256
hidden_units=1024
BATCH_SIZE=64
BUFFER_SIZE=10000

获取ID表示的数据并创建标签

# 标签的定义方式如:
# data="ming"
# input="min" labels="ing"
input_text=[]
labels_text=[]

# 迭代获取‘max_length’个数据
for i in range(0,len(text)-max_length,max_length):
    inputs=text[i:i+max_length]
    labels=text[i+1:i+1+max_length]
    
    input_text.append([char2idx[i] for i in inputs])
    labels_text.append([char2idx[i] for i in labels])

dataset读取数据:

dataset=tf.data.Dataset.from_tensor_slices((input_text,output_text))
# drop_remainder:小于batch_size 是否删除,默认不删除
dataset=dataset.batch(BATCH_SIZE,drop_remainder=True)
创建模型

我们的模型包含三层:Embedding层,GRU层,全连接层。

class Model(tf.keras.Model):
    """
    GRU:重置门,更新门 LSTM:遗忘门,输入门,输出门
    GRU,参数少,容易收敛,数据量大的时候LSTM表现更好
    """
    def __init__(self,vocab_size,embedding_dim,units,batch_size):
        super(Model, self).__init__()

        self.units=units
        self.batch_size=batch_size
        self.embedding=tf.keras.layers.Embedding(
            input_dim=vocab_size,
            output_dim=embedding_dim
        )
        if tf.test.is_gpu_available:
            # 使用GPU加速训练
            self.gru=tf.keras.layers.CuDNNGRU(
                units=self.units,
                return_sequences=True,
                return_state=True,
                recurrent_initializer="glorot_uniform"
            )
        else:
            self.gru=tf.keras.layers.GRU(
                units=self.units,
                return_sequences=True,
                return_state=True,
                # 默认激活函数为:hard_sigmoid
                recurrent_activation="sigmoid",
                recurrent_initializer="glorot_uniform"
            )
        self.fc=tf.keras.layers.Dense(units=vocab_size)
    def __call__(self, x,hidden):
        x=self.embedding(x)
        
        # output:[batch_size,max_length,hidden_size]
        # states:[batch_size,hidden_size]
        output,states=self.gru(x,initial_state=hidden)

        # 转换至:(batch_size*max_length,hidden_size)
        output=tf.reshape(output,shape=(-1,output.shape[2]))
        
        # output:[batch_size*max_length,vocab_size]
        x=self.fc(output)

        return x,states
为什么要使用Embedding

Embedding将高纬离散向量转为低纬稠密的连续向量,并且表现出了向量间的相似性。

如图所示,one-hot表示只有一个位置是1,其余为0,当文字较多时维度将会非常的大,并且由于one-hot编码后的单词存在独立性,导致不能利用相似词汇进行学习。那么Embedding又是怎么做的呢?

使用Embedding的第一步是通过索引对句子进行编码,然后根据索引创建嵌入矩阵,这样我们使用嵌入矩阵替代one-hot编码向量。每个单词向量不再是由一个独立向量代替,而是替换成用于查找嵌入矩阵中向量的索引。

模型训练
# model初始化
model=Model(vocab_size,embedding_dim,hidden_units,BATCH_SIZE)
optimizer=tf.train.AdamOptimizer(learning_rate=0.001)

# 创建损失函数
def loss_fn(lables,preds):
    # 交叉熵损失函数在值域上边界依然可以保持较高的激活值
    return tf.losses.sparse_softmax_cross_entropy(
        labels=lables,
        logits=preds
    )

模型保存:

# 读取checkpoint需要重新定义图结构
checkpoint_dir = "./training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 model=model)

开始训练:

EPOCHS = 20

for epoch in range(EPOCHS):
    start = time.time()
    
    # 每迭代完成一次数据集重置hidden-state
    hidden = model.reset_states()
    
    for (batch, (inp, target)) in enumerate(dataset):
          # 使用GradientTape记录
          with tf.GradientTape() as tape:
              predictions, hidden = model(inp, hidden)
              
              target = tf.reshape(target, (-1,))
              loss = loss_function(target, predictions)
              
          grads = tape.gradient(loss, model.variables)
          # 更新
          optimizer.apply_gradients(zip(grads, model.variables))

          if batch % 100 == 0:
              print ("Epoch {} Batch {} Loss {:.4f}".format(epoch+1,
                                                            batch,
                                                            loss))
    # 每迭代5次数据集保存一次模型数据
    if (epoch + 1) % 5 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

读取保存的checkpoint文件:

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
预测

要指定输入字符以及希望模型生成的文本长度:

# 需要生成的文字长度
num_generate=1000

start_string="Q"
# 将输入字符转为对应ID表示
input_eval=[char2idx[s] for s in start_string]
# 扩展一维 batch_size
input_eval=tf.expand_dims(input_eval,0)

text_generated=""
# hidden state shape:(batch_size,rnn units)
# hidden 初始化
hidden=[tf.zeros((1,hidden_units))]

for i in range(num_generate):
    precit,hidden=model(input_eval,hidden)
    # 注:这里batch_size == 1
    # 代码参考,很好理解:
    # output = tf.transpose(output,[1,0,2])
    # last = tf.gather(output,int(output.get_shape()[0]-1)
    predict_id=tf.argmax(predict[-1]).numpy()
    # 将前一时刻的输出作为下一时刻的输入,一直到迭代完成
    input_eval=tf.expand_dims(predict_id,0)
    # 转换成对应字符
    text_generated+=idx2char[predict_id]
print(start_string+text_generated)
总结

GRU网路作为LSTM网路的变体,参数少收敛快。Eager模式下代码简洁,调试便利虽然比Graph Execution功能逊色,但胜在便利性。RNN现在很多项目都会结合注意力机制使用,效果很好。注意力简单来说就是对输入不再是同等看待,而是根据权重值大小来区别训练。

本文内容部分参考Yash Katariya,在此表示感谢。

文章版权归作者所有,未经允许请勿转载,若此文章存在违规行为,您可以联系管理员删除。

转载请注明本文地址:https://www.ucloud.cn/yun/42763.html

相关文章

  • GRU网络生成莎士比小说

    摘要:介绍本文我们将使用网络来学习莎士比亚小说,模型通过学习可以生成与小说风格相似的文本,如图所示虽然有些句子并没有实际的意思目前我们的模型是基于概率,并不是理解语义,但是大多数单词都是有效的,文本结构也与我们训练的文本相似。 介绍 本文我们将使用GRU网络来学习莎士比亚小说,模型通过学习可以生成与小说风格相似的文本,如图所示:showImg(https://segmentfault.com...

    genedna 评论0 收藏0
  • GRU网络生成莎士比小说

    摘要:介绍本文我们将使用网络来学习莎士比亚小说,模型通过学习可以生成与小说风格相似的文本,如图所示虽然有些句子并没有实际的意思目前我们的模型是基于概率,并不是理解语义,但是大多数单词都是有效的,文本结构也与我们训练的文本相似。 介绍 本文我们将使用GRU网络来学习莎士比亚小说,模型通过学习可以生成与小说风格相似的文本,如图所示:showImg(https://segmentfault.com...

    Coly 评论0 收藏0
  • 美丽的神经网络:13种细胞构筑的深度学习世界

    摘要:网络所有的神经元都与另外的神经元相连每个节点功能都一样。训练的方法是将每个神经元的值设定为理想的模式,然后计算权重。输入神经元在网络整体更新后会成为输入神经元。的训练和运行过程与十分相似将输入神经元设定为固定值,然后任网络自己变化。 新的神经网络架构随时随地都在出现,要时刻保持还有点难度。要把所有这些缩略语指代的网络(DCIGN,IiLSTM,DCGAN,知道吗?)都弄清,一开始估计还无从下...

    zsirfs 评论0 收藏0
  • 神经网络

    摘要:通过将神经元的值设置为希望的模式来训练网络,之后可以计算权重。输入神经元在完整网络更新结束时变成输出神经元。在某种程度上,这类似于峰值神经网络,并不是所有的神经元始终都在发射并且点的生物合理性得分。 随着新的神经网络架构不时出现,很难跟踪这些架构。知道所有缩写(DCIGN,BiLSTM,DCGAN,任何人?)起初可能有点压倒性。 所以我决定编写一个包含许多这些体系结构的备忘单。这些大多...

    Anonymous1 评论0 收藏0
  • 难以置信!LSTM和GRU的解析从未如此清晰

    摘要:作为解决方案的和和是解决短时记忆问题的解决方案,它们具有称为门的内部机制,可以调节信息流。随后,它可以沿着长链序列传递相关信息以进行预测,几乎所有基于递归神经网络的技术成果都是通过这两个网络实现的。和采用门结构来克服短时记忆的影响。 短时记忆RNN 会受到短时记忆的影响。如果一条序列足够长,那它们将很难将信息从较早的时间步传送到后面的时间步。 因此,如果你正在尝试处理一段文本进行预测,RNN...

    MrZONT 评论0 收藏0

发表评论

0条评论

joyvw

|高级讲师

TA的文章

阅读更多
最新活动
阅读需要支付1元查看
<