阅读(21.1k) 书签 (0)

TensorFlow:tf.map_fn函数

2018-10-30 17:51 更新
函数:tf.map_fn
map_fn(
    fn,
    elems,
    dtype=None,
    parallel_iterations=10,
    back_prop=True,
    swap_memory=False,
    infer_shape=True,
    name=None
)

定义在:tensorflow/python/ops/functional_ops.py.

参见指南:高阶函数>高阶运算符

从0维度的 elems 中解压的张量列表上的映射.

map_fn 的最简单版本反复地将可调用的 fn 应用于从第一个到最后一个的元素序列.这些元素由 elems 解压缩的张量构成.dtype 是 fn 的返回值的数据类型.如果与elems 的数据类型不同,用户必须提供 dtype.

假设 elems 被打包成 values、张量列表.结果张量的形状是:[values.shape[0]] + fn(values[0]).shape.

这种方法也允许 fn 的多元 elems 和输出.如果 elems 是(可能是嵌套的)列表或元素的张量,则这些张量中的每一个必须具有匹配的第一(unpack)维度.签名fn可能匹配的结构elems.也就是说,如果 elems 是:(t1, [t2, t3, [t4, t5]]),则 fn 的适当签名为:fn = lambda (t1, [t2, t3, [t4, t5]]):.

此外,fn 可能会发出与其输入不同的结构.例如,fn 可能看起来像:fn = lambda t1: return (t1 + 1, t1 - 1).在这种情况下,dtype 参数不是可选的:dtype 必须是与 fn的输出匹配的类型或(可能是嵌套的)元组.

要将函数操作应用于 SparseTensor 的非零元素,建议使用以下方法之一.首先,如果函数可以表示为 TensorFlow ops,请使用:

result = SparseTensor(input.indices, fn(input.values), input.dense_shape)

但是,如果该函数不能作为 TensorFlow op 表示,则使用:

result = SparseTensor(
  input.indices, map_fn(fn, input.values), input.dense_shape)

参数:

  • fn:可调用的执行.它接受一个参数,它将具有与之相同的(可能嵌套的)结构 elems.其输出必须具有与 dtype 相同的结构(如果提供了),否则它必须具有与elems 相同的结构.
  • elems:张量或(可能是嵌套的)张量序列,其中的每一个都将沿着它们的第一维度进行解压.生成的切片的嵌套序列将应用于 fn.
  • dtype:(可选)fn 的输出类型.如果 fn 返回与 elems 结构不同的张量结构,则 dtype 不是可选的,并且必须具有与 fn 的输出相同的结构.
  • parallel_iterations:(可选)允许并行运行的迭代次数.
  • back_prop:(可选)True 允许支持反向传播.
  • swap_memory:(可选)True 可实现 GPU-CPU 内存交换.
  • infer_shape:(可选)False 禁用对一致输出形状的测试.
  • name:(可选)返回的张量的名称前缀.

返回值:

该函数返回张量或(可能是嵌套的)张量序列.每个张量都将 fn 的结果应用到从第一个维度的 elems,从第一个到最后一个.

可能发生的异常:

  • TypeError:如果 fn 不是可调用或 fn 的输出的结构和 dtype 不匹配,或者 elems 是 SparseTensor.
  • ValueError:如果 fn 的输出长度 和 dtype 不匹配.

例子:

elems = np.array([1, 2, 3, 4, 5, 6])
squares = map_fn(lambda x: x * x, elems)
# squares == [1, 4, 9, 16, 25, 36]

elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
# alternate == [-1, 2, -3]

elems = np.array([1, 2, 3])
alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
# alternates[0] == [1, 2, 3]
# alternates[1] == [-1, -2, -3]