LightGBMの進捗をプログレスバーで表示する

2022/02/28

学習の進捗がずらずらっと標準出力に出てきちゃうの邪魔だよね。特にNotebook環境の時はすごい見にくくなるし。

今日はPythonライブラリのtqdmを使っていい感じのプログレスバーで進捗を表示するスニペットを紹介するよ。

1イテレーションにかかる時間もサクっと確認できて便利!

LightGBMのバージョン

3.3.2

スニペット

tqdmを初めて使う時はpipやcondaでインストールしておこう。

pip install tqdm

LightGBMのtrainには関数を指定することでイテレーション毎に任意の処理を実行出来るcallbacksオプションがある。

そこにtqdmで作ったプログレスバーを更新する処理を挟んで実現する仕組み。

# str | Noneって書きたいところだけど3.7系でもいけるようにOptionalを使う
from typing import Optional
from collections import OrderedDict
from lightgbm.callback import CallbackEnv
from tqdm.auto import tqdm

class LgbmProgressBarCallback:
    description: Optional[str]
    pbar: tqdm

    def __init__(self, description: Optional[str] = None):
        self.description = description
        self.pbar = tqdm()

    def __call__(self, env: CallbackEnv):

        # 初回だけProgressBarを初期化する
        is_first_iteration: bool = env.iteration == env.begin_iteration

        if is_first_iteration:
            total: int = env.end_iteration - env.begin_iteration
            self.pbar.reset(total=total)
            self.pbar.set_description(self.description, refresh=False)

        # valid_setsの評価結果を更新
        if len(env.evaluation_result_list) > 0:
            # OrderedDictにしないと表示順がバラバラになって若干見にくい
            postfix = OrderedDict(
                [
                    (f"{entry[0]}:{entry[1]}", str(entry[2]))
                    for entry in env.evaluation_result_list
                ]
            )
            self.pbar.set_postfix(ordered_dict=postfix, refresh=False)

        # 進捗を1進める
        self.pbar.update(1)
        self.pbar.refresh()

使う時はこう。

booster = lgbm.train(
    {**param, "verbosity": -1},
    dataset_train,
    valid_sets=dataset_valid,
    num_boost_round=10,
    callbacks=[
      LgbmProgressBarCallback(description="Model A"),
    ],
)

これでイテレーションごとにプログレスバーの進捗が更新されるようになるよ。

verbosity

進捗以外のお知らせは普通に出力されちゃうから出来るだけ減らしたい場合は併せて"verbosity": -1を指定しておくといいかも。

Info系の出力が無くなってWarn Error系だけになるよ。

あと冒頭のfrom tqdm.auto import tqdmっていう部分、tqdmはこうやって書くとNotebook環境を自動で検知してリッチなウィジェットを表示してくれるんだけどVSCodeのNotebookだとうまくいかなかった。

該当部分のソースコードを読んでみるとVSCodeはサポート外ということでわざとauto時のウィジェット表示をしないようにしてるみたい。

from tqdm.notebook import tqdmってやれば普通にウィジェット表示にできるからVSCodeの人は試してみてね。

CallbackEnvの中身

callbacksに渡した関数にはCallbackEnvっていう形式のオブジェクトが引数として与えられる。

ここには学習中のBoosterインスタンスなんかも含まれるから今回みたいな進捗の表示以外にも色々なことが出来るよ。

バージョン3.3.2の時点のCallbackEnvの中身はこんな感じ。

プロパティ説明
iterationint現在のiterationのindex
begin_iterationint学習を開始したiterationのindex
end_iterationint総iteration数
modelBoosterBoosterインスタンス
paramsDict学習パラメータ
evaluation_result_listListvalid_setsに対する評価結果

modelとparamsも参照できるから特定のタイミングでmodelを保存したりparamsに変更を加えて学習時の振る舞いをカスタムしたりも出来るね。

iterationbegin_iterationがindexなのに対してend_iterationはlengthなことに注意。

例えばnum_boost_round = 10の時にはbegin_iterationは0, iterationは0~9, end_iterationは10になる。

begin_iterationが0じゃないケースは既存のBoosterで学習を再開した時とかだと思う。

evaluation_result_listには(valid_name, metric, value, is_higher_better)形式のTupleが複数格納されている感じ。

# evaluation_result_listの例
[
  ("valid_0", "binary_logloss", 0.034567, False),
  ("valid_0", "auc", 0.77, True),
  ("valid_1", "binary_logloss", 0.023456,  False),
  ("valid_1", "auc", 0.83, True),
]

lightgbm.callbackモジュールの中にも公式で提供されているコールバック関数が何個か入ってて、実装はすごくシンプルで分かりやすいから興味があったら参考にしてみるといいかも。

https://github.com/microsoft/LightGBM/blob/master/python-package/lightgbm/callback.py

  • log_evaluation
    • valid_setsに対する評価結果をロガーに出力する
    • 学習パラメータのverbose_evalは廃止になるからこっち使ってねってことみたい
  • record_evaluation
    • 指定したDictにevaluation_result_listの内容を転記する
  • reset_parameter
    • 初回iteration終了時に指定したパラメータを上書きする。iterationを引数にとる関数を渡せるから進捗に応じた learning_rate の調整に利用するような想定なのかも
  • early_stopping
    • 指定したiteration回数以上valid_setsに対するmetricが向上しなかったら学習を打ち切る
    • こちらも学習パラメータのearly_stopping_roundよりこっち使ってねって感じっぽい

みんな自分だけのコールバックを作ろう!

自分はどうにかして learning_rate をうまく調整して学習時間を短縮しつつ精度もギリギリまで確保するようなコールバックを作ろうとしてるんだけど全然うまくいってない😇

いい感じのが出来たら教えてね〜🙋‍♂️