TensorFlow 返回张量维度上最大值的索引
2020-07-16 11:03 更新
tf.arg_max
arg_max (
input ,
dimension ,
name = None
)
返回在张量维度上具有最大值的索引.
请注意,在关联的情况下,返回值的身份不能保证.
ARGS:
- input:张量.必须是下列类型之一:float32,float64,int64,int32,uint8,uint16,int16,int8,complex64,complex128,qint8,quint8,qint32,half.
- dimension:张量.必须是以下类型之一:int32,int64.当类型为 int32 时,应满足:0 <= dimension <rank(input).描述输入张量的哪个维度可以减少.对于向量,请使用dimension = 0.
- name:操作的名称(可选).
返回:
返回的张量类型为 int64.
示例
argmax的源代码
# pylint: disable=redefined-builtin
# TODO(aselle): deprecate arg_max
def argmax(input, axis=None, name=None, dimension=None):
if dimension is not None:
if axis is not None:
raise ValueError("Cannot specify both 'axis' and 'dimension'")
axis = dimension
elif axis is None:
axis = 0
return gen_math_ops.arg_max(input, axis, name)
如您所见,argmax在内部使用arg_max。另外,在代码中,我建议使用argmax,因为arg_max可能很快就会被弃用。