PyTorch實現MNIST數據集手寫數字識別詳情

前言:

本篇文章基於卷積神經網絡CNN,使用PyTorch實現MNIST數據集手寫數字識別。

一、PyTorch是什麼?

PyTorch 是一個 Torch7 團隊開源的 Python 優先的深度學習框架,提供兩個高級功能:

  • 強大的 GPU 加速 Tensor 計算(類似 numpy)
  • 構建基於 tape 的自動升級系統上的深度神經網絡

你可以重用你喜歡的 python 包,如 numpy、scipy 和 Cython ,在需要時擴展 PyTorch。

二、程序示例

下面案例可供運行參考

1.引入必要庫

import torchvision
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

2.下載數據集

這裡設置download=True,將會自動下載數據集,並存儲在./data文件夾。

train_data = torchvision.datasets.MNIST(root="./data",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.MNIST(root="./data",train=False,transform=torchvision.transforms.ToTensor(),download=True)

3.加載數據集

batch_size=32表示每一個batch中包含32張手寫數字圖片,shuffle=True表示打亂測試集(data和target仍一一對應)

train_loader = DataLoader(train_data,batch_size=32,shuffle=True)
test_loader = DataLoader(test_data,batch_size=32,shuffle=False)

4.搭建CNN模型並實例化

class Net(torch.nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.con1 = torch.nn.Conv2d(1,10,kernel_size=5)
        self.con2 = torch.nn.Conv2d(10,20,kernel_size=5)
        self.pooling = torch.nn.MaxPool2d(2)
        self.fc = torch.nn.Linear(320,10)
    def forward(self,x):
        batch_size = x.size(0)
        x = F.relu(self.pooling(self.con1(x)))
        x = F.relu(self.pooling(self.con2(x)))
        x = x.view(batch_size,-1)
        x = self.fc(x)
        return x
#模型實例化        
model = Net()

5.交叉熵損失函數損失函數及SGD算法優化器

lossfun = torch.nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)

6.訓練函數

def train(epoch):
    running_loss = 0.0
    for i,(inputs,targets) in enumerate(train_loader,0):
        # inputs,targets = inputs.to(device),targets.to(device)
        opt.zero_grad()
        outputs = model(inputs)
        loss = lossfun(outputs,targets)
        loss.backward()
        opt.step()

        running_loss += loss.item()
        if i % 300 == 299:
            print('[%d,%d] loss:%.3f' % (epoch+1,i+1,running_loss/300))
            running_loss = 0.0

7.測試函數

def test():
    total = 0
    correct = 0
    with torch.no_grad():
        for (inputs,targets) in test_loader:
            # inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _,predicted = torch.max(outputs.data,dim=1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    print(100*correct/total)

8.運行

if __name__ == '__main__':
    for epoch in range(20):
        train(epoch)
        test()

三、總結

到此這篇關於PyTorch實現MNIST數據集手寫數字識別詳情的文章就介紹到這瞭,更多相關PyTorch MNIST 內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!

推薦閱讀: