TensorFlow.js機器學習預測鳶尾花種類
一、加載IRIS數據集
創建index.html入口文件,跳轉到script主文件。
<script src="script.js"></script>
在script.js文件夾中利用預先準備好的腳本生成鳶尾花數據集,包括訓練集和驗證集,並打印查看。
import {getIrisData, IRIS_CLASSES} from "./data.js"; window.onload = () => { // 加載數據 const [xTrain, yTrain, xTest, yTest] = getIrisData(0.2); // 打印查看數據集 xTrain.print(); yTrain.print(); xTest.print(); yTest.print(); // 打印鳶尾花種類類別 console.log(IRIS_CLASSES); }
getIrisData(0.2):獲取數據集的時候,將20%的數據當成測試集,剩下的80%當成訓練集。
xTrain:訓練集的特征值。
yTrain:訓練集的目標值。
xTest:驗證集的特征值。
yTest:驗證集的目標值。
可以在控制臺查看到結果:
其中特征矩陣裡面的四個值分別表示:花萼的長度、花萼的寬度、花瓣的長度、花瓣的寬度。
目標值矩陣采用one-hot編碼形式。
二、定義模型結構
初始化一個神經網絡模型,為神經網絡模型添加兩層,配置模型的損失函數、激活函數、優化器、添加準確度度量。
// 定義網絡模型 const model = tf.sequential(); // 添加隱藏層 model.add(tf.layers.dense({ units: 10, inputShape: [xTrain.shape[1]], activation: 'relu' })); // 添加輸出層 model.add(tf.layers.dense({ units: 3, activation: 'softmax' })); // 配置模型 model.compile({ loss: "categoricalCrossentropy", optimizer: tf.train.adam(0.1), metrics: ['accuracy'] });
三、訓練模型並可視化
訓練結果需要等待,所以采用異步方式訓練。
await model.fit(xTrain, yTrain,{ epochs: 100, batchSize: 32, validationData: [xTest, yTest], callbacks: tfvis.show.fitCallbacks( {name: '訓練效果'}, ['loss', 'val_loss', 'acc', 'val_acc'], {callbacks: ['onEpochEnd']} ) });
訓練結果:
四、預測
編寫前端界面輸入待預測數據,使用訓練好的模型進行預測,將輸出的Tensor轉成普通數據並顯示。
在index.html中編寫form表單,用來輸入預測數據。
<form action="" onsubmit="predict(this); return false"> 花萼長度:<input type="text" name="a"><br> 花萼寬度:<input type="text" name="b"><br> 花瓣長度:<input type="text" name="c"><br> 花瓣寬度:<input type="text" name="d"><br> <button type="submit">預測</button> </form>
輸入數據的順序不能錯,因為我們訓練數據的順序就是花萼長度、花萼寬度、花瓣長度、花瓣寬度。
在Script.js中編寫predict預測函數。
window.predict = (form) => { // 將表單獲取的到數據轉成Tensor const input = tf.tensor([[ form.a.value * 1, form.b.value * 1, form.c.value * 1, form.d.value * 1, ]]); // 預測 const pred = model.predict(input); alert(`預測結果:${IRIS_CLASSES[pred.argMax(1).dataSync(0)]}`) }
預測結果:gif動圖有點模糊,可以自己動手試試看哦。
到此這篇關於TensorFlow.js機器學習預測鳶尾花種類的文章就介紹到這瞭,更多相關TensorFlow.js預測鳶尾花內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!
推薦閱讀:
- Tensorflow 實現線性回歸模型的示例代碼
- 前端AI機器學習在瀏覽器中訓練模型
- 基於tensorflow權重文件的解讀
- python進階TensorFlow神經網絡擬合線性及非線性函數
- tensorflow2.0教程之Keras快速入門