opencv3機器學習之EM算法示例詳解
引言
不同於其它的機器學習模型,EM算法是一種非監督的學習算法,它的輸入數據事先不需要進行標註。相反,該算法從給定的樣本集中,能計算出高斯混和參數的最大似然估計。也能得到每個樣本對應的標註值,類似於kmeans聚類(輸入樣本數據,輸出樣本數據的標註)。實際上,高斯混和模型GMM和kmeans都是EM算法的應用。
在opencv3.0中,EM算法的函數是trainEM,函數原型為:
bool trainEM(InputArray samples, OutputArray logLikelihoods=noArray(),OutputArray labels=noArray(),OutputArray probs=noArray())
四個參數:
samples
: 輸入的樣本,一個單通道的矩陣。從這個樣本中,進行高斯混和模型估計。
logLikelihoods
: 可選項,輸出一個矩陣,裡面包含每個樣本的似然對數值。
labels
: 可選項,輸出每個樣本對應的標註。
probs
: 可選項,輸出一個矩陣,裡面包含每個隱性變量的後驗概率
這個函數沒有輸入參數的初始化值,是因為它會自動執行kmeans算法,將kmeans算法得到的結果作為參數初始化。
這個trainEM函數實際把E步驟和M步驟都包含進去瞭,我們也可以對兩個步驟分開執行,OPENCV3.0中也提供瞭分別執行的函數:
bool trainE(InputArray samples, InputArray means0, InputArray covs0=noArray(), InputArray weights0=noArray(), OutputArray logLikelihoods=noArray(), OutputArray labels=noArray(), OutputArray probs=noArray())
bool trainM(InputArray samples, InputArray probs0, OutputArray logLikelihoods=noArray(), OutputArray labels=noArray(), OutputArray probs=noArray())
trainEM函數的功能和kmeans差不多,都是實現自動聚類,輸出每個樣本對應的標註值。但它比kmeans還多出一個功能,就是它還能起到訓練分類器的作用,用於後續新樣本的預測。
預測函數原型為:
Vec2d predict2(InputArray sample, OutputArray probs) const
sample
: 待測樣本
probs
: 和上面一樣,一個可選的輸出值,包含每個隱性變量的後驗概率
返回一個Vec2d類型的數,包括兩個元素的double向量,第一個元素為樣本的似然對數值,第二個元素為最大可能混和分量的索引值。
在本文中,我們用兩個實例來學習opencv中的EM算法的應用。
一、opencv3.0中自帶的例子
既包括聚類trianEM,也包括預測predict2
代碼:
#include "stdafx.h" #include "opencv2/opencv.hpp" #include <iostream> using namespace std; using namespace cv; using namespace cv::ml; //使用EM算法實現樣本的聚類及預測 int main() { const int N = 4; //分成4類 const int N1 = (int)sqrt((double)N); //定義四種顏色,每一類用一種顏色表示 const Scalar colors[] = { Scalar(0, 0, 255), Scalar(0, 255, 0), Scalar(0, 255, 255), Scalar(255, 255, 0) }; int i, j; int nsamples = 100; //100個樣本點 Mat samples(nsamples, 2, CV_32FC1); //樣本矩陣,100行2列,即100個坐標點 Mat img = Mat::zeros(Size(500, 500), CV_8UC3); //待測數據,每一個坐標點為一個待測數據 samples = samples.reshape(2, 0); //循環生成四個類別樣本數據,共樣本100個,每類樣本25個 for (i = 0; i < N; i++) { Mat samples_part = samples.rowRange(i*nsamples / N, (i + 1)*nsamples / N); //設置均值 Scalar mean(((i%N1) + 1)*img.rows / (N1 + 1), ((i / N1) + 1)*img.rows / (N1 + 1)); //設置標準差 Scalar sigma(30, 30); randn(samples_part, mean, sigma); //根據均值和標準差,隨機生成25個正態分佈坐標點作為樣本 } samples = samples.reshape(1, 0); // 訓練分類器 Mat labels; //標註,不需要事先知道 Ptr<EM> em_model = EM::create(); em_model->setClustersNumber(N); em_model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL); em_model->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 300, 0.1)); em_model->trainEM(samples, noArray(), labels, noArray()); //對每個坐標點進行分類,並根據類別用不同的顏色畫出 Mat sample(1, 2, CV_32FC1); for (i = 0; i < img.rows; i++) { for (j = 0; j < img.cols; j++) { sample.at<float>(0) = (float)j; sample.at<float>(1) = (float)i; //predict2返回的是double值,用cvRound進行四舍五入得到整型 //此處返回的是兩個值Vec2d,取第二個值作為樣本標註 int response = cvRound(em_model->predict2(sample, noArray())[1]); Scalar c = colors[response]; //為不同類別設定顏色 circle(img, Point(j, i), 1, c*0.75, FILLED); } } //畫出樣本點 for (i = 0; i < nsamples; i++) { Point pt(cvRound(samples.at<float>(i, 0)), cvRound(samples.at<float>(i, 1))); circle(img, pt, 2, colors[labels.at<int>(i)], FILLED); } imshow("EM聚類結果", img); waitKey(0); return 0; }
結果:
二、trainEM實現自動聚類進行圖片目標檢測
隻用trainEM實現自動聚類功能,進行圖片中的目標檢測
代碼:
#include "stdafx.h" #include "opencv2/opencv.hpp" #include <iostream> using namespace std; using namespace cv; using namespace cv::ml; int main() { const int MAX_CLUSTERS = 5; Vec3b colorTab[] = { Vec3b(0, 0, 255), Vec3b(0, 255, 0), Vec3b(255, 100, 100), Vec3b(255, 0, 255), Vec3b(0, 255, 255) }; Mat data, labels; Mat pic = imread("d:/woman.png"); for (int i = 0; i < pic.rows; i++) for (int j = 0; j < pic.cols; j++) { Vec3b point = pic.at<Vec3b>(i, j); Mat tmp = (Mat_<float>(1, 3) << point[0], point[1], point[2]); data.push_back(tmp); } int N =3; //聚成3類 Ptr<EM> em_model = EM::create(); em_model->setClustersNumber(N); em_model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL); em_model->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 300, 0.1)); em_model->trainEM(data, noArray(), labels, noArray()); int n = 0; //顯示聚類結果,不同的類別用不同的顏色顯示 for (int i = 0; i < pic.rows; i++) for (int j = 0; j < pic.cols; j++) { int clusterIdx = labels.at<int>(n); pic.at<Vec3b>(i, j) = colorTab[clusterIdx]; n++; } imshow("pic", pic); waitKey(0); return 0; }
測試圖片
測試結果:
以上就是opencv3機器學習之EM算法的詳細內容,更多關於opencv3 EM算法的資料請關註WalkonNet其它相關文章!