Pytorch中torch.stack()函數的深入解析
一. torch.stack()函數解析
1. 函數說明:
1.1 官網:torch.stack(),函數定義及參數說明如下圖所示:
1.2 函數功能
沿一個新維度對輸入一系列張量進行連接,序列中所有張量應為相同形狀,stack 函數返回的結果會新增一個維度。也即是把多個2維的張量湊成一個3維的張量;多個3維的湊成一個4維的張量…以此類推,也就是在增加新的維度上面進行堆疊。
1.3 參數列表
- tensors :為一系列輸入張量,類型為turple和List
- dim :新增維度的(下標)位置,當dim = -1時默認最後一個維度;范圍必須介於 0 到輸入張量的維數之間,默認是dim=0,在第0維進行連接
- 返回值:輸出新增維度後的張量
2. 代碼舉例
2.1 dim = 0 : 在第0維進行連接,相當於在行上進行組合(輸入張量為一維,輸出張量為兩維)
import torch #二維輸入張量a,b a = torch.tensor([1, 2, 3]) b = torch.tensor([11, 22, 33]) c = torch.stack([a, b],dim=0)#在第0維進行連接,相當於在行上進行組合(輸入張量為一維,輸出張量為兩維) print(a) print(b) print(c)
輸出結果如下:
tensor([1, 2, 3])
tensor([11, 22, 33])
tensor([[ 1, 2, 3],
[11, 22, 33]])
2.2 dim = 1 :在第1維進行連接,相當於在對應行上面對列元素進行組合(輸入張量為一維,輸出張量為兩維)
import torch #二維輸入張量a,b a = torch.tensor([1, 2, 3]) b = torch.tensor([11, 22, 33]) c = torch.stack([a, b],dim=1)#在第1維進行連接,相當於在對應行上面對列元素進行組合(輸入張量為一維,輸出張量為兩維) print(a) print(b) print(c)
輸出結果如下:
tensor([1, 2, 3])
tensor([11, 22, 33])
tensor([[ 1, 11],
[ 2, 22],
[ 3, 33]])
2.3 dim=0:表示在第0維進行連接,相當於在通道維度上進行組合(輸入張量為兩維,輸出張量為三維),註意:此處輸入張量維度為二維,因此dim最大隻能為2。
import torch #二維輸入張量a,b a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]]) c = torch.stack([a, b],dim=0)#在第0維進行連接,相當於在通道維度上進行組合(輸入張量為兩維,輸出張量為三維) print(a) print(b) print(c)
輸出結果如下所示:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[11, 22, 33],
[44, 55, 66],
[77, 88, 99]])
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],[[11, 22, 33],
[44, 55, 66],
[77, 88, 99]]])
2.4 dim=1:表示在第1維進行連接,相當於對相應通道中每個行進行組合,註意:此處輸入張量維度為二維,因此dim最大隻能為2。
import torch #二維輸入張量a,b a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]]) c = torch.stack([a, b], 1)#在第1維進行連接,相當於對相應通道中每個行進行組合 print(a) print(b) print(c)
輸出結果如下所示:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[11, 22, 33],
[44, 55, 66],
[77, 88, 99]])
tensor([[[ 1, 2, 3],
[11, 22, 33]],[[ 4, 5, 6],
[44, 55, 66]],[[ 7, 8, 9],
[77, 88, 99]]])
2.5 dim=2:表示在第2維進行連接,相當於對相應行中每個列元素進行組合,註意:此處輸入張量維度為二維,因此dim最大隻能為2。
import torch #二維輸入張量a,b a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]]) c = torch.stack([a, b], 2)#在第2維進行連接,相當於對相應行中每個列元素進行組合 print(a) print(b) print(c)
輸出結果如下所示:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[11, 22, 33],
[44, 55, 66],
[77, 88, 99]])
tensor([[[ 1, 11],
[ 2, 22],
[ 3, 33]],[[ 4, 44],
[ 5, 55],
[ 6, 66]],[[ 7, 77],
[ 8, 88],
[ 9, 99]]])
2.6 dim=3:表示在第3維進行連接,相當於對相應行中每個列元素進行組合(輸入維度大小為3維,因此dim=3最後一維始終代表為列),註意:此處輸入張量維度為三維,因此dim最大隻能為3。
import torch #三維輸入張量a,b a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]]) b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]]) c = torch.stack([a, b], 3)#表示在第3維進行連接,相當於對相應行中每個列元素進行組合(最後一維是第三維,始終代表為列) print(a) print(b) print(c)
輸出結果如下所示:
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],[[10, 20, 30],
[40, 50, 60],
[70, 80, 90]]])
tensor([[[ 11, 22, 33],
[ 44, 55, 66],
[ 77, 88, 99]],[[110, 220, 330],
[440, 550, 660],
[770, 880, 990]]])
tensor([[[[ 1, 11],
[ 2, 22],
[ 3, 33]],[[ 4, 44],
[ 5, 55],
[ 6, 66]],[[ 7, 77],
[ 8, 88],
[ 9, 99]]],[[[ 10, 110],
[ 20, 220],
[ 30, 330]],[[ 40, 440],
[ 50, 550],
[ 60, 660]],[[ 70, 770],
[ 80, 880],
[ 90, 990]]]])
2.7 dim=4 (錯誤維度:因為此處輸入張量維度為三維,所以dim最大隻能為3,此處維度為4,因此會報錯)
import torch #三維輸入張量a,b a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]]) b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]]) c = torch.stack([a, b], 4) print(a) print(b) print(c)
輸出錯誤:
IndexError: Dimension out of range (expected to be in range of [-4, 3], but got 4)
總結
到此這篇關於Pytorch中torch.stack()函數的文章就介紹到這瞭,更多相關Pytorch torch.stack()函數內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!
推薦閱讀:
- 聊聊Pytorch torch.cat與torch.stack的區別
- pytorch教程之Tensor的值及操作使用學習
- pytorch下的unsqueeze和squeeze的用法說明
- pytorch中的廣播語義
- pytorch中[…, 0]的用法說明