基於keras中訓練數據的幾種方式對比(fit和fit_generator)
一、train_on_batch
model.train_on_batch(batchX, batchY)
train_on_batch函數接受單批數據,執行反向傳播,然後更新模型參數,該批數據的大小可以是任意的,即,它不需要提供明確的批量大小,屬於精細化控制訓練模型,大部分情況下我們不需要這麼精細,99%情況下使用fit_generator訓練方式即可,下面會介紹。
二、fit
model.fit(x_train, y_train, batch_size=32, epochs=10)
fit的方式是一次把訓練數據全部加載到內存中,然後每次批處理batch_size個數據來更新模型參數,epochs就不用多介紹瞭。這種訓練方式隻適合訓練數據量比較小的情況下使用。
三、fit_generator
利用Python的生成器,逐個生成數據的batch並進行訓練,不占用大量內存,同時生成器與模型將並行執行以提高效率。例如,該函數允許我們在CPU上進行實時的數據提升,同時在GPU上進行模型訓練
接口如下:
fit_generator(self, generator, steps_per_epoch, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_q_size=10, workers=1, pickle_safe=False, initial_epoch=0)
generator
:生成器函數
steps_per_epoch
:整數,當生成器返回steps_per_epoch次數據時,計一個epoch結束,執行下一個epoch。也就是一個epoch下執行多少次batch_size。
epochs
:整數,控制數據迭代的輪數,到瞭就結束訓練。
callbacks=None, list,list中的元素為keras.callbacks.Callback對象,在訓練過程中會調用list中的回調函數
舉例:
def generate_arrays_from_file(path): while True: with open(path) as f: for line in f: # create numpy arrays of input data # and labels, from each line in the file x1, x2, y = process_line(line) yield ({'input_1': x1, 'input_2': x2}, {'output': y}) model.fit_generator(generate_arrays_from_file('./my_folder'), steps_per_epoch=10000, epochs=10)
補充:keras.fit_generator()屬性及取值
如下所示:
fit_generator(self, generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)
通過Python generator產生一批批的數據用於訓練模型。generator可以和模型並行運行,例如,可以使用CPU生成批數據同時在GPU上訓練模型。
參數:
generator
:一個generator或Sequence實例,為瞭避免在使用multiprocessing時直接復制數據。
steps_per_epoch
:從generator產生的步驟的總數(樣本批次總數)。通常情況下,應該等於數據集的樣本數量除以批量的大小。
epochs
:整數,在數據集上迭代的總數。
works
:在使用基於進程的線程時,最多需要啟動的進程數量。
use_multiprocessing
:佈爾值。當為True時,使用基於基於過程的線程。
例如:
datagen = ImageDataGenator(...) model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size), epochs=epochs, validation_data=(x_test, y_test), workers=4)
以上為個人經驗,希望能給大傢一個參考,也希望大傢多多支持WalkonNet。
推薦閱讀:
- python之tensorflow手把手實例講解貓狗識別實現
- python生成器generator:深度學習讀取batch圖片的操作
- pytorch鎖死在dataloader(訓練時卡死)
- Python實戰之MNIST手寫數字識別詳解
- 詳解TensorFlow訓練網絡兩種方式