阅读(15.5k) 书签 (0)

TensorFlow函数:tf.estimator.TrainSpec

2018-05-08 12:04 更新

tf.estimator.TrainSpec函数

TrainSpec类

定义在:tensorflow/python/estimator/training.py.

train_and_evaluate调用的“train”部分的配置.

TrainSpec确定训练的输入数据以及持续时间.可选的钩子(hook)在不同训练阶段运行.

属性

  • hooks

    字段号2的别名

  • input_fn

    字段号0的别名

  • max_steps

    字段号1的别名

方法

__new__

@ staticmethod 
__new__ ( 
    cls , 
    input_fn , 
    max_steps = None , 
    hooks = None 
)

创建一个已经经过验证的TrainSpec实例.

参数:

  • input_fn:训练输入函数返回一个元祖:features - Tensor或名为Tensor字符串特征的字典,labels - Tensor或带有标签的Tensor字典.
  • max_steps:Int.用于训练模型的总步骤的正数.如果为None,则一直训练.训练input_fn预计不会产生OutOfRangeError或StopIteration异常.
  • hooks:在训练过程中对所有workers(包括chief)运行的tf.train.SessionRunHook对象进行可迭代处理.

返回值:

tf.estimator.TrainSpec函数返回一个经过验证的TrainSpec对象.

可能引发的异常:

  • ValueError:如果任何输入参数无效.
  • TypeError:如果任何参数不是预期的类型.