tensorflow中的shape的不完全指南
最近做实验, 因为tensorflow中Tensor的shape的问题大伤脑筋,特别是自己实现的较为复杂的网络层,例如对输入加一个RBF核,或者图卷积等等复杂的操作都需要写代码时对整个流程中tensor的shape有一定把握。否则就很容易遇到error或者bug。
下面就将tf中的shape好好梳理一下。
文中代码在Github上,用jupyter notebook打开。
1 | import numpy as np |
静态与动态
tf中的每个tensor
都有两种shape,一种是静态(static)shap,一种是动态(dynamic)shape。 静态形状是由我们的操作推断出的形状(inferred),而动态形状是运行时真实的形状。
稍微熟悉tf的同学可能会发现,tf中提供了两个获取tensor形状的函数,一个是tf.Tensor.get_shape
,另一个就是tf.shape
,这两个函数就和两种shape一一对应,get_shape
用于获取静态shape,而tf.shape
用来获取动态信息。
下面我们来看个例子:
1 | x = tf.placeholder(tf.int32, shape=[5]) |
[5]
[None]
Tensor("Shape_13:0", shape=(1,), dtype=int32)
Tensor("Shape_14:0", shape=(1,), dtype=int32)
可以看出来这里我们创建了两个tensor,x指定了具体的形状,y只确定了维度而没有给出具体的形状,我们使用get_shape
时就分别得到(5,)
和(?,)
的结果,?表示这一维的形状无法确定。而此时由于并没有运行,所以tf.shape
只是给出了一个tensor来表示形状,而这个tensor的值要运行了才能知道。
下面我们起一个session来运行一下,看一下它们的动态shape。
1 | sess = tf.Session() |
[5]
[3]
这样我们就获得了x和y的动态shape,相信大家都已经发现了,其实tf.shape
给出了的是一个tensor,因此需要运行,而因为涉及到运行时,所以tf.shape
,运行的答案中不会含有未知的“?”。 而get_shape
则是直接给出了形状的表示,需要注意的是,get_shape
给出的并不是一个list,而是一个TensorShape
,如有需要,可以用.as_list
来转换为list。
1 | print(type(x.get_shape())) |
<class 'tensorflow.python.framework.tensor_shape.TensorShape'>
[5]
改变形状
既然形状获取有两种方式,那么相应的,改变形状也有两个函数分别对应改变动态形状和静态形状。tf.Tensor.set_shape
会更新tensor的静态形状,常被用来提供额外的形状信息,而tf.reshape
则会创建一个新的具有不同形状的tensor。
1 | a = tf.placeholder(tf.int32, shape=[None]) |
before set shape: (?,)
after set shape: (4,)
而我们经常想做的是将一个tensor的实际形状改变,比如一个[3,3]的矩阵转换为一个[9,1]的向量。
1 | b = tf.placeholder(tf.int32, shape=[3,3]) |
(3, 3)
(9, 1)
举个栗子
我们经常会在神经网络中遇到tensor的乘法,此时往往我们使用tf.matmul
来完成,但是该函数只支持两个操作对象均为二维tensor(也就是矩阵)。有时我们会需要更高维tensor的乘法操作,例如一个NxMxP的tensor乘一个PxQ的矩阵,期望得到一个NxMxQ的tensor。下面我们就利用上面的芝士来创建一个更广义的乘法
1 | def NDmatmul(x, w): |