torch.utils.data.DataLoader與迭代器轉換操作
在做實驗時,我們常常會使用用開源的數據集進行測試。而Pytorch中內置瞭許多數據集,這些數據集我們常常使用DataLoader
類進行加載。
如下面這個我們使用DataLoader
類加載torch.vision
中的FashionMNIST
數據集。
from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt training_data = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor() ) test_data = datasets.FashionMNIST( root="data", train=False, download=True, transform=ToTensor() )
我們接下來定義Dataloader對象用於加載這兩個數據集:
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
那麼這個train_dataloader
究竟是什麼類型呢?
print(type(train_dataloader)) # <class 'torch.utils.data.dataloader.DataLoader'>
我們可以將先其轉換為迭代器類型。
print(type(iter(train_dataloader)))# <class 'torch.utils.data.dataloader._SingleProcessDataLoaderIter'>
然後再使用next(iter(train_dataloader))
從迭代器裡取數據,如下所示:
train_features, train_labels = next(iter(train_dataloader)) print(f"Feature batch shape: {train_features.size()}") print(f"Labels batch shape: {train_labels.size()}") img = train_features[0].squeeze() label = train_labels[0] plt.imshow(img, cmap="gray") plt.show() print(f"Label: {label}")
可以看到我們成功獲取瞭數據集中第一張圖片的信息,控制臺打印:
Feature batch shape: torch.Size([64, 1, 28, 28]) Labels batch shape: torch.Size([64]) Label: 2
圖片可視化顯示如下:
不過有讀者可能就會產生疑問,很多時候我們並沒有將DataLoader類型強制轉換成迭代器類型呀,大多數時候我們會寫如下代碼:
for train_features, train_labels in train_dataloader: print(train_features.shape) # torch.Size([64, 1, 28, 28]) print(train_features[0].shape) # torch.Size([1, 28, 28]) print(train_features[0].squeeze().shape) # torch.Size([28, 28]) img = train_features[0].squeeze() label = train_labels[0] plt.imshow(img, cmap="gray") plt.show() print(f"Label: {label}")
可以看到,該代碼也能夠正常迭代訓練數據,前三個樣本的控制臺打印輸出為:
torch.Size([64, 1, 28, 28]) torch.Size([1, 28, 28]) torch.Size([28, 28]) Label: 7 torch.Size([64, 1, 28, 28]) torch.Size([1, 28, 28]) torch.Size([28, 28]) Label: 4 torch.Size([64, 1, 28, 28]) torch.Size([1, 28, 28]) torch.Size([28, 28]) Label: 1
那麼為什麼我們這裡沒有顯式將Dataloader
轉換為迭代器類型呢,其實是Python語言for循環的一種機制,一旦我們用for … in …句式來迭代一個對象,那麼Python
解釋器就會偷偷地自動幫我們創建好迭代器,也就是說
for train_features, train_labels in train_dataloader:
實際上等同於
for train_features, train_labels in iter(train_dataloader):
更進一步,這實際上等同於
train_iterator = iter(train_dataloader) try: while True: train_features, train_labels = next(train_iterator) except StopIteration: pass
推而廣之,我們在用Python迭代直接迭代列表時:
for x in [1, 2, 3, 4]:
其實Python解釋器已經為我們隱式轉換為迭代器瞭:
list_iterator = iter([1, 2, 3, 4]) try: while True: x = next(list_iterator) except StopIteration: pass
到此這篇關於torch.utils.data.DataLoader
與迭代器轉換操作的文章就介紹到這瞭,更多相關torch.utils.data.DataLoader與迭代器轉換內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!
推薦閱讀:
- 解決Pytorch dataloader時報錯每個tensor維度不一樣的問題
- 我對PyTorch dataloader裡的shuffle=True的理解
- pyTorch深度學習softmax實現解析
- pytorch 搭建神經網路的實現
- python中的Pytorch建模流程匯總