TensorFlow 创建case
tf.case
case ( pred_fn_pairs , default , exclusive = False , strict = False , name = 'case' )
定义在:tensorflow/python/ops/control_flow_ops.py
参见指南:控制流程>控制流程操作
创建案例操作.
pred_fn_pairs 参数是字典或大小对的列表.每对都包含一个布尔标量张量和一个可调用的 python, 如果布尔计算结果为 True, 则创建要返回的数量.默认值是一个可调用的生成张量列表.pred_fn_pairs 中的所有 callables 以及默认值都应返回相同的张量和类型.
如果 exclusive = = true, 则计算所有谓词, 如果有多个谓词的计算结果为 true, 则引发异常.如果 exclusive = = False, 则执行停止是计算结果为 True 的第一个谓词, 并立即返回相应函数生成的张量.如果没有任何谓词计算为 True, 则此操作返回默认情况下生成的张量.
tf.case 支持在 tensorflow.python.util.nest 中实现的嵌套结构.所有的 callables 必须返回列表、元组和/或命名元组的相同 (可能是嵌套的) 值结构.单例列表和元组是唯一的例外: 当由可调用返回时, 它们被隐式解压到单个值.通过 strict = True 来禁用此行为.
如果使用无序字典 pred_fn_pairs, 则不保证条件测试的顺序.但是, 该顺序保证是确定性的, 以便在条件分支中创建的变量在运行时按固定顺序创建.
示例1:伪码:
if (x < y) return 17; else return 23;
表达式:
f1 = lambda: tf.constant(17) f2 = lambda: tf.constant(23) r = case([(tf.less(x, y), f1)], default=f2)
示例2:伪码:
if (x < y && x > z) raise OpError("Only one predicate may evaluate true"); if (x < y) return 17; else if (x > z) return 23; else return -1;
表达式:
def f1(): return tf.constant(17) def f2(): return tf.constant(23) def f3(): return tf.constant(-1) r = case({tf.less(x, y): f1, tf.greater(x, z): f2}, default=f3, exclusive=True)
ARGS:
- pred_fn_pairs:字典或一组布尔标量张量和一个可调用的列表,其返回张量列表.
- default:可以返回张量列表的可调用函数.
- exclusive:如果允许最多一个谓词可以评估,则为真True.
- strict:启用/禁用“严格”模式的布尔值.
- name:此操作的名称(可选).
返回:
由谓词计算为 True 的第一对返回的张量,或者默认情况下返回的张量 (如果没有).
注意:
- TypeError:当 pred_fn_pairs 不是列表/字典.
- TypeError:当 pred_fn_pairs 是一个列表,但不包含2元组.
- TypeError:当 fns [i] 对任何 i 不可调用, 或者默认不可调用.