Skip to content

Commit

Permalink
Add LLaMa as an example. (#722)
Browse files Browse the repository at this point in the history
* Add LLaMa as an example.

* Bugfix.

* Switch to using safetensors.

* Properly switch to half-precision.

* Get the text sampling to work.
  • Loading branch information
LaurentMazare authored May 24, 2023
1 parent 59fb161 commit 730b268
Show file tree
Hide file tree
Showing 5 changed files with 562 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ __pycache__
*.safetensors
*.so
*.dylib
llama-tokenizer.json
6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ safetensors = "0.3.0"
cpython = { version = "0.7.1", optional = true }
regex = { version = "1.6.0", optional = true }
image = { version = "0.24.5", optional = true }
clap = { version = "4.2.4", features = ["derive"], optional = true }
serde_json = { version = "1.0.96", optional = true }

[dev-dependencies]
anyhow = "1"
Expand All @@ -54,3 +56,7 @@ required-features = ["rl-python"]
[[example]]
name = "stable-diffusion"
required-features = ["regex"]

[[example]]
name = "llama"
required-features = ["regex", "clap", "serde_json"]
63 changes: 63 additions & 0 deletions examples/llama/convert_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Adapted from https://github.com/Lightning-AI/lit-llama/blob/main/scripts/convert_checkpoint.py
import sys
import torch
from typing import Dict
from pathlib import Path
from safetensors.torch import save_file

def convert_state_dict(state_dict: Dict[str, torch.Tensor], dtype: torch.dtype = torch.float16) -> Dict[str, torch.Tensor]:
print("start conv")

def get_and_remove(key):
v = state_dict[key].to(dtype)
del state_dict[key]
return v

converted = {}
converted["transformer.wte.weight"] = get_and_remove("tok_embeddings.weight")
converted["lm_head.weight"] = get_and_remove("output.weight")
converted["transformer.ln_f.scale"] = get_and_remove("norm.weight")

for layer_idx in sorted(set([k.split(".")[1] for k in state_dict if k.startswith("layers")])):
print(layer_idx)

# attention
# the wq, wk, wv from the FB model are stacked in our model as c_attn
converted[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = torch.cat(
(
get_and_remove(f"layers.{layer_idx}.attention.wq.weight"),
get_and_remove(f"layers.{layer_idx}.attention.wk.weight"),
get_and_remove(f"layers.{layer_idx}.attention.wv.weight"),
)
)
converted[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = get_and_remove(
f"layers.{layer_idx}.attention.wo.weight"
)
# mlp
converted[f"transformer.h.{layer_idx}.mlp.c_fc1.weight"] = get_and_remove(
f"layers.{layer_idx}.feed_forward.w1.weight"
)
converted[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = get_and_remove(
f"layers.{layer_idx}.feed_forward.w2.weight"
)
converted[f"transformer.h.{layer_idx}.mlp.c_fc2.weight"] = get_and_remove(
f"layers.{layer_idx}.feed_forward.w3.weight"
)
# rms norm
converted[f"transformer.h.{layer_idx}.rms_1.scale"] = get_and_remove(f"layers.{layer_idx}.attention_norm.weight")
converted[f"transformer.h.{layer_idx}.rms_2.scale"] = get_and_remove(f"layers.{layer_idx}.ffn_norm.weight")
return converted

def convert_weights(llama_ckpt, *, output_st: Path = Path("llama.safetensors"), dtype: str = "float16") -> None:
dt = getattr(torch, dtype, None)
if not isinstance(dt, torch.dtype):
raise ValueError(f"{dtype} is not a valid dtype.")
checkpoint = torch.load(llama_ckpt, map_location="cpu")
converted = convert_state_dict(checkpoint, dtype=dt)
del checkpoint
save_file(converted, output_st)

if __name__ == "__main__":
if len(sys.argv) != 2:
raise ValueError(f"usage: convert_checkpoint.py ..../LLaMA/7B/consolidated.00.pth")
convert_weights(sys.argv[1])
Loading

0 comments on commit 730b268

Please sign in to comment.