pytorch 帶batch的tensor類型圖像顯示操作
項目場景
pytorch訓練時我們一般把數據集放到數據加載器裡,然後分批拿出來訓練。訓練前我們一般還要看一下訓練數據長啥樣,也就是訓練數據集可視化。
那麼如何顯示dataloader裡面帶batch的tensor類型的圖像呢?
顯示圖像
繪圖最常用的庫就是matplotlib:
pip install matplotlib
顯示圖像會用到matplotlib.pyplot.imshow方法。查閱官方文檔可知,該方法接收的圖像的通道數要放到後面:
數據加載器中數據的維度是[B, C, H, W],我們每次隻拿一個數據出來就是[C, H, W],而matplotlib.pyplot.imshow要求的輸入維度是[H, W, C],所以我們需要交換一下數據維度,把通道數放到最後面,這裡用到pytorch裡面的permute方法(transpose方法也行,不過要交換兩次,沒這個方便,numpy中的transpose方法倒是可以一次交換完成)
用法示例如下:
>>> x = torch.randn(2, 3, 5) >>> x.size() torch.Size([2, 3, 5]) >>> x.permute(1, 2, 0).size() torch.Size([3, 5, 2])
代碼示例
#%% 導入模塊 import torch import matplotlib.pyplot as plt from torchvision.utils import make_grid from torch.utils.data import DataLoader from torchvision import datasets, transforms #%% 下載數據集 train_file = datasets.MNIST( root='./dataset/', train=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]), download=True ) #%% 制作數據加載器 train_loader = DataLoader( dataset=train_file, batch_size=9, shuffle=True ) #%% 訓練數據可視化 images, labels = next(iter(train_loader)) print(images.size()) # torch.Size([9, 1, 28, 28]) plt.figure(figsize=(9, 9)) for i in range(9): plt.subplot(3, 3, i+1) plt.title(labels[i].item()) plt.imshow(images[i].permute(1, 2, 0), cmap='gray') plt.axis('off') plt.show()
這裡以mnist數據集為例,演示一下顯示效果。我這個代碼其實還有一點小問題。數據增強的時候我不是進行標準化瞭嘛,就是在第7行代碼:Normalize((0.1307,), (0.3081,))。
所以,如果你想查看訓練集的原始圖像,還得反標準化。
標準化:image = (image-mean)/std
反標準化:image = image*std+mean
我拿imagenet中的一個螞蟻和蜜蜂的子集做瞭一下實驗,標準化前後的區別還是很明顯的:
最終效果
補充:PIL,plt顯示tensor類型的圖像
該方法針對顯示Dataloader讀取的圖像
PIL 與plt中對應操作不同,但原理是一樣的,我試過用下方代碼Image的方法在plt上show失敗瞭,原因暫且不知。
# 方法1:Image.show() # transforms.ToPILImage()中有一句 # npimg = np.transpose(pic.numpy(), (1, 2, 0)) # 因此pic隻能是3-D Tensor,所以要用image[0]消去batch那一維 img = transforms.ToPILImage(image[0]) img.show() # 方法2:plt.imshow(ndarray) img = image[0] # plt.imshow()隻能接受3-D Tensor,所以也要用image[0]消去batch那一維 img = img.numpy() # FloatTensor轉為ndarray img = np.transpose(img, (1,2,0)) # 把channel那一維放到最後 # 顯示圖片 plt.imshow(img) plt.show() cnt += 1
以上為個人經驗,希望能給大傢一個參考,也希望大傢多多支持WalkonNet。
推薦閱讀:
- pytorch深度神經網絡入門準備自己的圖片數據
- Pytorch DataLoader shuffle驗證方式
- 超詳細PyTorch實現手寫數字識別器的示例代碼
- Pytorch實現圖像識別之數字識別(附詳細註釋)
- pytorch鎖死在dataloader(訓練時卡死)