TensorFlow函数:tf.scatter_nd_update
2018-01-09 10:16 更新
tf.scatter_nd_update 函数
scatter_nd_update(
ref,
indices,
updates,
use_locking=True,
name=None
)
请参阅指南:变量>稀疏变量更新
根据indices对给定变量中的单个值或切片应用稀疏updates.
在这个函数中,ref是一个秩为P的Tensor,indices是一个秩为Q的Tensor.
indices必须是整数张量,包含索引到ref.它一定有形状:[d_0, ..., d_{Q-2}, K],并且是:0<K<=P.
indices(具有长度K)的最内部维度对应于沿着ref的K维度的元素(if K = P)或切片(if K < P)的索引.
updates是具有形状的秩为Q-1+P-K的Tensor:
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]
例如,假设我们想把4个分散的元素更新为一个rank-1张量到8个元素.在Python中,该更新将如下所示:
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1] ,[7]])
updates = tf.constant([9, 10, 11, 12])
update = tf.scatter_nd_update(ref, indices, updates)
with tf.Session() as sess:
print sess.run(update)
对ref的结果更新如下所示:
[1, 11, 3, 10, 9, 6, 7, 12]
请参阅tf.scatter_nd有关如何更新切片的更多详细信息.
函数参数
- ref:一个可变的Tensor;应该来自一个变量节点.
- indices:一个Tensor.必须是以下类型之一:int32,int64;索引到ref的一个张量.
- updates:一个Tensor.必须与ref具有相同的类型;要添加到ref更新值的张量.
- use_locking:可选的bool;如果为True,则赋值将受锁定的保护;否则行为是不确定的,但可能表现出较少的争用.
- name:操作的名称(可选).
函数返回值
使用tf.scatter_nd_update函数能够返回一个可变的Tensor.与ref有相同的类型;与 ref 一样,返回为希望在更新完成后使用更新的值的操作的方便性.