TensorFlow处理RNN参数变量
2018-09-01 15:50 更新
tf.contrib.cudnn_rnn.RNNParamsSaveable
tf.contrib.cudnn_rnn.RNNParamsSaveable 类
定义在:tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
用于处理 RNN 参数变量的 SaveableObject 实现.
方法
__init__
__init__ (
params_to_canonical ,
canonical_to_params ,
param_variables ,
name = 'params_canonical'
)
创建一个 RNNParamsSaveable 对象.
RNNParams 可以在检查点文件中保存/恢复,用于以规范格式保存/恢复权重和偏置参数,其中参数逐层保存为张量.对于每个层,偏差张量在重量张量之后被保存.恢复时,用户可以根据需要命名 param_variables,并将权重和偏差张量恢复到这些变量.
对于 CudnnRNNRelu 或 CudnnRNNTanh,每个层的每个权重和每个偏移量都有两个张量:张量0被用于从前一层输入,张量1用于循环输入.
对于 CudnnLSTM,每个层的每个权重和每个偏移量有8个张量;张量0-3被用于从前一层输入;张量4-7用于循环输入;张量0和4用于输入门;张量1和5忘记门;张量2和6新的存储门; 张量3和7是输出门.
对于 CudnnGRU,每个层的每个权重和每个偏移量有6张张量;张量0-2被用于从前一层输入;张量3-5用于循环输入;张量0和3用于复位门;张量1和4更新门;张量2和5新的存储门.
ARGS:
- params_to_canonical:一种函数, 用于将参数从特定格式转换为 cuDNN 或其他 RNN ops 转换到规范格式._CudnnRNN params_to_canonical () 应在这里提供.
- canonical_to_params:用于将参数从规范格式转换为 cuDNN 或其他 RNN ops 的特定格式的函数.函数必须返回一个标量 (如 cuDNN) 或元组.此函数可以是 _CudnnRN.
- param_variables:特定窗体中参数的变量列表.对于 cuDNN RNN ops,这是一个单一的加权和偏见合并变量;对于其他 RNN ops, 这可能是多个未或部分合并的变量, 分别用于权重和偏差.
- name:RNNParamsSaveable 对象的名称.
restore
restore(
restored_tensors ,
restored_shapes
)