Pytorch模型參數的保存和加載
一、前言
在模型訓練完成後,我們需要保存模型參數值用於後續的測試過程。由於保存整個模型將耗費大量的存儲,故推薦的做法是隻保存參數,使用時隻需在建好模型的基礎上加載。
通常來說,保存的對象包括網絡參數值、優化器參數值、epoch值等。本文將簡單介紹保存和加載模型參數的方法,同時也給出保存整個模型的方法供大傢參考。
二、參數保存
在這裡我們使用 torch.save() 函數保存模型參數:
import torch path = './model.pth' torch.save(model.state_dict(), path)
model——指定義的模型實例變量,如model=net( )
state_dict()——state_dict( )是一個可以輕松地保存、更新、修改和恢復的python字典對象, 對於model來說,表示模型的每一層的權重及偏置等參數信息;對於 optimizer 來說,其包含瞭優化器的狀態以及被使用的超參數(如lr, momentum,weight_decay等)
path——path是保存參數的路徑,一般設置為 path='./model.pth' , path='./model.pkl'等形式。
此外,如果想保存某一次訓練采用的optimizer、epochs等信息,可將這些信息組合起來構成一個字典保存起來:
import torch path = './model.pth' state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch} torch.save(state, path)
三、參數的加載
使用 load_state_dict()函數加載參數到模型中, 當僅保存瞭模型參數,而沒有optimizer、epochs等信息時:
model.load_state_dict(torch.load(path))
model——事先定義好的跟原模型一致的模型
path——之前保存的模型參數文件
如若保存瞭optimizer、epochs等信息,我們這樣載入信息:
# 使用torch.load()函數將文件中字典信息載入 state_dict 變量中 state_dict = torch.load(path) # 分佈加載參數到模型和優化器 model.load_state_dict(state_dict['model']) optimizer.load_state_dict(state_dict['optimizer']) epoch = state_dict(['epoch'])
我們還可以在每n個epoch後保存一次參數,以觀察不同迭代次數模型的表現。此時我們可設置不同的path,如 path='./model' + str(epoch) +'.pth',這樣,不同epoch的參數就能保存在不同的文件中。
四、保存和加載整個模型
使用上文提到的方法即可:
torch.save(model, path) model = torch.load(path)
五、總結
pytorch中state_dict()和load_state_dict()函數配合使用可以實現狀態的獲取與重載,load()和save()函數配合使用可以實現參數的存儲與讀取。掌握對應的函數使用方法就可以遊刃有餘地進行運用。
到此這篇關於Pytorch模型參數的保存和加載的文章就介紹到這瞭,更多相關Pytorch模型參數保存內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!
推薦閱讀:
- pytorch模型的保存和加載、checkpoint操作
- 關於Pytorch中模型的保存與遷移問題
- Pytorch 中的optimizer使用說明
- pytorch實現加載保存查看checkpoint文件
- Pytorch中的學習率衰減及其用法詳解