阅读(23.7k) 书签 (0)

TensorFlow分割:tf.slice函数

2018-01-26 10:28 更新

tf.slice 函数

slice(
    input_,
    begin,
    size,
    name=None
)

定义在:tensorflow/python/ops/array_ops.py.

参见指南:张量变换>分割和连接

从张量中提取切片.

此操作从由begin指定位置开始的张量input中提取一个尺寸size的切片.切片size被表示为张量形状,其中size[i]是你想要分割的input的第i维的元素的数量.切片的起始位置(begin)表示为每个input维度的偏移量.换句话说,begin[i]是你想从中分割出来的input的“第i个维度”的偏移量.

请注意,tf.Tensor.__getitem__通常是执行切片的python方式,因为它允许您写foo[3:7, :-2],而不是tf.slice([3, 0], [4, foo.get_shape()[1]-2]).

begin是基于零的;size是一个基础.如果size[i]是-1,则维度i中的所有其余元素都包含在切片中.换句话说,这相当于设置:

size[i] = input.dim_size(i) - begin[i]

该操作要求:

0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n]

例如:

t = tf.constant([[[1, 1, 1], [2, 2, 2]],
                 [[3, 3, 3], [4, 4, 4]],
                 [[5, 5, 5], [6, 6, 6]]])
tf.slice(t, [1, 0, 0], [1, 1, 3])  # [[[3, 3, 3]]]
tf.slice(t, [1, 0, 0], [1, 2, 3])  # [[[3, 3, 3],
                                   #   [4, 4, 4]]]
tf.slice(t, [1, 0, 0], [2, 1, 3])  # [[[3, 3, 3]],
                                   #  [[5, 5, 5]]]

函数参数

  • input_:一个Tensor.
  • begin:一个int32或int64类型的Tensor.
  • size:一个int32或int64类型的Tensor.
  • name:操作的名称(可选).

函数返回

tf.slice函数返回与input具有相同类型的Tensor.