阅读(4.3k) 书签 (0)

TensorFlow函数教程:tf.nn.static_state_saving_rnn

2019-02-14 14:54 更新

tf.nn.static_state_saving_rnn函数

别名:

  • tf.contrib.rnn.static_state_saving_rnn
  • tf.nn.static_state_saving_rnn
tf.nn.static_state_saving_rnn(
    cell,
    inputs,
    state_saver,
    state_name,
    sequence_length=None,
    scope=None
)

定义在:tensorflow/python/ops/rnn.py。

为时间截断的RNN计算接受状态保护程序的RNN。

参数:

  • cell:RNNCell的一个实例。
  • inputs:输入的长度为T的列表,每个输入都是一个具有shape [batch_size, input_size]的Tensor。
  • state_saver:一个状态保护程序对象,具有方法state和save_state。
  • state_name:Python字符串或字符串元组。与state_saver一起使用的名称。如果单元格返回状态元组(即,cell.state_size是一个元组),则state_name应该是与cell.state_size具有相同长度的字符串元组。否则它应该是一个单独的字符串。
  • sequence_length:(可选)int32 / int64向量,大小为[batch_size]。
  • scope:用于创建子图的VariableScope;默认为“rnn”。

返回:

(outputs, state)对,其中:

  • outputs是长度为T的输出列表(每个输入一个)
  • state是最终状态

可能引发的异常:

  • TypeError:如果cell不是RNNCell的实例。
  • ValueError:如果inputs是None或是一个空列表,或者state_name的arity和type与cell.state_size的不匹配。