pytorch深度神經網絡入門準備自己的圖片數據

正文

圖片數據一般有兩種情況:

1、所有圖片放在一個文件夾內,另外有一個txt文件顯示標簽。

2、不同類別的圖片放在不同的文件夾內,文件夾就是圖片的類別。

針對這兩種不同的情況,數據集的準備也不相同,第一種情況可以自定義一個Dataset,第二種情況直接調用torchvision.datasets.ImageFolder來處理。下面分別進行說明:

一、所有圖片放在一個文件夾內

這裡以mnist數據集的10000個test為例, 我先把test集的10000個圖片保存出來,並生著對應的txt標簽文件。

先在當前目錄創建一個空文件夾mnist_test, 用於保存10000張圖片,接著運行代碼:

import torch
import torchvision
import matplotlib.pyplot as plt
from skimage import io
mnist_test= torchvision.datasets.MNIST(
    './mnist', train=False, download=True
)
print('test set:', len(mnist_test))
f=open('mnist_test.txt','w')
for i,(img,label) in enumerate(mnist_test):
    img_path="./mnist_test/"+str(i)+".jpg"
    io.imsave(img_path,img)
    f.write(img_path+' '+str(label)+'\n')
f.close()

經過上面的操作,10000張圖片就保存在mnist_test文件夾裡瞭,並在當前目錄下生成瞭一個mnist_test.txt的文件,大致如下:

前期工作就裝備好瞭,接著就進入正題瞭:

from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from PIL import Image
def default_loader(path):
    return Image.open(path).convert('RGB')
class MyDataset(Dataset):
    def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
        fh = open(txt, 'r')
        imgs = []
        for line in fh:
            line = line.strip('\n')
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0],int(words[1])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = self.loader(fn)
        if self.transform is not None:
            img = self.transform(img)
        return img,label
    def __len__(self):
        return len(self.imgs)
train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor())
data_loader = DataLoader(train_data, batch_size=100,shuffle=True)
print(len(data_loader))
def show_batch(imgs):
    grid = utils.make_grid(imgs)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))
    plt.title('Batch from dataloader')
for i, (batch_x, batch_y) in enumerate(data_loader):
    if(i<4):
        print(i, batch_x.size(),batch_y.size())
        show_batch(batch_x)
        plt.axis('off')
        plt.show()

自定義瞭一個MyDataset, 繼承自torch.utils.data.Dataset。然後利用torch.utils.data.DataLoader將整個數據集分成多個批次。

二、不同類別的圖片放在不同的文件夾內

同樣先準備數據,這裡以flowers數據集為例

提取 鏈接: https://pan.baidu.com/s/1dcAsOOZpUfWNYR77JGXPHA?pwd=mwg6 

花總共有五類,分別放在5個文件夾下。大致如下圖:

我的路徑是d:/flowers/.

數據準備好瞭,就開始準備Dataset吧,這裡直接調用torchvision裡面的ImageFolder

import torch
import torchvision
from torchvision import transforms, utils
import matplotlib.pyplot as plt
img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower',
                                            transform=transforms.Compose([
                                                transforms.Scale(256),
                                                transforms.CenterCrop(224),
                                                transforms.ToTensor()])
                                            )
print(len(img_data))
data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)
print(len(data_loader))
def show_batch(imgs):
    grid = utils.make_grid(imgs,nrow=5)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))
    plt.title('Batch from dataloader')
for i, (batch_x, batch_y) in enumerate(data_loader):
    if(i<4):
        print(i, batch_x.size(), batch_y.size())
        show_batch(batch_x)
        plt.axis('off')
        plt.show()

以上就是pytorch深度神經網絡入門準備自己的圖片數據的詳細內容,更多關於pytorch圖片數據準備的資料請關註WalkonNet其它相關文章!

推薦閱讀: