Pytorch深度學習gather一些使用問題解決方案
問題場景描述
我在復現Faster-RCNN模型的過程中遇到這樣一個問題:
有一個張量,它的形狀是 (128, 21, 4)
roi_loc.shape = (128, 21, 4)
與之對應的還有一個label數據
gt_label.shape = (128)
我現在的需求是將label當作第一個張量在dim=1上的索引,將其中的數據拿出來。
具體來說就是,現在有128個樣本數據,每個樣本中有21個長度為4的向量。label也是128個,每個值代表取出21個向量中的哪一個。
問題的思考
我嘗試瞭很多辦法,包括佈爾索引,index_select方法等,最後發現都不適用(也有可能我沒用好)。最後利用gather API解決瞭這個問題。
這個API的說明我看瞭很多遍都沒看懂,我相信絕大部分讀者也是因為看不懂這個說明才來這兒的。
下面我給出自己的一些理解:
gather的說明
gather所需要的第一個參數是待索引的數據,在我們的問題中 roi_loc就是這個input。第二個參數dim,是你的索引數據要作用在哪個軸上,正如前面所言,我們想索引第二個軸(dim=1).
最難理解的是index,index就是我們想要用來索引的張量,對應的是label。可是label不能直接拿來用,得先做一定的變換,這也就是gather的難點。
我們先從簡單的情況來看
input和gather必須在維度上相同,假設數據還是3 * 3,index也是1 * 3的(註意這裡是二維的)
此時row至多取值0,col至多取值為2
如果我要對dim=0索引
那麼data[0][0] = data[index[0][0]] [0] = data[1][0] = 2
data[0][1] = data[index[0][1]] [1] = data[0][1] = 5
data[0][2] = data[index[0][2]][2] = data[2][2] = 9
上面的過程可以描述為,第一列的元素我想選第二行的,第二列的元素我想選第一行的,第三列的元素我想選第三行的。
可以發現因為index是1 * 3的,所以最後的輸出也是31* 3,即輸出張量的shape取決於index的shape
以上過程我相信讀者好好體悟應該可以理解。
問題的解決
回到我們的問題
roi_loc.shape = (128, 21, 4),gt_label.shape = (128)
我們想索引dim=1,最後的結果應該是(128, 4)
由上面的說明可以知道,input和index的dimension首先得相同
idx = gt_roi_labels.unsqueeze(-1).unsqueeze(-1) idx.shape = (128, 1, 1)
又因為我們想要輸出的結果得是(128, 4),所以得讓idx在最後一個軸上重復4次
idx = idx.repeat_interleave(-1, 4) idx.shape = (128, 1, 4)
現在就可以利用gather在dim=1上索引瞭
result = roi_loc.gather(1, idx) result.shape = (128, 1, 4)
最後將長度為1的軸壓縮(本身這個軸的出現是為瞭滿足input和index維度一樣的要求)
result = result.squeeze(1) result.shape(128, 4)
以上就是Pytorch深度學習gather一些使用問題解決方案的詳細內容,更多關於Pytorch學習gather使用問題的資料請關註WalkonNet其它相關文章!
推薦閱讀:
- pytorch下的unsqueeze和squeeze的用法說明
- 人工智能學習Pytorch教程Tensor基本操作示例詳解
- 淺談pytorch中stack和cat的及to_tensor的坑
- pytorch tensor計算三通道均值方式
- 解析Pytorch中的torch.gather()函數