Python Opencv使用ann神經網絡識別手寫數字功能

opencv中也提供瞭一種類似於Keras的神經網絡,即為ann,這種神經網絡的使用方法與Keras的很接近。
關於mnist數據的解析,讀者可以自己從網上下載相應壓縮文件,用python自己編寫解析代碼,由於這裡主要研究knn算法,為瞭圖簡單,直接使用Keras的mnist手寫數字解析模塊。
本次代碼運行環境為:
python 3.6.8
opencv-python 4.4.0.46
opencv-contrib-python 4.4.0.46

下面的代碼為使用ann進行模型的訓練:

from keras.datasets import mnist
from keras import utils
import cv2
import numpy as np
#opencv中ANN定義神經網絡層
def create_ANN():
    ann=cv2.ml.ANN_MLP_create()
    #設置神經網絡層的結構 輸入層為784 隱藏層為80 輸出層為10
    ann.setLayerSizes(np.array([784,64,10]))
    #設置網絡參數為誤差反向傳播法
    ann.setTrainMethod(cv2.ml.ANN_MLP_BACKPROP)
    #設置激活函數為sigmoid
    ann.setActivationFunction(cv2.ml.ANN_MLP_SIGMOID_SYM)
    #設置訓練迭代條件 
    #結束條件為訓練30次或者誤差小於0.00001
    ann.setTermCriteria((cv2.TermCriteria_EPS|cv2.TermCriteria_COUNT,100,0.0001))
    return ann
#計算測試數據上的識別率
def evaluate_acc(ann,test_images,test_labels):
    #采用的sigmoid激活函數,需要對結果進行置信度處理 
    #對於大於0.99的可以確定為1 對於小於0.01的可以確信為0
    test_ret=ann.predict(test_images)
    #預測結果是一個元組
    test_pre=test_ret[1]
    #可以直接最大值的下標 (10000,)
    test_pre=test_pre.argmax(axis=1)
    true_sum=(test_pre==test_labels)
    return true_sum.mean()
if __name__=='__main__':
    #直接使用Keras載入的訓練數據(60000, 28, 28) (60000,)
    (train_images,train_labels),(test_images,test_labels)=mnist.load_data()
    #變換數據的形狀並歸一化
    train_images=train_images.reshape(train_images.shape[0],-1)#(60000, 784)
    train_images=train_images.astype('float32')/255
    test_images=test_images.reshape(test_images.shape[0],-1)
    test_images=test_images.astype('float32')/255
    #將標簽變為one-hot形狀 (60000, 10) float32
    train_labels=utils.to_categorical(train_labels)
    #測試數據標簽不用變為one-hot (10000,)
    test_labels=test_labels.astype(np.int)
    
    #定義神經網絡模型結構
    ann=create_ANN()
    #開始訓練    
    ann.train(train_images,cv2.ml.ROW_SAMPLE,train_labels)
    #在測試數據上測試準確率
    print(evaluate_acc(ann,test_images,test_labels))
    
    #保存模型
    ann.save('mnist_ann.xml')
    #加載模型
    myann=cv2.ml.ANN_MLP_load('mnist_ann.xml')

訓練100次得到的準確率為0.9376,可以接著增加訓練次數或者提高神經網絡的層次結構深度來提高準確率。
使用ann神經網絡的模型結構非常小,因為隻是保存瞭權重參數。

在這裡插入圖片描述

可以看到整個模型文件的大小才1M,而svm的大小為十多兆,knn的為幾百兆,因此使用ann神經網絡更加適合部署在客戶端上。
接下來使用ann進行圖片的測試識別:

import cv2
import numpy as np
if __name__=='__main__':
    #讀取圖片
    img=cv2.imread('shuzi.jpg',0)
    img_sw=img.copy()
    #將數據類型由uint8轉為float32
    img=img.astype(np.float32)
    #圖片形狀由(28,28)轉為(784,)
    img=img.reshape(-1,)
    #增加一個維度變為(1,784)
    img=img.reshape(1,-1)
    #圖片數據歸一化
    img=img/255
    #載入ann模型
    ann=cv2.ml.ANN_MLP_load('minist_ann.xml')
    #進行預測
    img_pre=ann.predict(img)
    #因為激活函數sigmoid,因此要進行置信度處理
    ret=img_pre[1]
    ret[ret>0.9]=1
    ret[ret<0.1]=0
    print(ret)
    cv2.imshow('test',img_sw)
    cv2.waitKey(0)

運行程序,結果如下,可見該模型正確識別瞭數字0.

在這裡插入圖片描述

到此這篇關於Python Opencv使用ann神經網絡識別手寫數字的文章就介紹到這瞭,更多相關python opencv識別手寫數字內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!

推薦閱讀: