阅读(9.1k) 书签 (0)

TensorFlow函数:tf.unstack

2018-04-14 11:05 更新

tf.unstack函数

tf.unstack(
    value,
    num=None,
    axis=0,
    name='unstack'
)

定义在:tensorflow/python/ops/array_ops.py.

参见指南:张量变换>分割和连接

将秩为 R 的张量的给定维度出栈为秩为 (R-1) 的张量.

通过沿 axis 维度将 num 张量从 value 中分离出来.如果没有指定 num(默认值),则从 value 的形状推断.如果 value.shape[axis] 不知道,则引发 ValueError.

例如,给定一个具有形状 (A, B, C, D) 的张量.

  • 如果 axis == 0,那么 output 中的第 i 个张量就是切片 value[i, :, :, :],并且 output 中的每个张量都具有形状 (B, C, D).(请注意,出栈的维度已经消失,不像split).
  • 如果 axis == 1,那么 output 中的第 i 个张量就是切片 value[:, i, :, :],并且 output 中的每个张量都具有形状 (A, C, D).

这与堆栈(stack.)相反,numpy 相当于:

tf.unstack(x, n) = np.unstack(x)

函数参数:

  • value:一个要出栈的秩 R > 0 的 Tensor.
  • num:一个 int,维度 axis 的长度,如果为 None(默认值),则自动推断.
  • axis:一个 int,沿着这个轴出栈,默认为第一维,负值环绕,所以有效范围是 [-R, R).
  • name:操作的名称(可选).

函数返回值:

tf.unstac k函数从 value 中出栈的 Tensor 对象列表.

可能引发的异常:

  • ValueError:如果 num 没有指定并且无法推断.
  • ValueError:如果 axis 超出范围 [-R,R).