CoAtNet實戰之對植物幼苗圖像進行分類(pytorch)

前言

雖然Transformer在CV任務上有非常強的學習建模能力,但是由於缺少瞭像CNN那樣的歸納偏置,所以相比於CNN,Transformer的泛化能力就比較差。因此,如果隻有Transformer進行全局信息的建模,在沒有預訓練(JFT-300M)的情況下,Transformer在性能上很難超過CNN(VOLO在沒有預訓練的情況下,一定程度上也是因為VOLO的Outlook Attention對特征信息進行瞭局部感知,相當於引入瞭歸納偏置)。既然CNN有更強的泛化能力,Transformer具有更強的學習能力,那麼,為什麼不能將Transformer和CNN進行一個結合呢?

谷歌的最新模型CoAtNet做瞭卷積 + Transformer的融合,在ImageNet-1K數據集上取得88.56%的成績。今天我們就用CoAtNet實現植物幼苗的分類。

論文

github復現

項目結構

數據集

數據集選用植物幼苗分類,總共12類。數據集連接如下:

鏈接 提取碼:q060

在工程的根目錄新建data文件夾,獲取數據集後,將trian和test解壓放到data文件夾下面,如下圖:

安裝庫,並導入需要的庫

安裝完成後,導入到項目中。

import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from dataset.dataset import SeedlingData
from torch.autograd import Variable
from models.coatnet import coatnet_0

設置全局參數

設置使用GPU,設置學習率、BatchSize、epoch等參數

# 設置全局參數
modellr = 1e-4
BATCH_SIZE = 16
EPOCHS = 50
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

數據預處理

數據處理比較簡單,沒有做復雜的嘗試,有興趣的可以加入一些處理。

# 數據預處理

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])
transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

數據讀取

然後我們在dataset文件夾下面新建 init.py和dataset.py,在mydatasets.py文件夾寫入下面的代碼:

說一下代碼的核心邏輯。

第一步 建立字典,定義類別對應的ID,用數字代替類別。

第二步 在__init__裡面編寫獲取圖片路徑的方法。測試集隻有一層路徑直接讀取,訓練集在train文件夾下面是類別文件夾,先獲取到類別,再獲取到具體的圖片路徑。然後使用sklearn中切分數據集的方法,按照7:3的比例切分訓練集和驗證集。

第三步 在__getitem__方法中定義讀取單個圖片和類別的方法,由於圖像中有位深度32位的,所以我在讀取圖像的時候做瞭轉換。

代碼如下:

# coding:utf8
import os
from PIL import Image
from torch.utils import data
from torchvision import transforms as T
from sklearn.model_selection import train_test_split
 
Labels = {'Black-grass': 0, 'Charlock': 1, 'Cleavers': 2, 'Common Chickweed': 3,
          'Common wheat': 4, 'Fat Hen': 5, 'Loose Silky-bent': 6, 'Maize': 7, 'Scentless Mayweed': 8,
          'Shepherds Purse': 9, 'Small-flowered Cranesbill': 10, 'Sugar beet': 11}
 
 
class SeedlingData (data.Dataset):
 
    def __init__(self, root, transforms=None, train=True, test=False):
        """
        主要目標: 獲取所有圖片的地址,並根據訓練,驗證,測試劃分數據
        """
        self.test = test
        self.transforms = transforms
 
        if self.test:
            imgs = [os.path.join(root, img) for img in os.listdir(root)]
            self.imgs = imgs
        else:
            imgs_labels = [os.path.join(root, img) for img in os.listdir(root)]
            imgs = []
            for imglable in imgs_labels:
                for imgname in os.listdir(imglable):
                    imgpath = os.path.join(imglable, imgname)
                    imgs.append(imgpath)
            trainval_files, val_files = train_test_split(imgs, test_size=0.3, random_state=42)
            if train:
                self.imgs = trainval_files
            else:
                self.imgs = val_files
 
    def __getitem__(self, index):
        """
        一次返回一張圖片的數據
        """
        img_path = self.imgs[index]
        img_path=img_path.replace("\\",'/')
        if self.test:
            label = -1
        else:
            labelname = img_path.split('/')[-2]
            label = Labels[labelname]
        data = Image.open(img_path).convert('RGB')
        data = self.transforms(data)
        return data, label
 
    def __len__(self):
        return len(self.imgs)

然後我們在train.py調用SeedlingData讀取數據 ,記著導入剛才寫的dataset.py(from mydatasets import SeedlingData)

# 讀取數據
dataset_train = SeedlingData('data/train', transforms=transform, train=True)
dataset_test = SeedlingData("data/train", transforms=transform_test, train=False)
# 導入數據
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

設置模型

  • 設置loss函數為nn.CrossEntropyLoss()。
  • 設置模型為coatnet_0,修改最後一層全連接輸出改為12。
  • 優化器設置為adam。
  • 學習率調整策略改為餘弦退火
# 實例化模型並且移動到GPU
criterion = nn.CrossEntropyLoss()

model_ft = coatnet_0()
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 12)
model_ft.to(DEVICE)
# 選擇簡單暴力的Adam優化器,學習率調低
optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,T_max=20,eta_min=1e-9)
# 定義訓練過程

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    sum_loss = 0
    total_num = len(train_loader.dataset)
    print(total_num, len(train_loader))
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = Variable(data).to(device), Variable(target).to(device)
        output = model(data)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print_loss = loss.data.item()
        sum_loss += print_loss
        if (batch_idx + 1) % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                       100. * (batch_idx + 1) / len(train_loader), loss.item()))
    ave_loss = sum_loss / len(train_loader)
    print('epoch:{},loss:{}'.format(epoch, ave_loss))


# 驗證過程
def val(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    total_num = len(test_loader.dataset)
    print(total_num, len(test_loader))
    with torch.no_grad():
        for data, target in test_loader:
            data, target = Variable(data).to(device), Variable(target).to(device)
            output = model(data)
            loss = criterion(output, target)
            _, pred = torch.max(output.data, 1)
            correct += torch.sum(pred == target)
            print_loss = loss.data.item()
            test_loss += print_loss
        correct = correct.data.item()
        acc = correct / total_num
        avgloss = test_loss / len(test_loader)
        print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            avgloss, correct, len(test_loader.dataset), 100 * acc))


# 訓練

for epoch in range(1, EPOCHS + 1):
    train(model_ft, DEVICE, train_loader, optimizer, epoch)
    cosine_schedule.step()
    val(model_ft, DEVICE, test_loader)
torch.save(model_ft, 'model.pth')

測試

測試集存放的目錄如下圖:

第一步 定義類別,這個類別的順序和訓練時的類別順序對應,一定不要改變順序!!!!

classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed',
           'Common wheat', 'Fat Hen', 'Loose Silky-bent',
           'Maize', 'Scentless Mayweed', 'Shepherds Purse', 'Small-flowered Cranesbill', 'Sugar beet')

第二步 定義transforms,transforms和驗證集的transforms一樣即可,別做數據增強。

transform_test = transforms.Compose([
         transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

第三步 加載model,並將模型放在DEVICE裡。

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("model.pth")
model.eval()
model.to(DEVICE)

第四步 讀取圖片並預測圖片的類別,在這裡註意,讀取圖片用PIL庫的Image。不要用cv2,transforms不支持。

path = 'data/test/'
testList = os.listdir(path)
for file in testList:
    img = Image.open(path + file)
    img = transform_test(img)
    img.unsqueeze_(0)
    img = Variable(img).to(DEVICE)
    out = model(img)
    # Predict
    _, pred = torch.max(out.data, 1)
    print('Image Name:{},predict:{}'.format(file, classes[pred.data.item()]))

測試完整代碼:

import torch.utils.data.distributed
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
import os

classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed',
           'Common wheat', 'Fat Hen', 'Loose Silky-bent',
           'Maize', 'Scentless Mayweed', 'Shepherds Purse', 'Small-flowered Cranesbill', 'Sugar beet')
transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("model.pth")
model.eval()
model.to(DEVICE)

path = 'data/test/'
testList = os.listdir(path)
for file in testList:
    img = Image.open(path + file)
    img = transform_test(img)
    img.unsqueeze_(0)
    img = Variable(img).to(DEVICE)
    out = model(img)
    # Predict
    _, pred = torch.max(out.data, 1)
    print('Image Name:{},predict:{}'.format(file, classes[pred.data.item()]))

運行結果:

以上就是CoAtNet實戰之對植物幼苗圖像進行分類(pytorch)的詳細內容,更多關於CoAtNet 植物幼苗圖像分類的資料請關註WalkonNet其它相關文章!

推薦閱讀: