阅读(15.1k) 书签 (0)

将TensorFlow张量沿一个维度串联

2018-09-12 16:15 更新

tf.concat

concat ( 
    values , 
    axis , 
    name = 'concat' 
)

定义在:tensorflow/python/ops/array_ops.py.

参见指南:张量变换>张量的分割和连接

将张量沿一个维度串联.

将张量值的列表与维度轴串联在一起.如果 values[i].shape = [D0, D1, ... Daxis(i), ...Dn],则连接结果有形状.

[D0, D1, ... Raxis, ...Dn]

Raxis = sum(Daxis(i))

也就是说,输入张量的数据将沿轴维度连接.
输入张量的维数必须匹配, 并且除坐标轴外的所有维度必须相等.

例如:

T1 =  [ [ 1 , 2 , 3 ] , [ 4 , 5 , 6 ] ] 
T2 =  [ [ 7 , 8 , 9 ] , [ 10 , 11 , 12 ] ] 
tf.concat([T1 ,T2] ,0) == >  [[1 , 2 ,3 ],[4 ,5 ,6],[7 ,8 ,9],[10 ,11,12]] 
tf.concat([T1 ,T2] ,1) == >  [[ 1 ,2 ,3 ,7 ,8 ,9 ],[4 ,5 ,6,10 ,11 ,12]]

#张量 t3 的形状[2,3] 
#张量 t4 的形状[2,3] 
tf.shape(tf.concat([ t3 , t4 ] , 0 )) == >  [ 4 , 3 ] 
tf.shape( tf.concat([t3 ,t4 ] , 1 )) == >  [ 2 , 6 ]
注意:如果沿着新轴连接,请考虑使用堆栈.例如:
tf.concat ([ tf.expand_dims (t ,axis) for t in tensors] ,axis)

可以重写为

tf.stack(tensors,axis = axis)

ARGS:

  • values:张量对象或单个张量列表.
  • axis:0 维 int32 张量,要连接的维度.
  • name:操作的名称(可选).

返回:

由输入张量的连接引起的张量.