今回は深層学習の自然言語モデルである、BERTについて色々触っていたのでその記録を簡単にまとめたいと思う。
ひょんなことから、BERTの事前学習からやらなければならなかったので、実際に用いたライブラリや環境、学習に必要な時間などについて書いていく。
BERTのライブラリ
BERTのライブラリとしては、huggingfaceのTransformersを利用。
huggingface.co
特に、Transformersを用いたMasked Language Modelのサンプルコードを参考に作成した。
github.com
ベースライブラリ(と言って良いのかな?)はPytorchを採用。TensorFlowの方が色々とやってくれてるぽい感じで楽そうではあったが、深層学習やBERTに初挑戦の身としては、逆に何をやっているのか分かりづらかったので、比較的細部までいじれるPytorchを採用した。
学習環境
AWSの色々なインスタンスタイプで学習をおこなった。その時に得た知見を整理していく。
AMI
Deep Learning AMI (Ubuntu 18.04) Version 57.0 - ami-0da3b2f58df4b61cb
を利用。最近はコンテナを使う方が主流らしいが一通り学習を流すだけだったので慣れているAMIを選択した。
Instance Type
さまざまなインスタンスタイプでの推定学習時間を計測した。今回は3日程度である程度学習を完了させる必要があったため、少しデータ件数やモデルのサイズを変更して試した。
inf1.2xlarge
最初はinf1.xlargeで試そうとしたが、そもそもtransfomersのライブラリインストールがうまくいかなかった。vCPUが1つだとインストール時にハングアウトしてしまい、失敗するようだ。そのため、以降でも基本的にvCPU数2以上のインスタンスを採用している。
こちらは以下の条件で、学習に必要な時間が 6300h over だった。。。。
- 学習データ: 約 550万件
- epoch数: 10
- bach per device: 16
p2系
CPUのみだとやはり辛いことがわかったので、GPU系のインスタンスに切り替えた。
しかし、結果としてP2系ではPytorchがうまくGPUを認識してくれず、学習の実行まではできなかった。一応、Tesla系のライブラリを入れ直したりして、nvidia-smiではGPUが表示されるようになったのでOSレベルでは認識していたようだが、Pytorchからは相変わらず認識されなかった。
※後々調べてみたら、以下のページが見つかった。P2系は推奨ではないらしい?
docs.aws.amazon.com
p3.8xlarge
実際は、G4dn系を試した後なので時間軸的には前後するが同一系列なので並べて記載する。
以下の条件でおおよそ 60h程度だった。
- 学習データ: 約 550万件
- epoch数: 10
- bach per device: 16
さらに条件を変更し、以下であれば4h程度。つまり1epochあたり1h。
- 学習データ: 約 120万件
- epoch数: 4
- bert encoder layer: 8
- bach per device: 24
bert_encoder_layerの数を12から8に減らして計算量を削減し、また、メモリに余裕ができたのでbatchサイズを増やして学習時間を短縮させた。
g4dn.8xlarge
以下の条件でおおよそ630hだった。
- 学習データ: 約 550万件
- epoch数: 10
- bach per device: 16
g4dn.12xlarge
以下の条件でおおよそ 100hだった。
- 学習データ: 約 300万件
- epoch数: 10
- bach per device: 16
GPU数が4つになっているため、単純に4倍速くなっている。
最初はmulti-GPU環境でプログラムがうまく動かなかった。原因は os.environ['CUDA_LAUNCH_BLOCKING'] = "1" を指定していたためだった。
HuggingfaceのTrainerを使っていたので、自動的にmulti-GPU環境に対応はしていたが、blockingモードで動かしていたために、並列実行がハングアウトしていたらしい。
今回学んだこと
- Instanceのファミリーを変えると起動に失敗(?)する
- p3系を使う場合は、instance capacityに注意
- 発生したのが土日だったので、土日に合わせて誰かが大規模に動かしたのかもしれない?特に今回は、データ専用のEBSを付け替えながらの実験だったため、availability-zoneを固定していたのも発生頻発した原因かもしれない。
- GLIBCXXエラーが派生する場合がある
- transformersをimportする際に発生した。こちらの記事を参考に解決した。なんとなく、transformersというよりはcuda周りが原因の気がしないでもない。
- データ件数、GPU数は線形で学習時間に影響する。
- データ件数もGPU数も顕著に学習時間にn倍の影響を及ぼしていた。(データ件数半分にしたら学習時間も半分)
- CUDA_LAUNCH_BLOCKING はmulti GPU環境では悪影響
- わかってしまえば当たり前のことだが、結構ハマった。特に、学習時間を正確に計測したいならつけるべしみたいな記事をよくみていたのでつけるのが普通なのかな?程度の認識だった。