阅读(9.3k) 书签 (0)

TensorFlow 创建case

2018-09-06 10:36 更新

tf.case

case (  
    pred_fn_pairs ,  
    default ,  
    exclusive = False ,  
    strict = False ,  
    name = 'case' 
)

定义在:tensorflow/python/ops/control_flow_ops.py

参见指南:控制流程>控制流程操作

创建案例操作.

pred_fn_pairs 参数是字典或大小对的列表.每对都包含一个布尔标量张量和一个可调用的 python, 如果布尔计算结果为 True, 则创建要返回的数量.默认值是一个可调用的生成张量列表.pred_fn_pairs 中的所有 callables 以及默认值都应返回相同的张量和类型.
如果 exclusive = = true, 则计算所有谓词, 如果有多个谓词的计算结果为 true, 则引发异常.如果 exclusive = = False, 则执行停止是计算结果为 True 的第一个谓词, 并立即返回相应函数生成的张量.如果没有任何谓词计算为 True, 则此操作返回默认情况下生成的张量.

tf.case 支持在 tensorflow.python.util.nest 中实现的嵌套结构.所有的 callables 必须返回列表、元组和/或命名元组的相同 (可能是嵌套的) 值结构.单例列表和元组是唯一的例外: 当由可调用返回时, 它们被隐式解压到单个值.通过 strict = True 来禁用此行为.
如果使用无序字典 pred_fn_pairs, 则不保证条件测试的顺序.但是, 该顺序保证是确定性的, 以便在条件分支中创建的变量在运行时按固定顺序创建.

示例1:伪码:

if (x < y) return 17;
else return 23;

表达式:

f1 = lambda: tf.constant(17)
f2 = lambda: tf.constant(23)
r = case([(tf.less(x, y), f1)], default=f2)

示例2:伪码:

if (x < y && x > z) raise OpError("Only one predicate may evaluate true");
if (x < y) return 17;
else if (x > z) return 23;
else return -1;

表达式:

def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = case({tf.less(x, y): f1, tf.greater(x, z): f2},
         default=f3, exclusive=True)

ARGS:

  • pred_fn_pairs:字典或一组布尔标量张量和一个可调用的列表,其返回张量列表.
  • default:可以返回张量列表的可调用函数.
  • exclusive:如果允许最多一个谓词可以评估,则为真True.
  • strict:启用/禁用“严格”模式的布尔值.
  • name:此操作的名称(可选).

返回:

由谓词计算为 True 的第一对返回的张量,或者默认情况下返回的张量 (如果没有).

注意:

  • TypeError:当 pred_fn_pairs 不是列表/字典.
  • TypeError:当 pred_fn_pairs 是一个列表,但不包含2元组.
  • TypeError:当 fns [i] 对任何 i 不可调用, 或者默认不可调用.