2023/05/11
2023/05/11
PyTorch Lightningを久しぶりに触ったらよく分からなくなってしまったので整理してみる.
PyTorchを基盤としたMLフレームワーク.
PyTorchそのままでは,データセットダウンロードや学習ループなどを1つの処理として書き連ねる必要がある.これをクラス/関数に分割することでコードを書きやすくする.
気づいたら似たような名前のパッケージが大量発生して訳が分からないことになっていた.初めにこれらのパッケージについて整理してみる.
pytorch_lightninglightninglightning.pytorch, lightning.fabric, lightning.app の3つのサブパッケージを持つ.lightning.pytorchlightning.fabriclightning.app単に"Lightning"と呼ぶと"PyTorch Lightning" か "Lightning"かの区別がつかなくなってしまった. この記事で扱うのは"PyTorch Lightning".
("PyTorch Lightning"でググると(旧)PyTorch Lightningのページが一番上に出てくるのでつらい)
パッケージ名が変わったので,インポート方法と推奨される別名も変わっている.
(旧)PyTorch Lightningが次の通りで,
import pytorch_lightning as pl import polars as pl # 衝突💥
そしてLightningは次の通り
import lightning as L import polars as pl
データフレームライブラリPolarsと衝突しなくなったので少し嬉しい.
PyTorch Lightningの主要(と思われる)コンポーネントはLightningModule, LightningDataModule, Trainerの3つ.
多分この3つの概要を知っておけば理解が早いはず.
ドキュメント: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
nn.Moduleを継承していて,さらに
__init__(), setup())forward())training_step())validation_step())test_step())predict_step())configure_optimizer())のためのメソッドを実装できるようになっている.
__init__()とsetup()でモデル定義を行ったりする.
なぜ2つあるのかについてはあまり分かってない.(__init__()で実装すれば十分な気がする)
forward()で実装する.作成したモデルにバッチを入れてTensorを返す.
それぞれtraining_step(),validation_step(),test_step(),predict_step()で実装する.
training_step()ではミニバッチがbatchとして渡されるので,forward()などを用いて計算した後にlossを返す.
他のメソッドも大体同じだと思う(ちゃんと調べてない).
ドキュメント: https://lightning.ai/docs/pytorch/stable/data/datamodule.html
PyTorchのDatasetとDataLoaderをラップするコンポーネント.それぞれのメソッドで以下の処理を行う.
setup())train_dataloader())val_dataloader())test_dataloader())Datasetはデータの読み込みとTensorへの変換(+Data Augmentationなど),DataLoaderはDatasetからデータを読み込んでミニバッチにする役目のコンポーネントと覚えておくのが良さそう.
...ここまで書いて気づいたが,LightningDataModuleは公式ドキュメントの"Core API"として紹介されていない.もしかしたらあまり推奨されてないかもしれない.
(実際,LightningDataModuleを使わなくてもなんとかなる.)
ドキュメント: https://lightning.ai/docs/pytorch/stable/common/trainer.html
学習を回すためのコンポーネント.
最大エポック数やEarlyStoppingのためのコールバックなどを渡して作成し,LightningDataModuleとLightningModuleをfit()に渡せば学習を回してくれる.
trainer = pl.Trainer( max_epochs=100, callbacks=[ EarlyStopping(monitor="val_loss", patience=3), # val_lossが3回連続で更新されなければ終了 RichProgressBar(), RichModelSummary(), ], ) trainer.fit(compiled_model, datamodule=datamodule)
同様にvalidate/test用のメソッドも存在する.
また,LightningDataModuleではなくDataLoaderを直接渡しても良い.(公式ドキュメントはDataLoaderを直接渡している)