pytorch交叉熵損失函數的weight參數的使用
首先
必須將權重也轉為Tensor的cuda格式;
然後
將該class_weight作為交叉熵函數對應參數的輸入值。
class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()
補充:關於pytorch的CrossEntropyLoss的weight參數
首先這個weight參數比想象中的要考慮的多
你可以試試下面代碼
import torch import torch.nn as nn inputs = torch.FloatTensor([0,1,0,0,0,1]) outputs = torch.LongTensor([0,1]) inputs = inputs.view((1,3,2)) outputs = outputs.view((1,2)) weight_CE = torch.FloatTensor([1,1,1]) ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE) loss = ce(inputs,outputs) print(loss)
tensor(1.4803)
這裡的手動計算是:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803
加權呢?
import torch import torch.nn as nn inputs = torch.FloatTensor([0,1,0,0,0,1]) outputs = torch.LongTensor([0,1]) inputs = inputs.view((1,3,2)) outputs = outputs.view((1,2)) weight_CE = torch.FloatTensor([1,2,3]) ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE) loss = ce(inputs,outputs) print(loss)
tensor(1.6075)
手算發現,並不是單純的那權重相乘:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113
而是
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075
發現瞭麼,加權後,除以的是權重的和,不是數目的和。
我們再驗證一遍:
import torch import torch.nn as nn inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5]) outputs = torch.LongTensor([0,1,2,2]) inputs = inputs.view((1,3,4)) outputs = outputs.view((1,4)) weight_CE = torch.FloatTensor([1,2,3]) ce = nn.CrossEntropyLoss(weight=weight_CE) # ce = nn.CrossEntropyLoss(ignore_index=255) loss = ce(inputs,outputs) print(loss)
tensor(1.5472)
手算:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
loss3 = 0 + ln(e2 + e0 + e0) = 2.2395
loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943
求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472
可能有人對loss的CE計算過程有疑問,我這裡細致寫寫交叉熵的計算過程,就拿最後一個例子的loss4的計算說明
以上為個人經驗,希望能給大傢一個參考,也希望大傢多多支持WalkonNet。
推薦閱讀:
- Pytorch數據類型與轉換(torch.tensor,torch.FloatTensor)
- PyTorch中Tensor和tensor的區別及說明
- pytorch常用數據類型所占字節數對照表一覽
- pytorch 如何打印網絡回傳梯度
- pytorch 實現二分類交叉熵逆樣本頻率權重