プログレスバーで表示するシリーズ第三弾はKeras!
学習の進捗をtqdmでかっこよく表示するスニペットを紹介するよ。
TensorBoardも超いいけどこういうシンプルな進捗表示もアリだよね。
これまでのプログレスバーシリーズ
TensorFlowのバージョン
2.7.0
インストール
tqdmのインストールがまだの人はインストールしておこう。
pip install tqdmスニペット
単なる関数じゃなくてkeras.callbacks.Callbackを継承したクラスを作るのがポイント。
epochとstepそれぞれに対して2つプログレスバーを表示しているよ。
from typing import Dict, Optional
from collections import OrderedDict
from tensorflow import keras
from tqdm.auto import tqdm
class KerasProgressBarCallback(keras.callbacks.Callback):
description: Optional[str]
pbar_epoch: tqdm
pbar_step: tqdm
def __init__(self, description: Optional[str] = None):
self.description = description
self.pbar_epoch = tqdm(desc=self.description)
self.pbar_step = tqdm(desc="steps")
def __update_postfix(self, pbar: tqdm, logs: Dict[str, float]):
postfix = OrderedDict([(key, str(value)) for key, value in logs.items()])
pbar.set_postfix(ordered_dict=postfix, refresh=False)
def on_train_begin(self, logs={}):
total_epoch = self.params.get("epochs")
total_step = self.params.get("steps")
self.pbar_epoch.reset(total=total_epoch)
self.pbar_step.reset(total=total_step)
def on_train_end(self, logs={}):
self.pbar_epoch.close()
self.pbar_step.close()
def on_epoch_begin(self, epoch, logs={}):
self.pbar_step.reset()
def on_epoch_end(self, epoch, logs={}):
self.__update_postfix(self.pbar_epoch, logs)
self.pbar_epoch.update(1)
def on_batch_end(self, step, logs={}):
self.__update_postfix(self.pbar_step, logs)
self.pbar_step.update(1)on_*_{begin,end}の形式でタイミングに従ったメソッドが呼ばれるからそこにプログレスバーを制御するコードを差し込んでいく感じ。
カスタムコールバック
on_predict_batch_beginみたいに推論時に動くフックも書けたりするよ。
カスタムコールバックについてはこのページに詳しく載ってるからチェックしてみてね。
使い方
fit()する時のcallbacksオプションに渡して使う。
model.fit(
x_train,
y_train,
validation_data=(x_valid, y_valid),
epochs=100,
batch_size=2048,
callbacks=[KerasProgressBarCallback(description="MLP")],
verbose=0,
)デフォルトの出力は必要ないからverbose=0を指定しよう。