我要努力工作,加油!

tensorflow学习,循环神经网络(RNN)相关的函数简介(25)

tensorflow 向日葵智能 15℃ 0评论

本节,将介绍 tensorflow 实现循环神经网络 RNN 的主要函数。

实现 RNN 的基本单元 RNNCell


RNNCell 是 tensorflow 中的循环神经网络的基本单元,它是一个抽象类,本身不能实例化。它的两个子类,一个 BasicRNNCell,另一个BasicLSTMCell,分别对应经典循环神经网络,和长短记忆循环神经网络。

学习 RNNCell 要重点关注三个地方:

  • 类方法 call
  • 类属性 state_size
  • 类属性 output_size

简单的说,call方法就是用来计算隐状态的。关于隐状态可以参考前面两节(RNNLSTM)。而state_sizeoutput_size则表示隐状态的大小和输出向量的大小。

output, next_state = call(input, state)
通常 input 的形状是 [batch_size, input_size],所以隐状态的形状 [batch_size, state_size],输出形状[batch_size, output_size]。

定义经典 RNN 单元的方法

rnnCell = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
print rnnCell.state_size
# 应 state_size = 128

定义 LSTM 单元的方法

lstmCell = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
print lstmCell.state_size
# 应 state_size = LSTMStateTuple(c=128, h=128)

多层循环神经网络:MultiRNNCell


很多时候,单层 RNN 的能力有限,需要多层 RNN,在 tensorflow 中,可以使用 tf.nn.rnn_cell.MultiRNNCell 函数建立多层的 RNN,下面是一个示例小 demo

import tensorflow as tf
import numpy as np

# 创建单个cell并堆叠多层
def get_a_cell(lstm_size, keep_prob):
    rnn = tf.nn.rnn_cell.BasicRNNCell(num_units=128)
    return rnn
# 建立 3 层
cell = tf.nn.rnn_cell.MultiRNNCell([get_a_cell() for _ in range(3)])

这里 cell 的 state_size 为 (128,128,128),表示有 3 个隐状态,每个隐状态大小为 128。

MultiRNNCell 也是 RNNCell 的子类,所以它也有 call 方法,和 state_size, output_size 属性。

使用 dynamic_rnn 展开时间维度


对于单个 RNNCell,使用它的 call 方法进行运算时,只在序列时间是前进了一步。如使用 x1,h0 计算得到 h1,根据 x2,h1 计算得到 h2等。如果序列长度为 n,就需要调用 n 次 call 函数。tensorflow 提供了 tf.nn.dynamic_rnn 函数,等价于调用 n 次 call 函数。即通过 {h0, x1, x2, x3, …} 直接得到 {h1, h2, …}

outputs, state = tf.nn.dynamic_rnn(cell, inputs)

至此,建立循环神经网络的几个比较重要的 tensorflow 函数就介绍完了,下一节将尝试建立 RNN 网络,训练其作诗。

本节主要参考《21个项目玩转深度学习》。

转载请注明:xrk智能 » tensorflow学习,循环神经网络(RNN)相关的函数简介(25)

喜欢 (3)or分享 (0)
发表我的评论
取消评论

表情

Hi,您需要填写昵称和邮箱!

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址
(1)个小伙伴在吐槽