Python Opencv使用ann神經網絡識別手寫數字功能
opencv中也提供瞭一種類似於Keras的神經網絡,即為ann,這種神經網絡的使用方法與Keras的很接近。
關於mnist數據的解析,讀者可以自己從網上下載相應壓縮文件,用python自己編寫解析代碼,由於這裡主要研究knn算法,為瞭圖簡單,直接使用Keras的mnist手寫數字解析模塊。
本次代碼運行環境為:
python 3.6.8
opencv-python 4.4.0.46
opencv-contrib-python 4.4.0.46
下面的代碼為使用ann進行模型的訓練:
from keras.datasets import mnist from keras import utils import cv2 import numpy as np #opencv中ANN定義神經網絡層 def create_ANN(): ann=cv2.ml.ANN_MLP_create() #設置神經網絡層的結構 輸入層為784 隱藏層為80 輸出層為10 ann.setLayerSizes(np.array([784,64,10])) #設置網絡參數為誤差反向傳播法 ann.setTrainMethod(cv2.ml.ANN_MLP_BACKPROP) #設置激活函數為sigmoid ann.setActivationFunction(cv2.ml.ANN_MLP_SIGMOID_SYM) #設置訓練迭代條件 #結束條件為訓練30次或者誤差小於0.00001 ann.setTermCriteria((cv2.TermCriteria_EPS|cv2.TermCriteria_COUNT,100,0.0001)) return ann #計算測試數據上的識別率 def evaluate_acc(ann,test_images,test_labels): #采用的sigmoid激活函數,需要對結果進行置信度處理 #對於大於0.99的可以確定為1 對於小於0.01的可以確信為0 test_ret=ann.predict(test_images) #預測結果是一個元組 test_pre=test_ret[1] #可以直接最大值的下標 (10000,) test_pre=test_pre.argmax(axis=1) true_sum=(test_pre==test_labels) return true_sum.mean() if __name__=='__main__': #直接使用Keras載入的訓練數據(60000, 28, 28) (60000,) (train_images,train_labels),(test_images,test_labels)=mnist.load_data() #變換數據的形狀並歸一化 train_images=train_images.reshape(train_images.shape[0],-1)#(60000, 784) train_images=train_images.astype('float32')/255 test_images=test_images.reshape(test_images.shape[0],-1) test_images=test_images.astype('float32')/255 #將標簽變為one-hot形狀 (60000, 10) float32 train_labels=utils.to_categorical(train_labels) #測試數據標簽不用變為one-hot (10000,) test_labels=test_labels.astype(np.int) #定義神經網絡模型結構 ann=create_ANN() #開始訓練 ann.train(train_images,cv2.ml.ROW_SAMPLE,train_labels) #在測試數據上測試準確率 print(evaluate_acc(ann,test_images,test_labels)) #保存模型 ann.save('mnist_ann.xml') #加載模型 myann=cv2.ml.ANN_MLP_load('mnist_ann.xml')
訓練100次得到的準確率為0.9376,可以接著增加訓練次數或者提高神經網絡的層次結構深度來提高準確率。
使用ann神經網絡的模型結構非常小,因為隻是保存瞭權重參數。
可以看到整個模型文件的大小才1M,而svm的大小為十多兆,knn的為幾百兆,因此使用ann神經網絡更加適合部署在客戶端上。
接下來使用ann進行圖片的測試識別:
import cv2 import numpy as np if __name__=='__main__': #讀取圖片 img=cv2.imread('shuzi.jpg',0) img_sw=img.copy() #將數據類型由uint8轉為float32 img=img.astype(np.float32) #圖片形狀由(28,28)轉為(784,) img=img.reshape(-1,) #增加一個維度變為(1,784) img=img.reshape(1,-1) #圖片數據歸一化 img=img/255 #載入ann模型 ann=cv2.ml.ANN_MLP_load('minist_ann.xml') #進行預測 img_pre=ann.predict(img) #因為激活函數sigmoid,因此要進行置信度處理 ret=img_pre[1] ret[ret>0.9]=1 ret[ret<0.1]=0 print(ret) cv2.imshow('test',img_sw) cv2.waitKey(0)
運行程序,結果如下,可見該模型正確識別瞭數字0.
到此這篇關於Python Opencv使用ann神經網絡識別手寫數字的文章就介紹到這瞭,更多相關python opencv識別手寫數字內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!
推薦閱讀:
- Python實戰之MNIST手寫數字識別詳解
- TensorFlow教程Softmax邏輯回歸識別手寫數字MNIST數據集
- Python深度學習pytorch卷積神經網絡LeNet
- 由淺入深學習TensorFlow MNIST 數據集
- Python MNIST手寫體識別詳解與試練