“ Flatten”在Keras中的作用是什么?
计算机/软件
0 513
0

我试图了解该Flatten功能在Keras中的作用。

下面是我的代码,它是一个简单的两层网络。它接收(3,2)的二维数据,并输出(1,4)的一维数据:

model = Sequential()
model.add(Dense(16, input_shape=(3, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')
x = np.array([[[1, 2], [3, 4], [5, 6]]])
y = model.predict(x)print y.shape

输出y为(1、4)的向量。但是,如果我删除该Flatten行,则会输出(1、3、4)的y。

根据我对神经网络的了解,该model.add(Dense(16, input_shape=(3, 2)))函数正在创建一个具有16个节点的隐藏的全连接层。每一个节点都连接到3x2输入元素中的每个值。因此,该第一层的输出处的16个节点已经“平坦”。因此,第一层的输出应为(1、16)。然后,第二层将此作为输入,并输出形状为(1、4)的数据。

因此,如果第一层的输出已经“平坦”并且为(1,16),为什么我需要使用Flatten行?

收藏
2021-02-02 10:58 更新 karry •  2148
共 1 个回答
高赞 时间
0

如果你阅读的Keras文档条目Dense,将会看到以下调用:

Dense(16, input_shape=(5,3))

这样将形成一个具有3个输入和16个输出的Dense网络,因此,如果D(X)将三维向量转换为16d向量,则输出将得到一个向量序列:[D(x[0,:])、D(x[1,:])、.、D(x[4,:])],(5,16)的向量。为了实现你指定的行为,可以先将Flatten输入到15维向量,然后应用Dense:

model = Sequential()
model.add(Flatten(input_shape=(3, 2)))
model.add(Dense(16))
model.add(Activation('relu'))
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')

在这里有一个可以解释的图像:

Via:https://stackoverflow.com/a/43237727/14964791

收藏
2021-02-02 11:28 更新 anna •  2853