阅读(15.9k) 书签 (0)

TensorFlow 返回张量的最大值索引

2020-07-16 10:59 更新

tf.argmax


argmax ( 
input ,
axis = None ,
name = None ,
dimension = None
)

定义在tensorflow/python/ops/math_ops.py.

参考指南:数学>序列比较和索引

返回在张量的坐标轴上具有的最大值的索引.

请注意,在关联的情况下,返回值的身份不能保证.

ARGS:

  • input:张量,必须是下列类型之一:float32,float64,int64,int32,uint8,uint16,int16,int8,complex64,complex128,qint8,quint8,qint32,half.
  • axis:张量,必须是以下类型之一:int32,int64.当类型是 int32 时,要满足:0 <= axis < rank(input),描述输入向量的哪个轴减少.对于矢量,使用 axis = 0.
  • name:操作的名称(可选).

返回:

返回张量的 int 64 类型.


代码示例:

  1. import tensorflow as tf
  2.  
  3. Vector = [1,1,2,5,3]           #定义一个向量
  4. X = [[1,3,2],[2,5,8],[7,5,9]]  #定义一个矩阵
  5.  
  6. with tf.Session() as sess:
  7.     a = tf.argmax(Vector, 0)
  8.     b = tf.argmax(X, 0)
  9.     c = tf.argmax(X, 1)
  10.     
  11.     print(sess.run(a))
  12.     print(sess.run(b))
  13.     print(sess.run(c))

运行结果: 

  1. 3
  2. [2 1 2]
  3. [1 2 2]