阅读(16.1k) 书签 (0)

TensorFlow 指标(contrib)

2019-01-31 18:11 更新

评估指标和汇总统计的操作.

API

该模块提供计算流式指标的函数:以动态价值计算的指标 Tensors。每个指标声明返回一个“value_tensor”,一个返回指标当前值的幂等运算,还有一个“update_op”,这是一个从当前Tensors 测量值累加信息的操作,并返回 “ value_tensor”。

要使用这些指标中的任何一个,只需声明指标,update_op 重复调用即可将数据累加到所需数量的 Tensor 值(通常每个都是单个批次),最后评估 value_tensor。例如,要使用streaming_mean:

value = ...
mean_value, update_op = tf.contrib.metrics.streaming_mean(values)
sess.run(tf.local_variables_initializer())

for i in range(number_of_batches):
  print('Mean after batch %d: %f' % (i, update_op.eval())
print('Final Mean: %f' % mean_value.eval())

每个指标函数将节点添加到图表中,该图保持计算指标值所需的状态以及实际执行计算的一组操作。每个指标评估由三个步骤组成

  • 初始化:初始化指标标准状态
  • 聚合:更新指标标准状态的值
  • 完成:计​​算最终的指标值

在上面的例子中,调用 streaming_mean 创建一对状态变量,它们将包含(1)运行总和和(2)总和中样本数的计数。因为流式指标使用局部变量,所以初始化阶段通过运行tf.local_variables_initializer() 返回的操作来执行的。它将 sum 和 count 变量设置为零。

接下来,通过 values 适当地检查状态变量的当前状态和递增来执行聚合.此步骤通过运行由指标数返回的 update_op 来执行。

最后,通过评估 “value_tensor”

实际上,我们通常希望评估多个批次和多个指标。为此,我们只需要多次运行指标计算操作:

labels = ...
predictions = ...
accuracy, update_op_acc = tf.contrib.metrics.streaming_accuracy(
    labels, predictions)
error, update_op_error = tf.contrib.metrics.streaming_mean_absolute_error(
    labels, predictions)

sess.run(tf.local_variables_initializer())
for batch in range(num_batches):
  sess.run([update_op_acc, update_op_error])

accuracy, error = sess.run([accuracy, error])

请注意,在不同输入上多次评估相同的指标时,必须指定每个指标的范围,以避免将结果累加在一起:

labels = ...
predictions0 = ...
predictions1 = ...

accuracy0 = tf.contrib.metrics.accuracy(labels, predictions0, name='preds0')
accuracy1 = tf.contrib.metrics.accuracy(labels, predictions1, name='preds1')

某些指标(如 streaming_mean 或 streaming_accuracy)可以通过 weights 参数加权。权重张量的大小必须作为指标的加权平均值的标签,并且要和预测张量和结果相同。

TensorFlow 指标“操作”

  • tf.contrib.metrics.streaming_accuracy
  • tf.contrib.metrics.streaming_mean
  • tf.contrib.metrics.streaming_recall
  • tf.contrib.metrics.streaming_recall_at_thresholds
  • tf.contrib.metrics.streaming_precision
  • tf.contrib.metrics.streaming_precision_at_thresholds
  • tf.contrib.metrics.streaming_auc
  • tf.contrib.metrics.streaming_recall_at_k
  • tf.contrib.metrics.streaming_mean_absolute_error
  • tf.contrib.metrics.streaming_mean_iou
  • tf.contrib.metrics.streaming_mean_relative_error
  • tf.contrib.metrics.streaming_mean_squared_error
  • tf.contrib.metrics.streaming_mean_tensor
  • tf.contrib.metrics.streaming_root_mean_squared_error
  • tf.contrib.metrics.streaming_covariance
  • tf.contrib.metrics.streaming_pearson_correlation
  • tf.contrib.metrics.streaming_mean_cosine_distance
  • tf.contrib.metrics.streaming_percentage_less
  • tf.contrib.metrics.streaming_sensitivity_at_specificity
  • tf.contrib.metrics.streaming_sparse_average_precision_at_k
  • tf.contrib.metrics.streaming_sparse_precision_at_k
  • tf.contrib.metrics.streaming_sparse_precision_at_top_k
  • tf.contrib.metrics.streaming_sparse_recall_at_k
  • tf.contrib.metrics.streaming_specificity_at_sensitivity
  • tf.contrib.metrics.streaming_concat
  • tf.contrib.metrics.streaming_false_negatives
  • tf.contrib.metrics.streaming_false_negatives_at_thresholds
  • tf.contrib.metrics.streaming_false_positives
  • tf.contrib.metrics.streaming_false_positives_at_thresholds
  • tf.contrib.metrics.streaming_true_negatives
  • tf.contrib.metrics.streaming_true_negatives_at_thresholds
  • tf.contrib.metrics.streaming_true_positives
  • tf.contrib.metrics.streaming_true_positives_at_thresholds
  • tf.contrib.metrics.auc_using_histogram
  • tf.contrib.metrics.accuracy
  • tf.contrib.metrics.aggregate_metrics
  • tf.contrib.metrics.aggregate_metric_map
  • tf.contrib.metrics.confusion_matrix

TensorFlow 设置“操作”

  • tf.contrib.metrics.set_difference
  • tf.contrib.metrics.set_intersection
  • tf.contrib.metrics.set_size
  • tf.contrib.metrics.set_union