Local -> Cloud -> Local で分割する3分割 (Triadic) Split Computing
- 3分割することで、通信はすべて中間層出力の特徴ベクトルで行われるため、通信の安全性が高まる
- 生成AIモデルはサイズが大きいので、Splitしてサイズを小さくして多くのデバイスでロード可能にする
- LLM実装は、MetaのOpen Source LLMであるLLaMa-2 または LLaMa を使用
- Diffusion model実装には、Stable Diffusion XLを使用
- 推論時の推論レイヤを正しく分割するための、モデルのforwardメソッドのoverride(
src/models.py
のFirstLlamaModel
などのforward
メソッド内でコメントアウトすることで実装) - メモリ使用量削減のため、不要なレイヤを Identity レイヤで置き換える(
src/models.py
のFirstLlamaModel
などのreplace_unused_layers_with_identity
メソッドを実装)
main.py
: メインプログラムsrc/cloud.py
: クラウドクラス(first modelとthird modelを推論)src/edge.py
: エッジクラス(second modelを推論)src/base.py
: クラウドサーバ・エッジサーバの継承元クラスsrc/split_models.py
: 分割用のLLMクラスであるFirstLlamaModel
・FirstLlamaForCausalLM
・SecondLlamaModel
・SecondLlamaForCausalLM
・ThirdLlamaModel
・ThirdLlamaForCausalLM
が定義されているsrc/utils.py
: 推論のためのutilstorchinfo_summary_log/
: 分割したLLMのtorchinfo.summary
の結果
main.py
の first split layer の index の集合 first_split_layer_indices
と second split layer の index の集合 second_split_layer_indices
を変更して、
python3 main.py
初回時はPre-trainedモデルのダウンロードが必要。
LLaMa-2 では、https://note.com/npaka/n/n79eebc29366d の3.1の利用申請と3.2の huggingface-cli login
をする必要がある。