解決pytorch 保存模型遇到的問題

今天用pytorch保存模型時遇到bug

Can’t pickle <class ‘torch._C._VariableFunctions’>

在google上查找原因,發現是保存時保存瞭整個模型的原因,而模型中有一些自定義的參數

torch.save(model,save_path) 改為 torch.save(model.state_dict(),save_path)

然後載入模型也做相應的更改就好瞭

補充:pytorch訓練模型的一些坑

1. 圖像讀取

opencv的python和c++讀取的圖像結果不一致,是因為python和c++采用的opencv版本不一樣,從而使用的解碼庫不同,導致讀取的結果不同。

2. 圖像變換

PIL和pytorch的圖像resize操作,與opencv的resize結果不一樣,這樣會導致訓練采用PIL,預測時采用opencv,結果差別很大,尤其是在檢測和分割任務中比較明顯。

3. 數值計算

pytorch的torch.exp與c++的exp計算,10e-6的數值時候會有10e-3的誤差,對於高精度計算需要特別註意,比如

兩個輸入5.601597, 5.601601, 經過exp計算後變成270.85862343143174, 270.85970686809225

以上為個人經驗,希望能給大傢一個參考,也希望大傢多多支持WalkonNet。如有錯誤或未考慮完全的地方,望不吝賜教。

推薦閱讀: