pytorch加載自己的圖片數據集的2種方法詳解

pytorch加載圖片數據集有兩種方法。

1.ImageFolder 適合於分類數據集,並且每一個類別的圖片在同一個文件夾, ImageFolder加載的數據集, 訓練數據為文件件下的圖片, 訓練標簽是對應的文件夾, 每個文件夾為一個類別

導入ImageFolder()包
from torchvision.datasets import ImageFolder

在Flower_Orig_dataset文件夾下有flower_orig 和 sunflower這兩個文件夾, 這兩個文件夾下放著同一個類別的圖片。 使用 ImageFolder 加載的圖片, 就會返回圖片信息和對應的label信息, 但是label信息是根據文件夾給出的, 如flower_orig就是標簽0, sunflower就是標簽1。

ImageFolder 加載數據集

1. 導入包和設置transform

import torch
from torchvision import transforms, datasets
import torch.nn as nn
from torch.utils.data import DataLoader
 
transforms = transforms.Compose([
    transforms.Resize(256),    # 將圖片短邊縮放至256,長寬比保持不變:
    transforms.CenterCrop(224),   #將圖片從中心切剪成3*224*224大小的圖片
    transforms.ToTensor()          #把圖片進行歸一化,並把數據轉換成Tensor類型
]) 

2.加載數據集: 將分類圖片的父目錄作為路徑傳遞給ImageFolder(), 並傳入transform。這樣就有瞭要加載的數據集, 之後就可以使用DataLoader加載數據, 並構建網絡訓練。

path = r'D:\數據集\Flower_Orig_dataset'
data_train = datasets.ImageFolder(path, transform=transforms)
data_loader = DataLoader(data_train, batch_size=64, shuffle=True)
for i, data in enumerate(data_loader):
    images, labels = data
    print(images.shape)
    print(labels.shape)
    break

使用pytorch提供的Dataset類創建自己的數據集。

具體步驟:

1.  首先要有一個txt文件, 這個文件格式是: 圖片路徑  標簽.  這樣的格式, 所以使用os庫, 遍歷自己的圖片名, 並把標簽和圖片路徑寫入txt文件。

2. 有瞭這個txt文件, 我們就可以在類裡面構造我們的數據集.

2.1    把圖片路徑和圖片標簽分割開, 有兩個列表, 一個列表是圖片路徑名, 一個列表是標簽號, 有一點就是第 i 個圖片列表和 第 i 個標簽是對應的

3. 重寫__len__方法  和  __getitem__方法

3.1 getitem方法中, 獲得對應的圖片路徑,並用PIL庫讀取文件把圖片transfrom後, 在getitem函數中返回讀取的圖片和標簽即可

4.就可以構建數據集實例和加載數據集.

 定義一個用來生成[ 圖片路徑 標簽] 這樣的txt文件函數

def make_txt(root, file_name, label):
    path = os.path.join(root, file_name)
    data = os.listdir(path)
    f = open(path+'\\'+'f.txt', 'w')
    for line in data:
        f.write(line+' '+str(label)+'\n')
    f.close()
#調用函數生成兩個文件夾下的txt文件
make_txt(path, file_name='flower_orig', label=0)
make_txt(path, file_name='sunflower', label=1)

將連個txt文件合並成一個txt文件,表示數據集所有的圖片和標簽

def link_txt(file1, file2):
    txt_list = []
    path = r'D:\數據集\Flower_Orig_dataset\data.txt'
 
    f = open(path, 'a')
 
    f1 = open(file1, 'r')
    data1 = f1.readlines()
    for line in data1:
        txt_list.append(line)
 
    f2 = open(file2, 'r')
    data2 = f2.readlines()
    for line in data2:
        txt_list.append(line)
 
    for line in txt_list:
        f.write(line)
 
    f.close()
    f1.close()
    f2.close()
 
#調用函數, 將兩個文件夾下的txt文件合並
file1 = r'D:\數據集\Flower_Orig_dataset\flower_orig\f.txt'
file2 = r'D:\數據集\Flower_Orig_dataset\sunflower\f.txt'
link_txt(file1=file1, file2=file2)

現在我們已經有瞭我們制作數據集所需要的txt文件, 接下來要做的即使繼承Dataset類, 來構建自己的數據集 , 別忘瞭前面說的 構建數據集步驟, 在__getitem__函數中, 需要拿到圖片路徑和標簽, 並且用PIL庫方法讀取圖片,對圖片進行transform轉換後,返回圖片信息和標簽信息

Dataset加載數據集

我們讀取圖片的根目錄, 在根目錄下有所有圖片的txt文件, 拿到txt文件後, 先讀取txt文件, 之後遍歷txt文件中的每一行, 首先去除掉尾部的換行符, 在以空格切分,前半部分是圖片名稱, 後半部分是圖片標簽, 當圖片名稱和根目錄結合,就得到瞭我們的圖片路徑   
class MyDataset(Dataset):
    def __init__(self, img_path, transform=None):
        super(MyDataset, self).__init__()
        self.root = img_path
 
        self.txt_root = self.root + 'data.txt'
        f = open(self.txt_root, 'r')
        data = f.readlines()
 
        imgs = []
        labels = []
        for line in data:
            line = line.rstrip()
            word = line.split()
            imgs.append(os.path.join(self.root, word[1], word[0]))
 
            labels.append(word[1])
        self.img = imgs
        self.label = labels
        self.transform = transform
 
    def __len__(self):
        return len(self.label)
 
    def __getitem__(self, item):
        img = self.img[item]
        label = self.label[item]
 
        img = Image.open(img).convert('RGB')
 
        #此時img是PIL.Image類型   label是str類型
 
        if transforms is not None:
            img = self.transform(img)
 
        label = np.array(label).astype(np.int64)
        label = torch.from_numpy(label)
        
        return img, label

 加載我們的數據集:

path = r'D:\數據集\Flower_Orig_dataset'
dataset = MyDataset(path, transform=transform)
 
data_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)

接下來我們就可以構建我們的網絡架構:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,16,3)
        self.maxpool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(16,5,3)
 
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(55*55*5, 1200)
        self.fc2 = nn.Linear(1200,64)
        self.fc3 = nn.Linear(64,2)
 
    def forward(self,x):
        x = self.maxpool(self.relu(self.conv1(x)))    #113
        x = self.maxpool(self.relu(self.conv2(x)))    #55
        x = x.view(-1, self.num_flat_features(x))
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    
    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
 
        return num_features
 

 訓練我們的網絡:

model = Net()
 
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
 
 
epochs = 10
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(data_loader):
        images, label = data
 
        out = model(images)
 
        loss = criterion(out, label)
 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
        running_loss += loss.item()
        if(i+1)%10 == 0:
            print('[%d  %5d]   loss: %.3f'%(epoch+1, i+1, running_loss/100))
            running_loss = 0.0
 
print('finished train')

 保存網絡模型(這裡不止是保存參數,還保存瞭網絡結構)

#保存模型
torch.save(net, 'model_name.pth')   #保存的是模型, 不止是w和b權重值
 
# 讀取模型
model = torch.load('model_name.pth')

總結

到此這篇關於pytorch加載自己的圖片數據集的2種方法的文章就介紹到這瞭,更多相關pytorch加載圖片數據集內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!

推薦閱讀: