解決Pytorch修改預訓練模型時遇到key不匹配的情況
一、Pytorch修改預訓練模型時遇到key不匹配
最近想著修改網絡的預訓練模型vgg.pth,但是發現當我加載預訓練模型權重到新建的模型並保存之後。
在我使用新賦值的網絡模型時出現瞭key不匹配的問題
#加載後保存(未修改網絡) base_weights = torch.load(args.save_folder + args.basenet) ssd_net.vgg.load_state_dict(base_weights) torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 將新保存的網絡代替之前的預訓練模型 ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes']) net = ssd_net ... if args.resume: ... else: base_weights = torch.load(args.save_folder + args.basenet) #args.basenet為ssd_base.pth print('Loading base network...') ssd_net.vgg.load_state_dict(base_weights)
此時會如下出錯誤:
Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)
…
RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.
說明之前的預訓練模型 key參數為”0.weight”, “0.bias”,但是經過加載保存之後變為瞭”vgg.0.weight”, “vgg.0.bias”
我認為是因為本身的模型定義文件裡self.vgg = nn.ModuleList(base)這一句。
現在的問題是因為自己定義保存的模型key參數多瞭一個前綴。
可以通過如下語句進行修改,並加載
from collections import OrderedDict #導入此模塊 base_weights = torch.load(args.save_folder + args.basenet) print('Loading base network...') new_state_dict = **OrderedDict()** for k, v in base_weights.items(): name = k[4:] # remove `vgg.`,即隻取vgg.0.weights的後面幾位 new_state_dict[name] = v ssd_net.vgg.load_state_dict(new_state_dict)
此時就不會再出錯瞭。
參考瞭這個篇。修改一下就可以應用到自己的模型啦。
//www.jb51.net/article/214214.htm
二、pytorch加載預訓練模型遇到的問題:KeyError: ‘bn1.num_batches_tracked‘
最近在使用pytorch1.0加載resnet預訓練模型時,遇到的一個問題,在此記錄一下。
KeyError: ‘layer1.0.bn1.num_batches_tracked’
其實是使用的版本的問題,pytorch0.4.1之後在BN層加入瞭track_running_stats這個參數,
這個參數的作用如下:
訓練時用來統計訓練時的forward過的min-batch數目,每經過一個min-batch, track_running_stats+=1
如果沒有指定momentum, 則使用1/num_batches_tracked 作為因數來計算均值和方差(running mean and variance).
其實,這個參數沒啥用.但因為官方提供的預訓練模型是pytorch0.3版本訓練出來的,因此沒有這個參數.
所以,隻要過濾一下預訓練權重字典中的關鍵字即可,‘num_batches_tracked’.代碼例子,如下.
有問題的代碼:
def load_specific_param(self, state_dict, param_name, model_path): param_dict = torch.load(model_path) for i in state_dict: key = param_name + '.' + i state_dict[i].copy_(param_dict[key]) del param_dict
對’num_batches_tracked進行過濾:
def load_specific_param(self, state_dict, param_name, model_path): param_dict = torch.load(model_path) param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k} for i in state_dict: key = param_name + '.' + i if 'num_batches_tracked' in key: continue state_dict[i].copy_(param_dict[key]) del param_dict
以上為個人經驗,希望能給大傢一個參考,也希望大傢多多支持WalkonNet。
推薦閱讀:
- 解決pytorch 的state_dict()拷貝問題
- pytorch模型的保存和加載、checkpoint操作
- 關於Pytorch中模型的保存與遷移問題
- pytorch 預訓練模型讀取修改相關參數的填坑問題
- Pytorch模型參數的保存和加載