解決pytorch 的state_dict()拷貝問題
先說結論
model.state_dict()
是淺拷貝,返回的參數仍然會隨著網絡的訓練而變化。
應該使用deepcopy(model.state_dict())
,或將參數及時序列化到硬盤。
再講故事,前幾天在做一個模型的交叉驗證訓練時,通過model.state_dict()保存瞭每一組交叉驗證模型的參數,後根據效果選擇準確率最佳的模型load回去,結果每一次都是最後一個模型,從地址來看,每一個保存的state_dict()都具有不同的地址,但進一步發現state_dict()下的各個模型參數的地址是共享的,而我又使用瞭in-place的方式重置模型參數,進而導致瞭上述問題。
補充:pytorch中state_dict的理解
在PyTorch中,state_dict是一個Python字典對象(在這個有序字典中,key是各層參數名,value是各層參數),包含模型的可學習參數(即權重和偏差,以及bn層的的參數) 優化器對象(torch.optim)也具有state_dict,其中包含有關優化器狀態以及所用超參數的信息。
其實看瞭如下代碼的輸出應該就懂瞭
import torch import torch.nn as nn import torchvision import numpy as np from torchsummary import summary # Define model class TheModelClass(nn.Module): def __init__(self): super(TheModelClass, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # Initialize model model = TheModelClass() # Initialize optimizer optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Print model's state_dict print("Model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor,"\t", model.state_dict()[param_tensor].size()) # Print optimizer's state_dict print("Optimizer's state_dict:") for var_name in optimizer.state_dict(): print(var_name, "\t", optimizer.state_dict()[var_name])
輸出如下:
Model's state_dict: conv1.weight torch.Size([6, 3, 5, 5]) conv1.bias torch.Size([6]) conv2.weight torch.Size([16, 6, 5, 5]) conv2.bias torch.Size([16]) fc1.weight torch.Size([120, 400]) fc1.bias torch.Size([120]) fc2.weight torch.Size([84, 120]) fc2.bias torch.Size([84]) fc3.weight torch.Size([10, 84]) fc3.bias torch.Size([10]) Optimizer's state_dict: state {} param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]
我是剛接觸深度學西的小白一個,希望大佬可以為我指出我的不足,此博客僅為自己的筆記!!!!
補充:pytorch保存模型時報錯***object has no attribute ‘state_dict’
定義瞭一個類BaseNet並實例化該類:
net=BaseNet()
保存net時報錯 object has no attribute ‘state_dict’
torch.save(net.state_dict(), models_dir)
原因是定義類的時候不是繼承nn.Module類,比如:
class BaseNet(object): def __init__(self):
把類定義改為
class BaseNet(nn.Module): def __init__(self): super(BaseNet, self).__init__()
以上為個人經驗,希望能給大傢一個參考,也希望大傢多多支持WalkonNet。如有錯誤或未考慮完全的地方,望不吝賜教。
推薦閱讀:
- pytorch模型的保存和加載、checkpoint操作
- 關於Pytorch中模型的保存與遷移問題
- pytorch實現加載保存查看checkpoint文件
- pytorch 預訓練模型讀取修改相關參數的填坑問題
- Pytorch 統計模型參數量的操作 param.numel()