TensorFlow函数:tf.estimator.RunConfig
tf.estimator.RunConfig函数
RunConfig类
定义在:tensorflow/python/estimator/run_config.py.
该类指定Estimator运行的配置.
属性
- cluster_spec
- evaluation_master
- global_id_in_cluster
该global_id_in_cluster属性表示训练集群中的全局标识.
训练集群中的所有全局ID都是从递增的连续整数序列中分配的,第一个ID是0.
注意:任务ID(属性字段task_id)正在跟踪具有SAME任务类型的所有节点中的节点索引.例如,给定集群定义如下:
cluster = {'chief': ['host0:2222'], 'ps': ['host1:2222', 'host2:2222'], 'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
具有任务类型worker的节点可以具有id 0,1,2.具有任务类型ps的节点可以具有id,0,1.因此,task_id不是唯一的,但pair(task_type,task_id)可以唯一确定集群中的节点.
全局ID即该字段正在跟踪集群中所有节点之间的节点索引.它是唯一分配的.例如,对于上面给出的集群规范,全局id分配为:
task_type | task_id | global_id -------------------------------- chief | 0 | 0 worker | 0 | 1 worker | 1 | 2 worker | 2 | 3 ps | 0 | 4 ps | 1 | 5
返回:
一个整数ID.
- is_chief
- keep_checkpoint_every_n_hours
- keep_checkpoint_max
- log_step_count_steps
- master
- model_dir
- num_ps_replicas
- num_worker_replicas
- save_checkpoints_secs
- save_checkpoints_steps
- save_summary_steps
- service
返回定义的平台(在TF_CONFIG中)服务字典.
- session_config
- task_id
- task_type
- tf_random_seed
方法
__init__
__init__(
model_dir=None,
tf_random_seed=None,
save_summary_steps=100,
save_checkpoints_steps=_USE_DEFAULT,
save_checkpoints_secs=_USE_DEFAULT,
session_config=None,
keep_checkpoint_max=5,
keep_checkpoint_every_n_hours=10000,
log_step_count_steps=100
)
该方法用于构造一个RunConfig.
所有的分布式训练相关的属性cluster_spec,is_chief,master,num_worker_replicas,num_ps_replicas,task_id和task_type都是基于 TF_CONFIG 环境变量设置的,如果相关的信息存在.TF_CONFIG环境变量是具有属性JSON对象:cluster和task.
cluster是ClusterSpec的Python字典的JSON序列化版本,它将server_lib.py任务类型(通常是TaskType枚举之一)映射到任务地址列表.
task有两个属性:type和index,其中,type可以是cluster中任何类型的任务.当TF_CONFIG包含所述信息,则在该类上设置以下属性:
- cluster_spec:该属性从TF_CONFIG['cluster']解析,默认为{},如果存在,则在cluster_spec的chief属性中必须有且仅有一个节点.
- task_type:设置为TF_CONFIG['task']['type'];如果cluster_spec存在,则必须设置;如果cluster_spec没有设置,则必须是worker(默认值).
- task_id:设置为TF_CONFIG['task']['index'];如果cluster_spec存在,必须设置;如果cluster_spec未设置,则必须为0(默认值).
- master:master属性是通过在cluster_spec中查找task_type和task_id来确定的,默认为''.
- num_ps_replicas:是通过计算cluster_spec的ps属性中列出的节点数来设置的,默认为0.
- num_worker_replicas:是通过计算cluster_spec中的worker和chief属性中列出的节点数来设置的,默认为1.
- is_chief:是基于task_type和cluster确定的.
有一个带有task_type作为计算器的特殊节点,它不是(训练)cluster_spec的一部分,它处理分布式计算作业.
non-chief节点的例子:
cluster = {'chief': ['host0:2222'],
'ps': ['host1:2222', 'host2:2222'],
'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
os.environ['TF_CONFIG'] = json.dumps(
{'cluster': cluster,
'task': {'type': 'worker', 'index': 1}})
config = ClusterConfig()
assert config.master == 'host4:2222'
assert config.task_id == 1
assert config.num_ps_replicas == 2
assert config.num_worker_replicas == 4
assert config.cluster_spec == server_lib.ClusterSpec(cluster)
assert config.task_type == 'worker'
assert not config.is_chief
chief的例子:
cluster = {'chief': ['host0:2222'],
'ps': ['host1:2222', 'host2:2222'],
'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
os.environ['TF_CONFIG'] = json.dumps(
{'cluster': cluster,
'task': {'type': 'chief', 'index': 0}})
config = ClusterConfig()
assert config.master == 'host0:2222'
assert config.task_id == 0
assert config.num_ps_replicas == 2
assert config.num_worker_replicas == 4
assert config.cluster_spec == server_lib.ClusterSpec(cluster)
assert config.task_type == 'chief'
assert config.is_chief
evaluator节点示例(evaluator不是训练集群的一部分):
cluster = {'chief': ['host0:2222'],
'ps': ['host1:2222', 'host2:2222'],
'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
os.environ['TF_CONFIG'] = json.dumps(
{'cluster': cluster,
'task': {'type': 'evaluator', 'index': 0}})
config = ClusterConfig()
assert config.master == ''
assert config.evaluator_master == ''
assert config.task_id == 0
assert config.num_ps_replicas == 0
assert config.num_worker_replicas == 0
assert config.cluster_spec == {}
assert config.task_type == 'evaluator'
assert not config.is_chief
注意:如果save_checkpoints_steps或save_checkpoints_secs已设置,keep_checkpoint_max可能需要进行相应调整,特别是在分布式训练中.例如,设置save_checkpoints_secs为60而不进行调整keep_checkpoint_max(默认为5)会导致检查点在5分钟后被垃圾收集的情况.在分布式训练中,计算作业异步启动,可能无法加载或由于竞争条件而找到检查点.
参数:
- model_dir:保存模型参数,图表等的目录.如果有PathLike对象,路径将被解析;如果为None,则将使用Estimator设置的默认值.
- tf_random_seed:TensorFlow初始化器的随机种子,设置此值可以实现重播之间的一致性.
- save_summary_steps:每隔这么多步骤保存摘要.
- save_checkpoints_steps:每隔这么多步骤保存检查点,不能用save_checkpoints_secs指定.
- save_checkpoints_secs:每隔几秒钟保存检查点,不能用save_checkpoints_steps指定;如果save_checkpoints_steps和save_checkpoints_secs在构造函数中未设置,则默认设置为600秒;如果两个save_checkpoints_steps和save_checkpoints_secs为None,则检查站被禁用.
- session_config:用于设置会话参数的ConfigProto,或None.
- keep_checkpoint_max:要保留的最近检查点文件的最大数量.当新文件被创建时,旧文件被删除.如果为None或0,则保留所有检查点文件.默认为5(也就是保留5个最近的检查点文件.)
- keep_checkpoint_every_n_hours:要保存的每个检查点之间的小时数;默认值10,000小时有效地禁用该功能.
- log_step_count_steps:在培训期间将记录全局步骤/秒(global step/sec)的频率 (以全局步骤数表示).
可能引发的异常:
- ValueError:如果同时设置save_checkpoints_steps和save_checkpoints_secs.
replace
replace(**kwargs)
返回RunConfig的新实例替换指定属性.
仅允许替换以下列表中的属性:
- model_dir
- tf_random_seed
- save_summary_steps
- save_checkpoints_steps
- save_checkpoints_secs
- session_config
- keep_checkpoint_max
- keep_checkpoint_every_n_hours
- log_step_count_steps
另外,可以设置save_checkpoints_steps或者save_checkpoints_secs(不应该同时设置).
参数:
- **kwargs:使用新值命名属性的关键字.
可能引发的异常:
- ValueError:如果任何属性名kwargs不存在或不允许被替换,或同时设置save_checkpoints_steps和save_checkpoints_secs.
返回值:
一个RunConfig的新的实例.