TensorFlow函数:tf.estimator.TrainSpec
2018-05-08 12:04 更新
tf.estimator.TrainSpec函数
TrainSpec类
定义在:tensorflow/python/estimator/training.py.
train_and_evaluate调用的“train”部分的配置.
TrainSpec确定训练的输入数据以及持续时间.可选的钩子(hook)在不同训练阶段运行.
属性
- hooks
字段号2的别名
- input_fn
字段号0的别名
- max_steps
字段号1的别名
方法
__new__
@ staticmethod
__new__ (
cls ,
input_fn ,
max_steps = None ,
hooks = None
)
创建一个已经经过验证的TrainSpec实例.
参数:
- input_fn:训练输入函数返回一个元祖:features - Tensor或名为Tensor字符串特征的字典,labels - Tensor或带有标签的Tensor字典.
- max_steps:Int.用于训练模型的总步骤的正数.如果为None,则一直训练.训练input_fn预计不会产生OutOfRangeError或StopIteration异常.
- hooks:在训练过程中对所有workers(包括chief)运行的tf.train.SessionRunHook对象进行可迭代处理.
返回值:
tf.estimator.TrainSpec函数返回一个经过验证的TrainSpec对象.
可能引发的异常:
- ValueError:如果任何输入参数无效.
- TypeError:如果任何参数不是预期的类型.