TensorFlow函数教程:tf.nn.static_state_saving_rnn
2019-02-14 14:54 更新
tf.nn.static_state_saving_rnn函数
别名:
- tf.contrib.rnn.static_state_saving_rnn
- tf.nn.static_state_saving_rnn
tf.nn.static_state_saving_rnn(
cell,
inputs,
state_saver,
state_name,
sequence_length=None,
scope=None
)
定义在:tensorflow/python/ops/rnn.py。
为时间截断的RNN计算接受状态保护程序的RNN。
参数:
- cell:RNNCell的一个实例。
- inputs:输入的长度为T的列表,每个输入都是一个具有shape [batch_size, input_size]的Tensor。
- state_saver:一个状态保护程序对象,具有方法state和save_state。
- state_name:Python字符串或字符串元组。与state_saver一起使用的名称。如果单元格返回状态元组(即,cell.state_size是一个元组),则state_name应该是与cell.state_size具有相同长度的字符串元组。否则它应该是一个单独的字符串。
- sequence_length:(可选)int32 / int64向量,大小为[batch_size]。
- scope:用于创建子图的VariableScope;默认为“rnn”。
返回:
(outputs, state)对,其中:
- outputs是长度为T的输出列表(每个输入一个)
- state是最终状态
可能引发的异常:
- TypeError:如果cell不是RNNCell的实例。
- ValueError:如果inputs是None或是一个空列表,或者state_name的arity和type与cell.state_size的不匹配。