プログレスバーで表示するシリーズ第三弾はKeras!
学習の進捗をtqdm
でかっこよく表示するスニペットを紹介するよ。
TensorBoardも超いいけどこういうシンプルな進捗表示もアリだよね。
これまでのプログレスバーシリーズ
TensorFlowのバージョン
2.7.0
インストール
tqdm
のインストールがまだの人はインストールしておこう。
pip install tqdm
スニペット
単なる関数じゃなくてkeras.callbacks.Callback
を継承したクラスを作るのがポイント。
epoch
とstep
それぞれに対して2つプログレスバーを表示しているよ。
from typing import Dict, Optional
from collections import OrderedDict
from tensorflow import keras
from tqdm.auto import tqdm
class KerasProgressBarCallback(keras.callbacks.Callback):
description: Optional[str]
pbar_epoch: tqdm
pbar_step: tqdm
def __init__(self, description: Optional[str] = None):
self.description = description
self.pbar_epoch = tqdm(desc=self.description)
self.pbar_step = tqdm(desc="steps")
def __update_postfix(self, pbar: tqdm, logs: Dict[str, float]):
postfix = OrderedDict([(key, str(value)) for key, value in logs.items()])
pbar.set_postfix(ordered_dict=postfix, refresh=False)
def on_train_begin(self, logs={}):
total_epoch = self.params.get("epochs")
total_step = self.params.get("steps")
self.pbar_epoch.reset(total=total_epoch)
self.pbar_step.reset(total=total_step)
def on_train_end(self, logs={}):
self.pbar_epoch.close()
self.pbar_step.close()
def on_epoch_begin(self, epoch, logs={}):
self.pbar_step.reset()
def on_epoch_end(self, epoch, logs={}):
self.__update_postfix(self.pbar_epoch, logs)
self.pbar_epoch.update(1)
def on_batch_end(self, step, logs={}):
self.__update_postfix(self.pbar_step, logs)
self.pbar_step.update(1)
on_*_{begin,end}
の形式でタイミングに従ったメソッドが呼ばれるからそこにプログレスバーを制御するコードを差し込んでいく感じ。
カスタムコールバック
on_predict_batch_begin
みたいに推論時に動くフックも書けたりするよ。
カスタムコールバックについてはこのページに詳しく載ってるからチェックしてみてね。
使い方
fit()
する時のcallbacks
オプションに渡して使う。
model.fit(
x_train,
y_train,
validation_data=(x_valid, y_valid),
epochs=100,
batch_size=2048,
callbacks=[KerasProgressBarCallback(description="MLP")],
verbose=0,
)
デフォルトの出力は必要ないからverbose=0
を指定しよう。