pytorch加載自己的圖片數據集的2種方法詳解
pytorch加載圖片數據集有兩種方法。
1.ImageFolder 適合於分類數據集,並且每一個類別的圖片在同一個文件夾, ImageFolder加載的數據集, 訓練數據為文件件下的圖片, 訓練標簽是對應的文件夾, 每個文件夾為一個類別
導入ImageFolder()包 from torchvision.datasets import ImageFolder
在Flower_Orig_dataset文件夾下有flower_orig 和 sunflower這兩個文件夾, 這兩個文件夾下放著同一個類別的圖片。 使用 ImageFolder 加載的圖片, 就會返回圖片信息和對應的label信息, 但是label信息是根據文件夾給出的, 如flower_orig就是標簽0, sunflower就是標簽1。
ImageFolder 加載數據集
1. 導入包和設置transform
import torch from torchvision import transforms, datasets import torch.nn as nn from torch.utils.data import DataLoader transforms = transforms.Compose([ transforms.Resize(256), # 將圖片短邊縮放至256,長寬比保持不變: transforms.CenterCrop(224), #將圖片從中心切剪成3*224*224大小的圖片 transforms.ToTensor() #把圖片進行歸一化,並把數據轉換成Tensor類型 ])
2.加載數據集: 將分類圖片的父目錄作為路徑傳遞給ImageFolder(), 並傳入transform。這樣就有瞭要加載的數據集, 之後就可以使用DataLoader加載數據, 並構建網絡訓練。
path = r'D:\數據集\Flower_Orig_dataset' data_train = datasets.ImageFolder(path, transform=transforms) data_loader = DataLoader(data_train, batch_size=64, shuffle=True) for i, data in enumerate(data_loader): images, labels = data print(images.shape) print(labels.shape) break
使用pytorch提供的Dataset類創建自己的數據集。
具體步驟:
1. 首先要有一個txt文件, 這個文件格式是: 圖片路徑 標簽. 這樣的格式, 所以使用os庫, 遍歷自己的圖片名, 並把標簽和圖片路徑寫入txt文件。
2. 有瞭這個txt文件, 我們就可以在類裡面構造我們的數據集.
2.1 把圖片路徑和圖片標簽分割開, 有兩個列表, 一個列表是圖片路徑名, 一個列表是標簽號, 有一點就是第 i 個圖片列表和 第 i 個標簽是對應的
3. 重寫__len__方法 和 __getitem__方法
3.1 getitem方法中, 獲得對應的圖片路徑,並用PIL庫讀取文件把圖片transfrom後, 在getitem函數中返回讀取的圖片和標簽即可
4.就可以構建數據集實例和加載數據集.
定義一個用來生成[ 圖片路徑 標簽] 這樣的txt文件函數
def make_txt(root, file_name, label): path = os.path.join(root, file_name) data = os.listdir(path) f = open(path+'\\'+'f.txt', 'w') for line in data: f.write(line+' '+str(label)+'\n') f.close() #調用函數生成兩個文件夾下的txt文件 make_txt(path, file_name='flower_orig', label=0) make_txt(path, file_name='sunflower', label=1)
將連個txt文件合並成一個txt文件,表示數據集所有的圖片和標簽
def link_txt(file1, file2): txt_list = [] path = r'D:\數據集\Flower_Orig_dataset\data.txt' f = open(path, 'a') f1 = open(file1, 'r') data1 = f1.readlines() for line in data1: txt_list.append(line) f2 = open(file2, 'r') data2 = f2.readlines() for line in data2: txt_list.append(line) for line in txt_list: f.write(line) f.close() f1.close() f2.close() #調用函數, 將兩個文件夾下的txt文件合並 file1 = r'D:\數據集\Flower_Orig_dataset\flower_orig\f.txt' file2 = r'D:\數據集\Flower_Orig_dataset\sunflower\f.txt' link_txt(file1=file1, file2=file2)
現在我們已經有瞭我們制作數據集所需要的txt文件, 接下來要做的即使繼承Dataset類, 來構建自己的數據集 , 別忘瞭前面說的 構建數據集步驟, 在__getitem__函數中, 需要拿到圖片路徑和標簽, 並且用PIL庫方法讀取圖片,對圖片進行transform轉換後,返回圖片信息和標簽信息
Dataset加載數據集
我們讀取圖片的根目錄, 在根目錄下有所有圖片的txt文件, 拿到txt文件後, 先讀取txt文件, 之後遍歷txt文件中的每一行, 首先去除掉尾部的換行符, 在以空格切分,前半部分是圖片名稱, 後半部分是圖片標簽, 當圖片名稱和根目錄結合,就得到瞭我們的圖片路徑 class MyDataset(Dataset): def __init__(self, img_path, transform=None): super(MyDataset, self).__init__() self.root = img_path self.txt_root = self.root + 'data.txt' f = open(self.txt_root, 'r') data = f.readlines() imgs = [] labels = [] for line in data: line = line.rstrip() word = line.split() imgs.append(os.path.join(self.root, word[1], word[0])) labels.append(word[1]) self.img = imgs self.label = labels self.transform = transform def __len__(self): return len(self.label) def __getitem__(self, item): img = self.img[item] label = self.label[item] img = Image.open(img).convert('RGB') #此時img是PIL.Image類型 label是str類型 if transforms is not None: img = self.transform(img) label = np.array(label).astype(np.int64) label = torch.from_numpy(label) return img, label
加載我們的數據集:
path = r'D:\數據集\Flower_Orig_dataset' dataset = MyDataset(path, transform=transform) data_loader = DataLoader(dataset=dataset, batch_size=64, shuffle=True)
接下來我們就可以構建我們的網絡架構:
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3,16,3) self.maxpool = nn.MaxPool2d(2,2) self.conv2 = nn.Conv2d(16,5,3) self.relu = nn.ReLU() self.fc1 = nn.Linear(55*55*5, 1200) self.fc2 = nn.Linear(1200,64) self.fc3 = nn.Linear(64,2) def forward(self,x): x = self.maxpool(self.relu(self.conv1(x))) #113 x = self.maxpool(self.relu(self.conv2(x))) #55 x = x.view(-1, self.num_flat_features(x)) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x def num_flat_features(self, x): size = x.size()[1:] num_features = 1 for s in size: num_features *= s return num_features
訓練我們的網絡:
model = Net() criterion = torch.nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.01) epochs = 10 for epoch in range(epochs): running_loss = 0.0 for i, data in enumerate(data_loader): images, label = data out = model(images) loss = criterion(out, label) optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() if(i+1)%10 == 0: print('[%d %5d] loss: %.3f'%(epoch+1, i+1, running_loss/100)) running_loss = 0.0 print('finished train')
保存網絡模型(這裡不止是保存參數,還保存瞭網絡結構)
#保存模型 torch.save(net, 'model_name.pth') #保存的是模型, 不止是w和b權重值 # 讀取模型 model = torch.load('model_name.pth')
總結
到此這篇關於pytorch加載自己的圖片數據集的2種方法的文章就介紹到這瞭,更多相關pytorch加載圖片數據集內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!
推薦閱讀:
- 使用pytorch讀取數據集
- Pytorch DataLoader shuffle驗證方式
- pytorch深度神經網絡入門準備自己的圖片數據
- Pytorch深度學習之實現病蟲害圖像分類
- pytorch從csv加載自定義數據模板的操作