本文共 2103 字,大约阅读时间需要 7 分钟。
Several simple examples showing the usage of scatter_nd_update is provided in the tensor flow official document( accessible via ). However, this example only shows its usage on 1 dimensional tensor. It cost me quite a time to use it on multi dimensional tensor. Meanwhile, few examples about its usage on multi dimensional tensor can be found on the web. Following shows three examples I have successfully finished. Before that, first shows the example from the tensor flow official document.
Example 1:( from tensorflow documentation)
For example, say we want to update 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that update would look like this:
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])indices = tf.constant([[4], [3], [1] ,[7]])updates = tf.constant([9, 10, 11, 12])update = tf.scatter_nd_update(ref, indices, updates)with tf.Session() as sess: print sess.run(update)
The resulting update to ref would look like this:
[1, 11, 3, 10, 9, 6, 7, 12]
Next are two examples written by me.
Example 2:
>>> ref = tf.Variable(tf.ones([2,3],tf.int32)) >>> updates = tf.constant([[0,0,0]])>>> update = tf.scatter_nd_update(ref,[[0]],updates) >>> init = tf.global_variables_initializer()>>> sess.run(init)>>> sess.run(update)array([[0, 0, 0], [1, 1, 1]], dtype=int32)
Example 3:
>>> ref = tf.Variable(tf.ones([2,3,3],tf.int32))>>> indices = tf.constant([[0,1]])#>>> updates = tf.constant([0,0,0]) #wrong>>> updates = tf.constant([[0,0,0]])#correct>>> update = tf.scatter_nd_update(ref,indices,updates) >>> init = tf.global_variables_initializer()>>> sess.run(init)>>> print(ref.eval())[[[1 1 1] [1 1 1] [1 1 1]] [[1 1 1] [1 1 1] [1 1 1]]]>>> sess.run(update)array([[[1, 1, 1], [0, 0, 0], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]], dtype=int32)
Example 4:
>>> updates = tf.constant([0])>>> indices = tf.constant([[1,0,1]])>>> init = tf.global_variables_initializer()>>> sess.run(init)>>> update = tf.scatter_nd_update(ref,indices,updates)>>> sess.run(update)array([[[1, 1, 1], [1, 1, 1], [0, 0, 0]], [[1, 0, 1], [1, 1, 1], [1, 1, 1]]], dtype=int32)
转载地址:http://moqbi.baihongyu.com/