Skip to main content

tf.gather用法

tf.gather(params, indices, validate_indices=None, name=None, axis=0)
Gather slices from `params` axis `axis` according to `indices`.
从'params'的'axis'维根据'indices'的参数值获取切片。就是在axis维根据indices取某些值。
测试代码:
import tensorflow as tf
import numpy as np
print("\n先测试一维张量\n")
t=np.random.randint(1,10,5)
g1=tf.gather(t,[2,1,4])
sess=tf.Session()
print(t)
print(sess.run(g1))
print("\n再测试二维张量\n")
t=np.random.randint(1,10,[4,5])
g2=tf.gather(t,[1,2,2],axis=0)
g3=tf.gather(t,[1,2,2],axis=1)
print(t)
print(sess.run(g2))
print(sess.run(g3))

结果如下:
先测试一维张量 [7 4 7 1 3] [7 4 3] 再测试二维张量 [[5 5 7 4 3] [8 7 6 5 2] [6 9 4 4 8] [7 3 3 2 2]] [[8 7 6 5 2] [6 9 4 4 8] [6 9 4 4 8]] [[5 7 7] [7 6 6] [9 4 4] [3 3 3]]


Comments

Popular posts from this blog

Session Run的用法

feed_dict参数的作用是替换图中的某个tensor的值。例如: a = tf.add(2, 5)                        #a=7 b = tf.multiply(a, 3)                 #b=3*7=21 with tf.Session() as sess:     print(sess.run(b))     replace_dict = {a:15}           #用15代替b算式中的a     print(sess.run(b, feed_dict = replace_dict)) --------------------- 输出如下: 21 45