阅读(13.7k) 书签 (0)

TensorFlow函数:tf.where

2018-04-19 10:47 更新

tf.where函数

tf.where(
    condition,
    x=None,
    y=None,
    name=None
)

定义在:tensorflow/python/ops/array_ops.py.

请参阅指南:控制流程>比较运算符,数学函数>序列比较和索引

根据condition返回x或y中的元素.

如果x和y都为None,则该操作将返回condition中true元素的坐标.坐标以二维张量返回,其中第一维(行)表示真实元素的数量,第二维(列)表示真实元素的坐标.请记住,输出张量的形状可以根据输入中的真实值的多少而变化.索引以行优先顺序输出.

如果两者都不是None,则x和y必须具有相同的形状.如果x和y是标量,则condition张量必须是标量.如果x和y是更高级别的矢量,则condition必须是大小与x的第一维度相匹配的矢量,或者必须具有与x相同的形状.

condition张量作为一个可以选择的掩码(mask),它根据每个元素的值来判断输出中的相应元素/行是否应从 x (如果为 true) 或 y (如果为 false)中选择.

如果condition是向量,则x和y是更高级别的矩阵,那么它选择从x和y复制哪个行(外部维度).如果condition与x和y具有相同的形状,那么它将选择从x和y复制哪个元素.

函数参数:

  • condition:一个bool类型的张量(Tensor).
  • x:可能与condition具有相同形状的张量;如果condition的秩是1,则x可能有更高的排名,但其第一维度必须匹配condition的大小.
  • y:与x具有相同的形状和类型的张量.
  • name:操作的名称(可选).

返回值:

如果它们不是None,则返回与x,y具有相同类型与形状的张量;张量具有形状(num_true, dim_size(condition)).

可能引发的异常:

  • ValueError:当一个x或y正好不是None.