Pytorch數據讀取之Dataset和DataLoader知識總結
一、前言
確保安裝
- scikit-image
- numpy
二、Dataset
一個例子:
# 導入需要的包 import torch import torch.utils.data.dataset as Dataset import numpy as np # 編造數據 Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]]) Label = np.asarray([[0], [1], [0], [2]]) # 數據[1,2],對應的標簽是[0],數據[3,4],對應的標簽是[1] #創建子類 class subDataset(Dataset.Dataset): #初始化,定義數據內容和標簽 def __init__(self, Data, Label): self.Data = Data self.Label = Label #返回數據集大小 def __len__(self): return len(self.Data) #得到數據內容和標簽 def __getitem__(self, index): data = torch.Tensor(self.Data[index]) label = torch.IntTensor(self.Label[index]) return data, label # 主函數 if __name__ == '__main__': dataset = subDataset(Data, Label) print(dataset) print('dataset大小為:', dataset.__len__()) print(dataset.__getitem__(0)) print(dataset[0])
輸出的結果
我們有瞭對Dataset的一個整體的把握,再來分析裡面的細節:
#創建子類 class subDataset(Dataset.Dataset):
創建子類時,繼承的時Dataset.Dataset,不是一個Dataset。因為Dataset是module模塊,不是class類,所以需要調用module裡的class才行,因此是Dataset.Dataset!
len和getitem這兩個函數,前者給出數據集的大小**,後者是用於查找數據和標簽。是最重要的兩個函數,我們後續如果要對數據做一些操作基本上都是再這兩個函數的基礎上進行。
三、DatasetLoader
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_works=0, clollate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)
功能:構建可迭代的數據裝載器;
dataset:Dataset類,決定數據從哪裡讀取及如何讀取;數據集的路徑
batchsize:批大小;
num_works:是否多進程讀取數據;隻對於CPU
shuffle:每個epoch是否打亂;
drop_last:當樣本數不能被batchsize整除時,是否舍棄最後一批數據;
Epoch:所有訓練樣本都已輸入到模型中,稱為一個Epoch;
Iteration:一批樣本輸入到模型中,稱之為一個Iteration;
Batchsize:批大小,決定一個Epoch中有多少個Iteration;
還是舉一個實例:
import torch import torch.utils.data.dataset as Dataset import torch.utils.data.dataloader as DataLoader import numpy as np Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]]) Label = np.asarray([[0], [1], [0], [2]]) #創建子類 class subDataset(Dataset.Dataset): #初始化,定義數據內容和標簽 def __init__(self, Data, Label): self.Data = Data self.Label = Label #返回數據集大小 def __len__(self): return len(self.Data) #得到數據內容和標簽 def __getitem__(self, index): data = torch.Tensor(self.Data[index]) label = torch.IntTensor(self.Label[index]) return data, label if __name__ == '__main__': dataset = subDataset(Data, Label) print(dataset) print('dataset大小為:', dataset.__len__()) print(dataset.__getitem__(0)) print(dataset[0]) #創建DataLoader迭代器,相當於我們要先定義好前面說的Dataset,然後再用Dataloader來對數據進行一些操作,比如是否需要打亂,則shuffle=True,是否需要多個進程讀取數據num_workers=4,就是四個進程 dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4) for i, item in enumerate(dataloader): #可以用enumerate來提取出裡面的數據 print('i:', i) data, label = item #數據是一個元組 print('data:', data) print('label:', label)
四、將Dataset數據和標簽放在GPU上(代碼執行順序出錯則會有bug)
這部分可以直接去看博客:Dataset和DataLoader
總結下來時有兩種方法解決
1.如果在創建Dataset的類時,定義__getitem__方法的時候,將數據轉變為GPU類型。則需要將Dataloader裡面的參數num_workers設置為0,因為這個參數是對於CPU而言的。如果數據改成瞭GPU,則隻能單進程。如果是在Dataloader的部分,先多個子進程讀取,再轉變為GPU,則num_wokers不用修改。就是上述__getitem__部分的代碼,移到Dataloader部分。
2.不過一般來講,數據集和標簽不會像我們上述編輯的那麼簡單。一般再kaggle上的標簽都是存在CSV這種文件中。需要pandas的配合。
這個進階可以看:WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS,他是用人臉圖片作為數據和人臉特征點作為標簽。
到此這篇關於Pytorch數據讀取之Dataset和DataLoader知識總結的文章就介紹到這瞭,更多相關詳解Dataset和DataLoader內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!
推薦閱讀:
- Pytorch如何加載自己的數據集(使用DataLoader讀取Dataset)
- 我對PyTorch dataloader裡的shuffle=True的理解
- Pytorch數據讀取與預處理該如何實現
- Pytorch DataLoader shuffle驗證方式
- pytorch DataLoader的num_workers參數與設置大小詳解