2023/05/11
2023/05/11
PyTorch Lightningを久しぶりに触ったらよく分からなくなってしまったので整理してみる.
PyTorchを基盤としたMLフレームワーク.
PyTorchそのままでは,データセットダウンロードや学習ループなどを1つの処理として書き連ねる必要がある.これをクラス/関数に分割することでコードを書きやすくする.
気づいたら似たような名前のパッケージが大量発生して訳が分からないことになっていた.初めにこれらのパッケージについて整理してみる.
pytorch_lightning
lightning
lightning.pytorch
, lightning.fabric
, lightning.app
の3つのサブパッケージを持つ.lightning.pytorch
lightning.fabric
lightning.app
単に"Lightning"と呼ぶと"PyTorch Lightning" か "Lightning"かの区別がつかなくなってしまった. この記事で扱うのは"PyTorch Lightning".
("PyTorch Lightning"でググると(旧)PyTorch Lightningのページが一番上に出てくるのでつらい)
パッケージ名が変わったので,インポート方法と推奨される別名も変わっている.
(旧)PyTorch Lightningが次の通りで,
import pytorch_lightning as pl import polars as pl # 衝突💥
そしてLightningは次の通り
import lightning as L import polars as pl
データフレームライブラリPolarsと衝突しなくなったので少し嬉しい.
PyTorch Lightningの主要(と思われる)コンポーネントはLightningModule
, LightningDataModule
, Trainer
の3つ.
多分この3つの概要を知っておけば理解が早いはず.
ドキュメント: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
nn.Module
を継承していて,さらに
__init__()
, setup()
)forward()
)training_step()
)validation_step()
)test_step()
)predict_step()
)configure_optimizer()
)のためのメソッドを実装できるようになっている.
__init__()
とsetup()
でモデル定義を行ったりする.
なぜ2つあるのかについてはあまり分かってない.(__init__()
で実装すれば十分な気がする)
forward()
で実装する.作成したモデルにバッチを入れてTensorを返す.
それぞれtraining_step()
,validation_step()
,test_step()
,predict_step()
で実装する.
training_step()
ではミニバッチがbatch
として渡されるので,forward()
などを用いて計算した後にlossを返す.
他のメソッドも大体同じだと思う(ちゃんと調べてない).
ドキュメント: https://lightning.ai/docs/pytorch/stable/data/datamodule.html
PyTorchのDataset
とDataLoader
をラップするコンポーネント.それぞれのメソッドで以下の処理を行う.
setup()
)train_dataloader()
)val_dataloader()
)test_dataloader()
)Datasetはデータの読み込みとTensorへの変換(+Data Augmentationなど),DataLoaderはDatasetからデータを読み込んでミニバッチにする役目のコンポーネントと覚えておくのが良さそう.
...ここまで書いて気づいたが,LightningDataModule
は公式ドキュメントの"Core API"として紹介されていない.もしかしたらあまり推奨されてないかもしれない.
(実際,LightningDataModule
を使わなくてもなんとかなる.)
ドキュメント: https://lightning.ai/docs/pytorch/stable/common/trainer.html
学習を回すためのコンポーネント.
最大エポック数やEarlyStoppingのためのコールバックなどを渡して作成し,LightningDataModule
とLightningModule
をfit()
に渡せば学習を回してくれる.
trainer = pl.Trainer( max_epochs=100, callbacks=[ EarlyStopping(monitor="val_loss", patience=3), # val_lossが3回連続で更新されなければ終了 RichProgressBar(), RichModelSummary(), ], ) trainer.fit(compiled_model, datamodule=datamodule)
同様にvalidate/test用のメソッドも存在する.
また,LightningDataModule
ではなくDataLoader
を直接渡しても良い.(公式ドキュメントはDataLoader
を直接渡している)