基於PyTorch實現一個簡單的CNN圖像分類器
pytorch中文網:https://www.pytorchtutorial.com/
pytorch官方文檔:https://pytorch.org/docs/stable/index.html
一. 加載數據
Pytorch的數據加載一般是用torch.utils.data.Dataset與torch.utils.data.Dataloader兩個類聯合進行。我們需要繼承Dataset來定義自己的數據集類,然後在訓練時用Dataloader加載自定義的數據集類。
1. 繼承Dataset類並重寫關鍵方法
pytorch的dataset類有兩種:Map-style datasets和Iterable-style datasets。前者是我們常用的結構,而後者是當數據集難以(或不可能)進行隨機讀取時使用。在這裡我們實現Map-style dataset。
繼承torch.utils.data.Dataset後,需要重寫的方法有:__len__與__getitem__方法,其中__len__方法需要返回所有數據的數量,而__getitem__則是要依照給出的數據索引獲取對應的tensor類型的Sample,除瞭這兩個方法以外,一般還需要實現__init__方法來初始化一些變量。話不多說,直接上代碼。
''' 包括瞭各種數據集的讀取處理,以及圖像相關處理方法 ''' from torch.utils.data import Dataset import torch import os import cv2 from Config import mycfg import random import numpy as np class ImageClassifyDataset(Dataset): def __init__(self, imagedir, labelfile, classify_num, train=True): ''' 這裡進行一些初始化操作。 ''' self.imagedir = imagedir self.labelfile = labelfile self.classify_num = classify_num self.img_list = [] # 讀取標簽 with open(self.labelfile, 'r') as fp: lines = fp.readlines() for line in lines: filepath = os.path.join(self.imagedir, line.split(";")[0].replace('\\', '/')) label = line.split(";")[1].strip('\n') self.img_list.append((filepath, label)) if not train: self.img_list = random.sample(self.img_list, 50) def __len__(self): return len(self.img_list) def __getitem__(self, item): ''' 這個函數是關鍵,通過item(索引)來取數據集中的數據, 一般來說在這裡才將圖像數據加載入內存,之前存的是圖像的保存路徑 ''' _int_label = int(self.img_list[item][1]) # label直接用0,1,2,3,4...表示不同類別 label = torch.tensor(_int_label,dtype=torch.long) img = self.ProcessImgResize(self.img_list[item][0]) return img, label def ProcessImgResize(self, filename): ''' 對圖像進行一些預處理 ''' _img = cv2.imread(filename) _img = cv2.resize(_img, (mycfg.IMG_WIDTH, mycfg.IMG_HEIGHT), interpolation=cv2.INTER_CUBIC) _img = _img.transpose((2, 0, 1)) _img = _img / 255 _img = torch.from_numpy(_img) _img = _img.to(torch.float32) return _img
有一些的數據集類一般還會傳入一個transforms函數來構造一個圖像預處理序列,傳入transforms函數的一個好處是作為參數傳入的話可以對一些非本地數據集中的數據進行操作(比如直接通過torchvision獲取的一些預存數據集CIFAR10等等),除此之外就是torchvision.transforms裡面有一些預定義的圖像操作函數,可以直接像拼積木一樣拼成一個圖像處理序列,很方便。我這裡因為是用我自己下載到本地的數據集,而且比較簡單就直接用自己的函數來操作瞭。
2. 使用Dataloader加載數據
實例化自定義的數據集類ImageClassifyDataset後,將其傳給DataLoader作為參數,得到一個可遍歷的數據加載器。可以通過參數batch_size控制批處理大小,shuffle控制是否亂序讀取,num_workers控制用於讀取數據的線程數量。
from torch.utils.data import DataLoader from MyDataset import ImageClassifyDataset dataset = ImageClassifyDataset(imagedir, labelfile, 10) dataloader = DataLoader(dataset, batch_size=5, shuffle=True,num_workers=5) for index, data in enumerate(dataloader): print(index) # batch索引 print(data) # 一個batch的{img,label}
二. 模型設計
在這裡隻討論深度學習模型的設計,pytorch中的網絡結構是一層一層疊出來的,pytorch中預定義瞭許多可以通過參數控制的網絡層結構,比如Linear、CNN、RNN、Transformer等等具體可以查閱官方文檔中的torch.nn部分。
設計自己的模型結構需要繼承torch.nn.Module這個類,然後實現其中的forward方法,一般在__init__中設定好網絡模型的一些組件,然後在forward方法中依據輸入輸出順序拼裝組件。
''' 包括瞭各種模型、自定義的loss計算方法、optimizer ''' import torch.nn as nn class Simple_CNN(nn.Module): def __init__(self, class_num): super(Simple_CNN, self).__init__() self.class_num = class_num self.conv1 = nn.Sequential( nn.Conv2d( # input: 3,400,600 in_channels=3, out_channels=8, kernel_size=5, stride=1, padding=2 ), nn.Conv2d( in_channels=8, out_channels=16, kernel_size=5, stride=1, padding=2 ), nn.AvgPool2d(2), # 16,400,600 --> 16,200,300 nn.BatchNorm2d(16), nn.LeakyReLU(), nn.Conv2d( in_channels=16, out_channels=16, kernel_size=5, stride=1, padding=2 ), nn.Conv2d( in_channels=16, out_channels=8, kernel_size=5, stride=1, padding=2 ), nn.AvgPool2d(2), # 8,200,300 --> 8,100,150 nn.BatchNorm2d(8), nn.LeakyReLU(), nn.Conv2d( in_channels=8, out_channels=8, kernel_size=3, stride=1, padding=1 ), nn.Conv2d( in_channels=8, out_channels=1, kernel_size=3, stride=1, padding=1 ), nn.AvgPool2d(2), # 1,100,150 --> 1,50,75 nn.BatchNorm2d(1), nn.LeakyReLU() ) self.line = nn.Sequential( nn.Linear( in_features=50 * 75, out_features=self.class_num ), nn.Softmax() ) def forward(self, x): x = self.conv1(x) x = x.view(-1, 50 * 75) y = self.line(x) return y
上面我定義的模型中包括卷積組件conv1和全連接組件line,卷積組件中包括瞭一些卷積層,一般是按照{卷積層、池化層、激活函數}的順序拼接,其中我還在激活函數之前添加瞭一個BatchNorm2d層對上層的輸出進行正則化以免傳入激活函數的值過小(梯度消失)或過大(梯度爆炸)。
在拼接組件時,由於我全連接層的輸入是一個一維向量,所以需要將卷積組件中最後的50 × 75 50\times 7550×75大小的矩陣展平成一維的再傳入全連接層(x.view(-1,50*75))
三. 訓練
實例化模型後,網絡模型的訓練需要定義損失函數與優化器,損失函數定義瞭網絡輸出與標簽的差距,依據不同的任務需要定義不同的合適的損失函數,而優化器則定義瞭神經網絡中的參數如何基於損失來更新,目前神經網絡最常用的優化器就是SGD(隨機梯度下降算法) 及其變種。
在我這個簡單的分類器模型中,直接用的多分類任務最常用的損失函數CrossEntropyLoss()以及優化器SGD。
self.cnnmodel = Simple_CNN(mycfg.CLASS_NUM) self.criterion = nn.CrossEntropyLoss() # 交叉熵,標簽應該是0,1,2,3...的形式而不是獨熱的 self.optimizer = optim.SGD(self.cnnmodel.parameters(), lr=mycfg.LEARNING_RATE, momentum=0.9)
訓練過程其實很簡單,使用dataloader依照batch讀出數據後,將input放入網絡模型中計算得到網絡的輸出,然後基於標簽通過損失函數計算Loss,並將Loss反向傳播回神經網絡(在此之前需要清理上一次循環時的梯度),最後通過優化器更新權重。訓練部分代碼如下:
for each_epoch in range(mycfg.MAX_EPOCH): running_loss = 0.0 self.cnnmodel.train() for index, data in enumerate(self.dataloader): inputs, labels = data outputs = self.cnnmodel(inputs) loss = self.criterion(outputs, labels) self.optimizer.zero_grad() # 清理上一次循環的梯度 loss.backward() # 反向傳播 self.optimizer.step() # 更新參數 running_loss += loss.item() if index % 200 == 199: print("[{}] loss: {:.4f}".format(each_epoch, running_loss/200)) running_loss = 0.0 # 保存每一輪的模型 model_name = 'classify-{}-{}.pth'.format(each_epoch,round(all_loss/all_index,3)) torch.save(self.cnnmodel,model_name) # 保存全部模型
四. 測試
測試和訓練的步驟差不多,也就是讀取模型後通過dataloader獲取數據然後將其輸入網絡獲得輸出,但是不需要進行反向傳播的等操作瞭。比較值得註意的可能就是準確率計算方面有一些小技巧。
acc = 0.0 count = 0 self.cnnmodel = torch.load('mymodel.pth') self.cnnmodel.eval() for index, data in enumerate(dataloader_eval): inputs, labels = data # 5,3,400,600 5,10 count += len(labels) outputs = cnnmodel(inputs) _,predict = torch.max(outputs, 1) acc += (labels == predict).sum().item() print("[{}] accurancy: {:.4f}".format(each_epoch, acc / count))
我這裡采用的是保存全部模型並加載全部模型的方法,這種方法的好處是在使用模型時可以完全將其看作一個黑盒,但是在模型比較大時這種方法會很費事。此時可以采用隻保存參數不保存網絡結構的方法,在每一次使用模型時需要讀取參數賦值給已經實例化的模型:
torch.save(cnnmodel.state_dict(), "my_resnet.pth") cnnmodel = Simple_CNN() cnnmodel.load_state_dict(torch.load("my_resnet.pth"))
結語
至此整個流程就說完瞭,是一個小白級的圖像分類任務流程,因為前段時間一直在做android方面的事,所以有點生疏瞭,就寫瞭這篇博客記錄一下,之後應該還會寫一下seq2seq以及image caption任務方面的模型構造與訓練過程,完整代碼之後也會統一放到github上給大傢做參考。
以上就是基於PyTorch實現一個簡單的CNN圖像分類器的詳細內容,更多關於PyTorch實現CNN圖像分類器的資料請關註WalkonNet其它相關文章!
推薦閱讀:
- 超詳細PyTorch實現手寫數字識別器的示例代碼
- Pytorch深度學習之實現病蟲害圖像分類
- pytorch教程網絡和損失函數的可視化代碼示例
- 淺談Pytorch 定義的網絡結構層能否重復使用
- Pytorch實現ResNet網絡之Residual Block殘差塊