Pytorch訓練模型得到輸出後計算F1-Score 和AUC的操作
1、計算F1-Score
對於二分類來說,假設batch size 大小為64的話,那麼模型一個batch的輸出應該是torch.size([64,2]),所以首先做的是得到這個二維矩陣的每一行的最大索引值,然後添加到一個列表中,同時把標簽也添加到一個列表中,最後使用sklearn中計算F1的工具包進行計算,代碼如下
import numpy as np import sklearn.metrics import f1_score prob_all = [] lable_all = [] for i, (data,label) in tqdm(train_data_loader): prob = model(data) #表示模型的預測輸出 prob = prob.cpu().numpy() #先把prob轉到CPU上,然後再轉成numpy,如果本身在CPU上訓練的話就不用先轉成CPU瞭 prob_all.extend(np.argmax(prob,axis=1)) #求每一行的最大值索引 label_all.extend(label) print("F1-Score:{:.4f}".format(f1_score(label_all,prob_all)))
2、計算AUC
計算AUC的時候,本次使用的是sklearn中的roc_auc_score () 方法
輸入參數:
y_true
:真實的標簽。形狀 (n_samples,) 或 (n_samples, n_classes)。二分類的形狀 (n_samples,1),而多標簽情況的形狀 (n_samples, n_classes)。
y_score
:目標分數。形狀 (n_samples,) 或 (n_samples, n_classes)。二分類情況形狀 (n_samples,1),“分數必須是具有較大標簽的類的分數”,通俗點理解:模型打分的第二列。舉個例子:模型輸入的得分是一個數組 [0.98361117 0.01638886],索引是其類別,這裡 “較大標簽類的分數”,指的是索引為 1 的分數:0.01638886,也就是正例的預測得分。
average='macro'
:二分類時,該參數可以忽略。用於多分類,’ micro ‘:將標簽指標矩陣的每個元素看作一個標簽,計算全局的指標。’ macro ‘:計算每個標簽的指標,並找到它們的未加權平均值。這並沒有考慮標簽的不平衡。’ weighted ‘:計算每個標簽的指標,並找到它們的平均值,根據支持度 (每個標簽的真實實例的數量) 進行加權。
sample_weight=None
:樣本權重。形狀 (n_samples,),默認 = 無。
max_fpr=None
:
multi_class='raise'
:(多分類的問題在下一篇文章中解釋)
labels=None
:
輸出:
auc
:是一個 float 的值。
import numpy as np import sklearn.metrics import roc_auc_score prob_all = [] lable_all = [] for i, (data,label) in tqdm(train_data_loader): prob = model(data) #表示模型的預測輸出 prob_all.extend(prob[:,1].cpu().numpy()) #prob[:,1]返回每一行第二列的數,根據該函數的參數可知,y_score表示的較大標簽類的分數,因此就是最大索引對應的那個值,而不是最大索引值 label_all.extend(label) print("AUC:{:.4f}".format(roc_auc_score(label_all,prob_all)))
補充:pytorch訓練模型的一些坑
1. 圖像讀取
opencv的python和c++讀取的圖像結果不一致,是因為python和c++采用的opencv版本不一樣,從而使用的解碼庫不同,導致讀取的結果不同。
2. 圖像變換
PIL和pytorch的圖像resize操作,與opencv的resize結果不一樣,這樣會導致訓練采用PIL,預測時采用opencv,結果差別很大,尤其是在檢測和分割任務中比較明顯。
3. 數值計算
pytorch的torch.exp與c++的exp計算,10e-6的數值時候會有10e-3的誤差,對於高精度計算需要特別註意,比如
兩個輸入5.601597, 5.601601, 經過exp計算後變成270.85862343143174, 270.85970686809225
以上為個人經驗,希望能給大傢一個參考,也希望大傢多多支持WalkonNet。
推薦閱讀:
- 回歸預測分析python數據化運營線性回歸總結
- 人工智能-Python實現嶺回歸
- Python之Sklearn使用入門教程
- python生成器generator:深度學習讀取batch圖片的操作
- Python實現DBSCAN聚類算法並樣例測試