阅读(11.3k) 书签 (0)

TensorFlow函数:tf.strided_slice

2018-03-20 14:01 更新

tf.strided_slice函数

tf.strided_slice(
    input_,
    begin,
    end,
    strides=None,
    begin_mask=0,
    end_mask=0,
    ellipsis_mask=0,
    new_axis_mask=0,
    shrink_axis_mask=0,
    var=None,
    name=None
)

定义在:tensorflow/python/ops/array_ops.py

参见指南:张量变换>分割和连接

提取张量的一个分段切片(广义 python 数组索引).

而不是直接调用这个操作,大多数用户会想要使用 NumPy 的风格的切片语法(例如,tensor[..., 3:4:-1, tf.newaxis, 3]),它们通过 tf.Tensor.getitemtf.Variable.getitem来支持.此运算的接口是切片语法的低级编码.

粗略地说,这个运算从给定的 input_ 张量中提取一个尺寸 (end-begin)/stride 的片段.从 begin 片段指定的位置开始,继续添加 stride 索引,直到所有维度都不小于 end.请注意,步幅可能是负值,这会导致反向切片.

给定一个 Python 的切片 input[spec0, spec1, ..., specn],这个函数将被调用如下.

begin,end 与 strides 将是长度 n 的向量.一般 n 不等于 input_ 张量的等级.

在每个掩码字段(begin_mask、end_mask、ellipsis_mask、new_axis_mask、shrink_axis_mask)中,第 i 位将对应于第 i 个规范.

如果设置了 begin_mask 的第 i 位,则忽略 begin[i],并使用该维度中的最大范围来代替.end_mask 类似地工作,除了结束范围.

foo[5:,:,:3] 在 7x8x9 张量上相当于 foo[5:7,0:8,0:3].foo[::-1] 反转形状为 8 的张量.

如果设置了 ellipsis_mask 的第 i 位,则会在其他维度之间插入所需的许多未指定维度.ellipsis_mask 中只允许有一个非零位.

例如,foo[3:5,...,4:5] 在一个形状 10x3x3x10 张量就相当于:foo[3:5,:,:,4:5],并且:foo[3:5,...] 相当于 foo[3:5,:,:,:].

如果设置了 new_axis_mask 的第 i 位,则 begin,end 和 stride 被忽略,并且在输出张量中的该点处添加新的长度 1 维.

例如,foo[:4, tf.newaxis, :2] 会产生一个形状 (4, 1, 2)张量.

如果设置了 shrink_axis_mask 的第 i 位,则意味着第 i 规范将维度缩小 1.begin[i],end[i] 和 strides[i] 必须意味着维度中的尺寸 1 的切片.例如,在Python中,可能 foo[:, 3, :] 会导致 shrink_axis_mask 等于 2.

注:begin 和 end都是零索引,strides 条目必须非零.

t = tf.constant([[[1, 1, 1], [2, 2, 2]],
                 [[3, 3, 3], [4, 4, 4]],
                 [[5, 5, 5], [6, 6, 6]]])
tf.strided_slice(t, [1, 0, 0], [2, 1, 3], [1, 1, 1])  # [[[3, 3, 3]]]
tf.strided_slice(t, [1, 0, 0], [2, 2, 3], [1, 1, 1])  # [[[3, 3, 3],
                                                      #   [4, 4, 4]]]
tf.strided_slice(t, [1, -1, 0], [2, -3, 3], [1, -1, 1])  # [[[4, 4, 4],
                                                         #   [3, 3, 3]]]

函数参数:

  • input_:一个 Tensor.
  • begin:一个 int32 或 int64 Tensor.
  • end:一个 int32 或 int64 Tensor.
  • strides:一个 int32 或 int64 Tensor.
  • begin_mask:一个 int32 mask.
  • end_mask:一个 int32 mask.
  • ellipsis_mask:一个 int32 mask.
  • new_axis_mask:一个 int32 mask.
  • shrink_axis_mask:一个 int32 mask.
  • var:与 input_None 对应的变量
  • name:操作的名称(可选).

函数返回值:

tf.strided_slice函数返回一个与 input 具有相同的类型的 Tensor.