Pytorch模型參數的保存和加載

一、前言

在模型訓練完成後,我們需要保存模型參數值用於後續的測試過程。由於保存整個模型將耗費大量的存儲,故推薦的做法是隻保存參數,使用時隻需在建好模型的基礎上加載。

通常來說,保存的對象包括網絡參數值、優化器參數值、epoch值等。本文將簡單介紹保存和加載模型參數的方法,同時也給出保存整個模型的方法供大傢參考。

二、參數保存

在這裡我們使用 torch.save() 函數保存模型參數:

import torch
path = './model.pth'
torch.save(model.state_dict(), path)

model——指定義的模型實例變量,如model=net( )

state_dict()——state_dict( )是一個可以輕松地保存、更新、修改和恢復的python字典對象, 對於model來說,表示模型的每一層的權重及偏置等參數信息;對於 optimizer 來說,其包含瞭優化器的狀態以及被使用的超參數(如lr, momentum,weight_decay等)

path——path是保存參數的路徑,一般設置為 path='./model.pth' , path='./model.pkl'等形式。

此外,如果想保存某一次訓練采用的optimizer、epochs等信息,可將這些信息組合起來構成一個字典保存起來:

import torch
path = './model.pth'
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
torch.save(state, path)

三、參數的加載

使用 load_state_dict()函數加載參數到模型中, 當僅保存瞭模型參數,而沒有optimizer、epochs等信息時:

model.load_state_dict(torch.load(path))

model——事先定義好的跟原模型一致的模型

path——之前保存的模型參數文件

如若保存瞭optimizer、epochs等信息,我們這樣載入信息:

# 使用torch.load()函數將文件中字典信息載入 state_dict 變量中
state_dict = torch.load(path)
# 分佈加載參數到模型和優化器
model.load_state_dict(state_dict['model'])
optimizer.load_state_dict(state_dict['optimizer'])
epoch = state_dict(['epoch'])

我們還可以在每n個epoch後保存一次參數,以觀察不同迭代次數模型的表現此時我們可設置不同的path,如 path='./model' + str(epoch) +'.pth',這樣,不同epoch的參數就能保存在不同的文件中。

四、保存和加載整個模型

使用上文提到的方法即可:

torch.save(model, path)
model = torch.load(path)

五、總結

pytorch中state_dict()和load_state_dict()函數配合使用可以實現狀態的獲取與重載,load()和save()函數配合使用可以實現參數的存儲與讀取。掌握對應的函數使用方法就可以遊刃有餘地進行運用。

到此這篇關於Pytorch模型參數的保存和加載的文章就介紹到這瞭,更多相關Pytorch模型參數保存內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!

推薦閱讀: