pytorch中nn.Flatten()函數詳解及示例
torch.nn.Flatten(start_dim=1, end_dim=- 1)
作用:將連續的維度范圍展平為張量。 經常在nn.Sequential()中出現,一般寫在某個神經網絡模型之後,用於對神經網絡模型的輸出進行處理,得到tensor類型的數據。
有倆個參數,start_dim和end_dim,分別表示開始的維度和終止的維度,默認值分別是1和-1,其中1表示第一維度,-1表示最後的維度。結合起來看意思就是從第一維度到最後一個維度全部給展平為張量。(註意:數據的維度是從0開始的,也就是存在第0維度,第一維度並不是真正意義上的第一個)
同理,如果我這麼寫:
self.flat = nn.Flatten(start_dim=2, end_dim=3)
那麼意思就是從第二維度開始,到第三維度全部給展平,也就是將2、3兩個維度展平。
官網給出的示例:
input = torch.randn(32, 1, 5, 5) # With default parameters m = nn.Flatten() output = m(input) output.size() #torch.Size([32, 25]) # With non-default parameters m = nn.Flatten(0, 2) output = m(input) output.size() #torch.Size([160, 5])
#開頭的代碼是註釋
整段代碼的意思是:給定一個維度為(32,1,5,5)的隨機數據。
1.先使用一次nn.Flatten(),使用默認參數:
m = nn.Flatten()
也就是說從第一維度展平到最後一個維度,數據的維度是從0開始的,第一維度實際上是數據的第二個位置代表的維度,也就是樣例中的1。
因此進行展平後的結果也就是[32,1×5×5]➡[32,25]
2.接著再使用一次指定參數的nn.Flatten(),即
m = nn.Flatten(0, 2)
也就是說從第0維度展平到第2維度,0~2,對應的也就是前三個維度。
因此結果就是[32×1×5,5]➡[160,5]
因此進行展平後的結果也就是[32,1*5*5]➡[32,25]
示例1
卷積公式
import torch import torch.nn as nn input = torch.randn(32, 1, 5, 5) m = nn.Sequential( nn.Conv2d(1, 32, 5, 1, 1), # 通過卷積,得到torch.size([32, 32, 3, 3] nn.Flatten()) output = m(input) print(output.size()) >> torch.Size([32, 288])
示例2
import torch import torch.nn as nn input = torch.randn(32, 1, 5, 5) m = nn.Sequential( nn.Conv2d(1, 32, 5, 1, 1), # 通過卷積,得到torch.size([32, 32, 3, 3] nn.Flatten(start_dim=0)) output = m(input) print(output.size()) >>torch.Size([9216])
總結
到此這篇關於pytorch中nn.Flatten()函數詳解的文章就介紹到這瞭,更多相關pytorch nn.Flatten()函數詳解內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!
推薦閱讀:
- Pytorch中torch.flatten()和torch.nn.Flatten()實例詳解
- 支持PyTorch的einops張量操作神器用法示例詳解
- pytorch中常用的乘法運算及相關的運算符(@和*)
- PyTorch零基礎入門之構建模型基礎
- Pytorch 統計模型參數量的操作 param.numel()