TensorFlow.js機器學習預測鳶尾花種類

一、加載IRIS數據集

創建index.html入口文件,跳轉到script主文件。

<script src="script.js"></script>

在script.js文件夾中利用預先準備好的腳本生成鳶尾花數據集,包括訓練集和驗證集,並打印查看。

import {getIrisData, IRIS_CLASSES} from "./data.js";
window.onload = () => {
    // 加載數據
    const [xTrain, yTrain, xTest, yTest] = getIrisData(0.2);
    // 打印查看數據集
    xTrain.print();
    yTrain.print();
    xTest.print();
    yTest.print();
    // 打印鳶尾花種類類別
    console.log(IRIS_CLASSES);
}

getIrisData(0.2):獲取數據集的時候,將20%的數據當成測試集,剩下的80%當成訓練集。

xTrain:訓練集的特征值。

yTrain:訓練集的目標值。

xTest:驗證集的特征值。

yTest:驗證集的目標值。

可以在控制臺查看到結果:

其中特征矩陣裡面的四個值分別表示:花萼的長度、花萼的寬度、花瓣的長度、花瓣的寬度。

目標值矩陣采用one-hot編碼形式。

二、定義模型結構

初始化一個神經網絡模型,為神經網絡模型添加兩層,配置模型的損失函數、激活函數、優化器、添加準確度度量。

// 定義網絡模型
const model = tf.sequential();
// 添加隱藏層
model.add(tf.layers.dense({
    units: 10,
    inputShape: [xTrain.shape[1]],
    activation: 'relu'
}));
// 添加輸出層 
model.add(tf.layers.dense({
    units: 3,
    activation: 'softmax'
}));
// 配置模型
model.compile({
    loss: "categoricalCrossentropy",
    optimizer: tf.train.adam(0.1),
    metrics: ['accuracy']
});

三、訓練模型並可視化

訓練結果需要等待,所以采用異步方式訓練。

await model.fit(xTrain, yTrain,{
    epochs: 100,
    batchSize: 32,
    validationData: [xTest, yTest],
    callbacks: tfvis.show.fitCallbacks(
         {name: '訓練效果'},
         ['loss', 'val_loss', 'acc', 'val_acc'],
         {callbacks: ['onEpochEnd']}
    )
}); 

訓練結果:

四、預測

編寫前端界面輸入待預測數據,使用訓練好的模型進行預測,將輸出的Tensor轉成普通數據並顯示。

在index.html中編寫form表單,用來輸入預測數據。

<form action="" onsubmit="predict(this); return false">
    花萼長度:<input type="text" name="a"><br>
    花萼寬度:<input type="text" name="b"><br>
    花瓣長度:<input type="text" name="c"><br>
    花瓣寬度:<input type="text" name="d"><br>
    <button type="submit">預測</button>
</form>

輸入數據的順序不能錯,因為我們訓練數據的順序就是花萼長度、花萼寬度、花瓣長度、花瓣寬度。

在Script.js中編寫predict預測函數。

 window.predict = (form) => {
        // 將表單獲取的到數據轉成Tensor
        const input = tf.tensor([[
            form.a.value * 1,
            form.b.value * 1,
            form.c.value * 1,
            form.d.value * 1,
        ]]);
        // 預測
        const pred = model.predict(input);
        alert(`預測結果:${IRIS_CLASSES[pred.argMax(1).dataSync(0)]}`)
    }

預測結果:gif動圖有點模糊,可以自己動手試試看哦。

到此這篇關於TensorFlow.js機器學習預測鳶尾花種類的文章就介紹到這瞭,更多相關TensorFlow.js預測鳶尾花內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!

推薦閱讀: