diff --git a/.gitignore b/.gitignore index 3e4ae83..1a0ff3f 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ experiments/** outfiles/** text_generation/ **/__pycache__/ +wandb/** +methods/baselines/topk/bce928f38989812b69c6f8e3a86763e004387d16/** diff --git a/data/input.txt b/data/input.txt new file mode 100644 index 0000000..02b5945 --- /dev/null +++ b/data/input.txt @@ -0,0 +1,3 @@ +The United States of America (USA or U.S.A.), commonly known as the United States (US or U.S.) or America, is a country primarily located in North America. +The Arsenal Football Club, commonly known as Arsenal, is an English professional football club based in Holloway, North London. +The three primary colors are diff --git a/eval_ppl.py b/eval_ppl.py index 9a5b774..559d1c7 100644 --- a/eval_ppl.py +++ b/eval_ppl.py @@ -1,11 +1,23 @@ from methods import init_tensor_saver -from configure_model import get_h2o_args, get_topk_args, get_spar_args, get_pca_args, get_save_tensor_args -from configure_model import get_modifier +from methods.common.configure_model import get_h2o_args, get_topk_args, get_spar_args, get_pca_args, get_save_tensor_args +from methods.common.configure_model import get_modifier +from methods import init_logger, finish_logger +import methods import argparse import os + os.environ["TOKENIZERS_PARALLELISM"] = "false" +#LM_HARNESS_VALID_TASKS = ["hellaswag", "winogrande", "gsm8k", "mmlu", "truthfulqa_mc2", "arc_challenge"] +LM_HARNESS_TASKS = { + "mmlu" : "acc,none", + "gsm8k" : "exact_match,strict-match", + "hellaswag" : "acc_norm,none", + "winogrande" : "acc,none", + "truthfulqa_mc2" : "acc,none", + "arc_challenge" : "acc_norm,none" +} if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -14,7 +26,9 @@ parser.add_argument("--sequence-length", type=int, default=4096, help="sequence length") parser.add_argument("--use-axonn", action='store_true', default=False, help="shard a model using AxoNN") parser.add_argument("--lm-harness-eval", action='store_true', default=False, help="use lm harness eval") - parser.add_argument("--dataset", type=str, default="wikitext-test", help="which dataset to use for ppl eval") + parser.add_argument("--dataset", type=str, default="wikitext-test", help="dataset - wikitext, bookcorpus, c4") + parser.add_argument("--use-wandb", action='store_true', default=False, help="use wandb") + #parser.add_argument("--task", type=str, default="perplexity", help="task - perplexity, ") parser = get_h2o_args(parser) parser = get_topk_args(parser) @@ -26,6 +40,8 @@ if args.save_tensors: init_tensor_saver(args.tensors_dir) + init_logger(args) + modifier_method = get_modifier(args) if modifier_method is None: raise ValueError("Modifier method not found") @@ -33,10 +49,8 @@ print (modifier_method) cache = None - if args.use_topk: - modifier_method(args.top_k) - elif args.use_h2o: - modifier_method(args.heavy_ratio) + if args.use_topk or args.use_h2o or args.use_pca_topk: + modifier_method(args) elif args.use_sparq or args.use_spark: modifier_method(args.top_r, args.top_k) elif args.use_spar_hat: @@ -44,20 +58,31 @@ elif args.use_pca: modifier_method(args.top_r) args.use_axonn = False - elif args.use_pca_topk: - modifier_method(args.top_r, args.top_k) - + if args.lm_harness_eval: import lm_eval + from lm_perplexity_eval import evaluate + model = evaluate(model_id=args.model_id, + dataset=args.dataset, + sequence_length=args.sequence_length, + use_axonn=args.use_axonn, + past_key_values=cache, + axonn_low_level_api=True, + return_model=True) results = lm_eval.simple_evaluate( model = "hf", - model_args=f"pretrained={args.model_id}", - tasks = ["copa", "rte", "openbookqa", "mathqa", "winogrande", "hellaswag"], - #tasks = ["hellaswag"], + #model_args=f"pretrained={args.model_id}", + #model_args={"pretrained": model, "parallelize": True}, + model_args={"pretrained": model}, + tasks = LM_HARNESS_TASKS.keys(), log_samples=False, + batch_size=16 ) - print(results["results"]) + if results is not None: + print(results["results"]) + if methods.LOGGER is not None: + methods.LOGGER.log_lm_harness_results(LM_HARNESS_TASKS, results["results"]) else: from lm_perplexity_eval import evaluate print(args.use_axonn) @@ -69,3 +94,7 @@ axonn_low_level_api=True) print(ppl) + if methods.LOGGER is not None: + methods.LOGGER.log_ppl(ppl) + + finish_logger() diff --git a/eval_ppl_old.py b/eval_ppl_old.py index a7f6b6d..afeb42a 100644 --- a/eval_ppl_old.py +++ b/eval_ppl_old.py @@ -8,8 +8,8 @@ #make_gemma_attention_top_k make_gptneox_attention_top_k ) -from configure_model import get_h2o_args, get_topk_args, get_spar_args, get_pca_args, get_save_tensor_args -from configure_model import get_modifier +from methods.common.configure_model import get_h2o_args, get_topk_args, get_spar_args, get_pca_args, get_save_tensor_args +from methods.common.configure_model import get_modifier from methods import SparHatCache import argparse diff --git a/examples/h2o-llama.sh b/examples/h2o-llama.sh deleted file mode 100644 index 01451b0..0000000 --- a/examples/h2o-llama.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash - -set -x -sbatch examples/submit_h2o.sh meta-llama/Llama-2-7b-hf llama 4096 0.0625 -sbatch examples/submit_h2o.sh meta-llama/Llama-2-7b-hf llama 4096 0.125 -sbatch examples/submit_h2o.sh meta-llama/Llama-2-7b-hf llama 4096 0.25 -sbatch examples/submit_h2o.sh meta-llama/Llama-2-13b-hf llama 4096 0.0625 -#sbatch examples/submit_h2o.sh meta-llama/Llama-2-13b-hf llama 4096 0.125 -#sbatch examples/submit_h2o.sh meta-llama/Llama-2-13b-hf llama 4096 0.25 -sbatch examples/submit_h2o.sh meta-llama/Llama-2-70b-hf llama 4096 0.0625 -#sbatch examples/submit_h2o.sh meta-llama/Llama-2-70b-hf llama 4096 0.125 -#sbatch examples/submit_h2o.sh meta-llama/Llama-2-70b-hf llama 4096 0.25 -set +x diff --git a/examples/h2o-mistral.sh b/examples/h2o-mistral.sh deleted file mode 100644 index cc5c98b..0000000 --- a/examples/h2o-mistral.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -set -x -sbatch examples/submit_h2o_1gpu.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.125 -sbatch examples/submit_h2o_1gpu.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.25 -sbatch examples/submit_h2o_1gpu.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.125 --lm-harness-eval -sbatch examples/submit_h2o_1gpu.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.25 --lm-harness-eval -set +x diff --git a/examples/h2o-run.sh b/examples/h2o-run.sh deleted file mode 100644 index fb7d52e..0000000 --- a/examples/h2o-run.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -set -x -#sbatch examples/submit_h2o_opt.sh facebook/opt-1.3b opt 2048 0.125 -#sbatch examples/submit_h2o_opt.sh facebook/opt-2.7b opt 2048 0.125 -#sbatch examples/submit_h2o_opt.sh facebook/opt-6.7b opt 2048 0.125 -#sbatch examples/submit_h2o_opt.sh facebook/opt-13b opt 2048 0.125 -#sbatch examples/submit_h2o_opt.sh facebook/opt-30b opt 2048 0.125 -#sbatch examples/submit_h2o_opt.sh facebook/opt-1.3b opt 2048 0.25 -#sbatch examples/submit_h2o_opt.sh facebook/opt-2.7b opt 2048 0.25 -#sbatch examples/submit_h2o_opt.sh facebook/opt-6.7b opt 2048 0.25 -#sbatch examples/submit_h2o_opt.sh facebook/opt-13b opt 2048 0.25 -#sbatch examples/submit_h2o_opt.sh facebook/opt-30b opt 2048 0.25 -sbatch examples/submit_h2o_opt.sh facebook/opt-1.3b opt 2048 0.0625 -sbatch examples/submit_h2o_opt.sh facebook/opt-2.7b opt 2048 0.0625 -sbatch examples/submit_h2o_opt.sh facebook/opt-6.7b opt 2048 0.0625 -sbatch examples/submit_h2o_opt.sh facebook/opt-13b opt 2048 0.0625 -sbatch examples/submit_h2o_opt.sh facebook/opt-30b opt 2048 0.0625 -set +x - diff --git a/examples/h2o/h2o-llama.sh b/examples/h2o/h2o-llama.sh new file mode 100644 index 0000000..0dbf7f8 --- /dev/null +++ b/examples/h2o/h2o-llama.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +set -x +sbatch examples/h2o/submit_h2o.sh meta-llama/Llama-2-7b-hf llama 4096 0.125 +sbatch examples/h2o/submit_h2o.sh meta-llama/Llama-2-7b-hf llama 4096 0.25 +sbatch examples/h2o/submit_h2o.sh meta-llama/Llama-2-7b-hf llama 4096 0.125 --lm-harness-eval +sbatch examples/h2o/submit_h2o.sh meta-llama/Llama-2-7b-hf llama 4096 0.25 --lm-harness-eval +set +x diff --git a/examples/h2o/h2o-mistral.sh b/examples/h2o/h2o-mistral.sh new file mode 100644 index 0000000..162e5f8 --- /dev/null +++ b/examples/h2o/h2o-mistral.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +set -x +sbatch examples/h2o/submit_h2o_1gpu.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.125 +sbatch examples/h2o/submit_h2o_1gpu.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.25 +sbatch examples/h2o/submit_h2o_1gpu.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.125 --lm-harness-eval +sbatch examples/h2o/submit_h2o_1gpu.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.25 --lm-harness-eval +set +x diff --git a/examples/h2o/h2o-pythia.sh b/examples/h2o/h2o-pythia.sh new file mode 100644 index 0000000..82d4ec0 --- /dev/null +++ b/examples/h2o/h2o-pythia.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +set -x +sbatch examples/h2o/submit_h2o_1gpu.sh EleutherAI/pythia-6.9b gptneox 2048 0.125 +sbatch examples/h2o/submit_h2o_1gpu.sh EleutherAI/pythia-6.9b gptneox 2048 0.25 +sbatch examples/h2o/submit_h2o_1gpu.sh EleutherAI/pythia-6.9b gptneox 2048 0.125 --lm-harness-eval +sbatch examples/h2o/submit_h2o_1gpu.sh EleutherAI/pythia-6.9b gptneox 2048 0.25 --lm-harness-eval +set +x diff --git a/examples/submit_h2o.sh b/examples/h2o/submit_h2o.sh similarity index 96% rename from examples/submit_h2o.sh rename to examples/h2o/submit_h2o.sh index 237c3b2..6c46477 100644 --- a/examples/submit_h2o.sh +++ b/examples/h2o/submit_h2o.sh @@ -6,6 +6,9 @@ #SBATCH --account=m4641_g #SBATCH --ntasks-per-node=4 #SBATCH --time=03:00:00 +#SBATCH -J h2o +#SBATCH --output=outfiles/%x-%j.out + # Runs a "10B" parameter model diff --git a/examples/submit_h2o_1gpu.sh b/examples/h2o/submit_h2o_1gpu.sh similarity index 96% rename from examples/submit_h2o_1gpu.sh rename to examples/h2o/submit_h2o_1gpu.sh index 42815af..36ccd75 100644 --- a/examples/submit_h2o_1gpu.sh +++ b/examples/h2o/submit_h2o_1gpu.sh @@ -6,6 +6,9 @@ #SBATCH --account=m4641_g #SBATCH --ntasks-per-node=1 #SBATCH --time=10:00:00 +#SBATCH -J h2o +#SBATCH --output=outfiles/%x-%j.out + # Runs a "10B" parameter model diff --git a/examples/pca-llama.sh b/examples/pca-llama.sh deleted file mode 100644 index bc53b74..0000000 --- a/examples/pca-llama.sh +++ /dev/null @@ -1,21 +0,0 @@ -#!/bin/bash - -set -x -#sbatch examples/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 128 --lm-harness-eval -#sbatch examples/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 64 --lm-harness-eval -#sbatch examples/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 32 --lm-harness-eval -#sbatch examples/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 128 -#sbatch examples/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 64 -#sbatch examples/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 32 -sbatch examples/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 96 -sbatch examples/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.97 --lm-harness-eval -sbatch examples/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.95 --lm-harness-eval -sbatch examples/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.90 --lm-harness-eval -sbatch examples/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.85 --lm-harness-eval -sbatch examples/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.75 --lm-harness-eval -#sbatch examples/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.97 -#sbatch examples/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.95 -#sbatch examples/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.90 -#sbatch examples/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.85 -#sbatch examples/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.75 -set +x diff --git a/examples/pca-mistral.sh b/examples/pca-mistral.sh deleted file mode 100644 index 1879476..0000000 --- a/examples/pca-mistral.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash - -set -x -#sbatch examples/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 128 --lm-harness-eval -sbatch examples/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 64 --lm-harness-eval -sbatch examples/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 32 --lm-harness-eval -#sbatch examples/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 128 -#sbatch examples/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 64 -#sbatch examples/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 32 -#sbatch examples/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 96 -#sbatch examples/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.97 --lm-harness-eval -#sbatch examples/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.95 --lm-harness-eval -#sbatch examples/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.90 --lm-harness-eval -#sbatch examples/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.85 --lm-harness-eval -sbatch examples/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.70 --lm-harness-eval -#sbatch examples/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.75 --lm-harness-eval -#sbatch examples/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.97 -#sbatch examples/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.95 -#sbatch examples/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.90 -#sbatch examples/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.85 -#sbatch examples/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.75 -set +x diff --git a/examples/pca-topk-llama.sh b/examples/pca-topk-llama.sh deleted file mode 100644 index a261010..0000000 --- a/examples/pca-topk-llama.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash - -set -x -sbatch examples/submit_pca_topk.sh meta-llama/Llama-2-7b-hf llama 4096 64 2048 -sbatch examples/submit_pca_topk.sh meta-llama/Llama-2-7b-hf llama 4096 64 1024 -sbatch examples/submit_pca_topk.sh meta-llama/Llama-2-7b-hf llama 4096 64 512 -sbatch examples/submit_pca_topk.sh meta-llama/Llama-2-7b-hf llama 4096 32 2048 -sbatch examples/submit_pca_topk.sh meta-llama/Llama-2-7b-hf llama 4096 32 1024 -sbatch examples/submit_pca_topk.sh meta-llama/Llama-2-7b-hf llama 4096 32 512 -sbatch examples/submit_pca_topk.sh meta-llama/Llama-2-7b-hf llama 4096 16 2048 -sbatch examples/submit_pca_topk.sh meta-llama/Llama-2-7b-hf llama 4096 16 1024 -sbatch examples/submit_pca_topk.sh meta-llama/Llama-2-7b-hf llama 4096 16 512 -set +x diff --git a/examples/pca-topk-mistral.sh b/examples/pca-topk-mistral.sh deleted file mode 100644 index 52e1d39..0000000 --- a/examples/pca-topk-mistral.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash - -set -x -sbatch examples/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 64 2048 -sbatch examples/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 64 1024 -sbatch examples/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 64 512 -sbatch examples/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 32 2048 -sbatch examples/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 32 1024 -sbatch examples/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 32 512 -sbatch examples/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 16 2048 -sbatch examples/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 16 1024 -sbatch examples/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 16 512 -set +x diff --git a/examples/pca/pca-llama.sh b/examples/pca/pca-llama.sh new file mode 100644 index 0000000..cb0e9c3 --- /dev/null +++ b/examples/pca/pca-llama.sh @@ -0,0 +1,21 @@ +#!/bin/bash + +set -x +#sbatch examples/pca/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 128 --lm-harness-eval +#sbatch examples/pca/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 64 --lm-harness-eval +#sbatch examples/pca/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 32 --lm-harness-eval +#sbatch examples/pca/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 128 +#sbatch examples/pca/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 64 +#sbatch examples/pca/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 32 +sbatch examples/pca/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 96 +sbatch examples/pca/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.97 --lm-harness-eval +sbatch examples/pca/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.95 --lm-harness-eval +sbatch examples/pca/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.90 --lm-harness-eval +sbatch examples/pca/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.85 --lm-harness-eval +sbatch examples/pca/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.75 --lm-harness-eval +#sbatch examples/pca/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.97 +#sbatch examples/pca/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.95 +#sbatch examples/pca/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.90 +#sbatch examples/pca/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.85 +#sbatch examples/pca/submit_pca.sh meta-llama/Llama-2-7b-hf llama 4096 0.75 +set +x diff --git a/examples/pca/pca-mistral.sh b/examples/pca/pca-mistral.sh new file mode 100644 index 0000000..0ea6f94 --- /dev/null +++ b/examples/pca/pca-mistral.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +set -x +#sbatch examples/pca/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 128 --lm-harness-eval +sbatch examples/pca/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 64 --lm-harness-eval +sbatch examples/pca/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 32 --lm-harness-eval +#sbatch examples/pca/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 128 +#sbatch examples/pca/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 64 +#sbatch examples/pca/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 32 +#sbatch examples/pca/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 96 +#sbatch examples/pca/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.97 --lm-harness-eval +#sbatch examples/pca/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.95 --lm-harness-eval +#sbatch examples/pca/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.90 --lm-harness-eval +#sbatch examples/pca/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.85 --lm-harness-eval +sbatch examples/pca/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.70 --lm-harness-eval +#sbatch examples/pca/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.75 --lm-harness-eval +#sbatch examples/pca/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.97 +#sbatch examples/pca/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.95 +#sbatch examples/pca/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.90 +#sbatch examples/pca/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.85 +#sbatch examples/pca/submit_pca.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.75 +set +x diff --git a/examples/submit_pca.sh b/examples/pca/submit_pca.sh similarity index 100% rename from examples/submit_pca.sh rename to examples/pca/submit_pca.sh diff --git a/examples/pca_topk/pca-topk-llama.sh b/examples/pca_topk/pca-topk-llama.sh new file mode 100644 index 0000000..2028c09 --- /dev/null +++ b/examples/pca_topk/pca-topk-llama.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +MODEL=$1 +MODEL_TYPE=$2 +SEQLEN=$3 + +set -x +sbatch examples/pca_topk/submit_pca_topk.sh ${MODEL} ${MODEL_TYPE} ${SEQLEN} 64 0.5 +#sbatch examples/pca_topk/submit_pca_topk.sh ${MODEL} ${MODEL_TYPE} ${SEQLEN} 64 0.25 +#sbatch examples/pca_topk/submit_pca_topk.sh ${MODEL} ${MODEL_TYPE} ${SEQLEN} 64 0.125 +# +sbatch examples/pca_topk/submit_pca_topk.sh ${MODEL} ${MODEL_TYPE} ${SEQLEN} 32 0.5 +#sbatch examples/pca_topk/submit_pca_topk.sh ${MODEL} ${MODEL_TYPE} ${SEQLEN} 32 0.25 +#sbatch examples/pca_topk/submit_pca_topk.sh ${MODEL} ${MODEL_TYPE} ${SEQLEN} 32 0.125 +# +sbatch examples/pca_topk/submit_pca_topk.sh ${MODEL} ${MODEL_TYPE} ${SEQLEN} 16 0.5 +#sbatch examples/pca_topk/submit_pca_topk.sh ${MODEL} ${MODEL_TYPE} ${SEQLEN} 16 0.25 +#sbatch examples/pca_topk/submit_pca_topk.sh ${MODEL} ${MODEL_TYPE} ${SEQLEN} 16 0.125 +# +#sbatch examples/pca_topk/submit_pca_topk.sh ${MODEL} ${MODEL_TYPE} ${SEQLEN} 64 0.5 --lm-harness-eval +#sbatch examples/pca_topk/submit_pca_topk.sh ${MODEL} ${MODEL_TYPE} ${SEQLEN} 64 0.25 --lm-harness-eval +#sbatch examples/pca_topk/submit_pca_topk.sh ${MODEL} ${MODEL_TYPE} ${SEQLEN} 64 0.125 --lm-harness-eval +# +#sbatch examples/pca_topk/submit_pca_topk.sh ${MODEL} ${MODEL_TYPE} ${SEQLEN} 32 0.5 --lm-harness-eval +#sbatch examples/pca_topk/submit_pca_topk.sh ${MODEL} ${MODEL_TYPE} ${SEQLEN} 32 0.25 --lm-harness-eval +#sbatch examples/pca_topk/submit_pca_topk.sh ${MODEL} ${MODEL_TYPE} ${SEQLEN} 32 0.125 --lm-harness-eval +# +#sbatch examples/pca_topk/submit_pca_topk.sh ${MODEL} ${MODEL_TYPE} ${SEQLEN} 16 0.5 --lm-harness-eval +#sbatch examples/pca_topk/submit_pca_topk.sh ${MODEL} ${MODEL_TYPE} ${SEQLEN} 16 0.25 --lm-harness-eval +#sbatch examples/pca_topk/submit_pca_topk.sh ${MODEL} ${MODEL_TYPE} ${SEQLEN} 16 0.125 --lm-harness-eval +set +x diff --git a/examples/pca_topk/pca-topk-mistral.sh b/examples/pca_topk/pca-topk-mistral.sh new file mode 100644 index 0000000..cf69ad0 --- /dev/null +++ b/examples/pca_topk/pca-topk-mistral.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +set -x +sbatch examples/pca_topk/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 64 0.5 +sbatch examples/pca_topk/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 64 0.25 +sbatch examples/pca_topk/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 64 0.125 + +sbatch examples/pca_topk/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 32 0.5 +sbatch examples/pca_topk/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 32 0.25 +sbatch examples/pca_topk/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 32 0.125 + +sbatch examples/pca_topk/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 16 0.5 +sbatch examples/pca_topk/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 16 0.25 +sbatch examples/pca_topk/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 16 0.125 + +sbatch examples/pca_topk/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 64 0.5 --lm-harness-eval +sbatch examples/pca_topk/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 64 0.25 --lm-harness-eval +sbatch examples/pca_topk/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 64 0.125 --lm-harness-eval + +sbatch examples/pca_topk/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 32 0.5 --lm-harness-eval +sbatch examples/pca_topk/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 32 0.25 --lm-harness-eval +sbatch examples/pca_topk/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 32 0.125 --lm-harness-eval + +sbatch examples/pca_topk/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 16 0.5 --lm-harness-eval +sbatch examples/pca_topk/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 16 0.25 --lm-harness-eval +sbatch examples/pca_topk/submit_pca_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 16 0.125 --lm-harness-eval +set +x diff --git a/examples/pca_topk/pca-topk-pythia.sh b/examples/pca_topk/pca-topk-pythia.sh new file mode 100644 index 0000000..ee77da8 --- /dev/null +++ b/examples/pca_topk/pca-topk-pythia.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +set -x +sbatch examples/pca_topk/submit_pca_topk.sh EleutherAI/pythia-6.9b gptneox 2048 64 0.5 +sbatch examples/pca_topk/submit_pca_topk.sh EleutherAI/pythia-6.9b gptneox 2048 64 0.25 +sbatch examples/pca_topk/submit_pca_topk.sh EleutherAI/pythia-6.9b gptneox 2048 64 0.125 + +sbatch examples/pca_topk/submit_pca_topk.sh EleutherAI/pythia-6.9b gptneox 2048 32 0.5 +sbatch examples/pca_topk/submit_pca_topk.sh EleutherAI/pythia-6.9b gptneox 2048 32 0.25 +sbatch examples/pca_topk/submit_pca_topk.sh EleutherAI/pythia-6.9b gptneox 2048 32 0.125 + +sbatch examples/pca_topk/submit_pca_topk.sh EleutherAI/pythia-6.9b gptneox 2048 16 0.5 +sbatch examples/pca_topk/submit_pca_topk.sh EleutherAI/pythia-6.9b gptneox 2048 16 0.25 +sbatch examples/pca_topk/submit_pca_topk.sh EleutherAI/pythia-6.9b gptneox 2048 16 0.125 + +sbatch examples/pca_topk/submit_pca_topk.sh EleutherAI/pythia-6.9b gptneox 2048 64 0.5 --lm-harness-eval +sbatch examples/pca_topk/submit_pca_topk.sh EleutherAI/pythia-6.9b gptneox 2048 64 0.25 --lm-harness-eval +sbatch examples/pca_topk/submit_pca_topk.sh EleutherAI/pythia-6.9b gptneox 2048 64 0.125 --lm-harness-eval + +sbatch examples/pca_topk/submit_pca_topk.sh EleutherAI/pythia-6.9b gptneox 2048 32 0.5 --lm-harness-eval +sbatch examples/pca_topk/submit_pca_topk.sh EleutherAI/pythia-6.9b gptneox 2048 32 0.25 --lm-harness-eval +sbatch examples/pca_topk/submit_pca_topk.sh EleutherAI/pythia-6.9b gptneox 2048 32 0.125 --lm-harness-eval + +sbatch examples/pca_topk/submit_pca_topk.sh EleutherAI/pythia-6.9b gptneox 2048 16 0.5 --lm-harness-eval +sbatch examples/pca_topk/submit_pca_topk.sh EleutherAI/pythia-6.9b gptneox 2048 16 0.25 --lm-harness-eval +sbatch examples/pca_topk/submit_pca_topk.sh EleutherAI/pythia-6.9b gptneox 2048 16 0.125 --lm-harness-eval +set +x diff --git a/examples/pca_topk/pca-topk-tinyllama.sh b/examples/pca_topk/pca-topk-tinyllama.sh new file mode 100644 index 0000000..6600e33 --- /dev/null +++ b/examples/pca_topk/pca-topk-tinyllama.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +set -x +#sbatch examples/pca_topk/submit_pca_topk.sh TinyLlama/TinyLlama-1.1B-Chat-v1.0 llama 2048 32 0.5 +#sbatch examples/pca_topk/submit_pca_topk.sh TinyLlama/TinyLlama-1.1B-Chat-v1.0 llama 2048 32 0.25 +#sbatch examples/pca_topk/submit_pca_topk.sh TinyLlama/TinyLlama-1.1B-Chat-v1.0 llama 2048 32 0.125 +# +#sbatch examples/pca_topk/submit_pca_topk.sh TinyLlama/TinyLlama-1.1B-Chat-v1.0 llama 2048 16 0.5 +#sbatch examples/pca_topk/submit_pca_topk.sh TinyLlama/TinyLlama-1.1B-Chat-v1.0 llama 2048 16 0.25 +#sbatch examples/pca_topk/submit_pca_topk.sh TinyLlama/TinyLlama-1.1B-Chat-v1.0 llama 2048 16 0.125 +# +#sbatch examples/pca_topk/submit_pca_topk.sh TinyLlama/TinyLlama-1.1B-Chat-v1.0 llama 2048 8 0.5 +#sbatch examples/pca_topk/submit_pca_topk.sh TinyLlama/TinyLlama-1.1B-Chat-v1.0 llama 2048 8 0.25 +#sbatch examples/pca_topk/submit_pca_topk.sh TinyLlama/TinyLlama-1.1B-Chat-v1.0 llama 2048 8 0.125 + +sbatch examples/pca_topk/submit_pca_topk.sh TinyLlama/TinyLlama-1.1B-Chat-v1.0 llama 2048 32 0.5 --lm-harness-eval +sbatch examples/pca_topk/submit_pca_topk.sh TinyLlama/TinyLlama-1.1B-Chat-v1.0 llama 2048 32 0.25 --lm-harness-eval +sbatch examples/pca_topk/submit_pca_topk.sh TinyLlama/TinyLlama-1.1B-Chat-v1.0 llama 2048 32 0.125 --lm-harness-eval + +sbatch examples/pca_topk/submit_pca_topk.sh TinyLlama/TinyLlama-1.1B-Chat-v1.0 llama 2048 16 0.5 --lm-harness-eval +sbatch examples/pca_topk/submit_pca_topk.sh TinyLlama/TinyLlama-1.1B-Chat-v1.0 llama 2048 16 0.25 --lm-harness-eval +sbatch examples/pca_topk/submit_pca_topk.sh TinyLlama/TinyLlama-1.1B-Chat-v1.0 llama 2048 16 0.125 --lm-harness-eval + +sbatch examples/pca_topk/submit_pca_topk.sh TinyLlama/TinyLlama-1.1B-Chat-v1.0 llama 2048 8 0.5 --lm-harness-eval +sbatch examples/pca_topk/submit_pca_topk.sh TinyLlama/TinyLlama-1.1B-Chat-v1.0 llama 2048 8 0.25 --lm-harness-eval +sbatch examples/pca_topk/submit_pca_topk.sh TinyLlama/TinyLlama-1.1B-Chat-v1.0 llama 2048 8 0.125 --lm-harness-eval +set +x diff --git a/examples/submit_pca_topk.sh b/examples/pca_topk/submit_pca_topk.sh similarity index 68% rename from examples/submit_pca_topk.sh rename to examples/pca_topk/submit_pca_topk.sh index 222baf6..892c869 100644 --- a/examples/submit_pca_topk.sh +++ b/examples/pca_topk/submit_pca_topk.sh @@ -1,11 +1,14 @@ #!/bin/bash #SBATCH --qos=regular -#SBATCH --constraint=gpu +#SBATCH --constraint=gpu&hbm80g #SBATCH -N 1 -#SBATCH --gpus-per-node=1 +#SBATCH --gpus-per-node=4 #SBATCH --account=m4641_g -#SBATCH --ntasks-per-node=1 -#SBATCH --time=00:50:00 +#SBATCH --ntasks-per-node=4 +#SBATCH --time=02:30:00 +#SBATCH -J pca_topk +#SBATCH --output=outfiles/%x-%j.out + # Runs a "10B" parameter model @@ -34,6 +37,10 @@ export HF_HOME=${HF_HOME:-"$SCRATCH/hf_cache"} export TRANSFORMERS_HOME=${TRANSFORMERS_HOME:-"$SCRATCH/hf_cache"} export HF_DATASETS_CACHE=${HF_DATASETS_CACHE:-"$SCRATCH/hf_cache"} +export WANDB_DIR="$SCRATCH/InferenceData/wandb" +export WANDB_CACHE_DIR="$SCRATCH/.cache/wandb" +export WANDB_CONFIG_DIR="$SCRATCH/.cache/wandb_config" + MODEL=$1 MODEL_TYPE=$2 SEQ_LEN=$3 @@ -41,10 +48,16 @@ MODEL_NAME=$(echo "$MODEL" | cut -d'/' -f2) TOPR=$4 TOPK=$5 EVAL=$6 +WANDB=true OUT_FILE_PATH="experiments/exp-pca-topk/${MODEL_NAME}" mkdir -p $OUT_FILE_PATH +WANDB_ARGS="" +if [ "$WANDB" = true ]; then + WANDB_ARGS="--use-wandb" +fi + echo "Model: ${MODEL}" echo "Model Name: ${MODEL_NAME}" echo "Sequence Length: ${SEQ_LEN}" @@ -53,8 +66,14 @@ echo "Running model ${MODEL} with PCA Attention and top-r ${TOPR} and top-k ${TO run_cmd="srun -C gpu -N ${NNODES} -n ${GPUS} -c 32 --cpu-bind=cores --gpus-per-node=4 ./set_env_vars_slurm.sh python -u eval_ppl.py --use-axonn --sequence-length ${SEQ_LEN}\ --model-id ${MODEL} --model-type ${MODEL_TYPE}\ + ${WANDB_ARGS}\ --use-pca-topk --top-r ${TOPR} --top-k ${TOPK} ${EVAL} | tee ${OUT_FILE_PATH}/out_${MODEL_NAME}_${TOPR}_${TOPK}${EVAL}.out 2>&1" +#run_cmd="srun -N 1 ./set_env_vars_slurm.sh python -u eval_ppl.py --use-axonn --sequence-length ${SEQ_LEN}\ +# --model-id ${MODEL} --model-type ${MODEL_TYPE}\ +# ${WANDB_ARGS}\ +# --use-pca-topk --top-r ${TOPR} --top-k ${TOPK} ${EVAL} | tee ${OUT_FILE_PATH}/out_${MODEL_NAME}_${TOPR}_${TOPK}${EVAL}.out 2>&1" + echo ${run_cmd} eval ${run_cmd} diff --git a/examples/saver/saver-keys.sh b/examples/saver/saver-keys.sh new file mode 100644 index 0000000..7e08e4d --- /dev/null +++ b/examples/saver/saver-keys.sh @@ -0,0 +1,12 @@ +#!/bin/bash +MODEL_ID=$1 +MODEL_TYPE=$2 + +echo "MODEL_ID: $MODEL_ID" +echo "MODEL_TYPE: $MODEL_TYPE" + +set -x +sbatch examples/saver/submit_saver.sh ${MODEL_ID} ${MODEL_TYPE} 2048 1 wikitext-valid +sbatch examples/saver/submit_saver.sh ${MODEL_ID} ${MODEL_TYPE} 2048 1 bookcorpus +sbatch examples/saver/submit_saver.sh ${MODEL_ID} ${MODEL_TYPE} 2048 1 c4 +set +x diff --git a/examples/saver/submit_saver.sh b/examples/saver/submit_saver.sh new file mode 100644 index 0000000..65fedb9 --- /dev/null +++ b/examples/saver/submit_saver.sh @@ -0,0 +1,72 @@ +#!/bin/bash +#SBATCH --qos=regular +#SBATCH --constraint=gpu&hbm80g +#SBATCH -N 1 +#SBATCH --gpus-per-node=4 +#SBATCH --account=m4641_g +#SBATCH --ntasks-per-node=4 +#SBATCH --time=03:00:00 +#SBATCH -J saver +#SBATCH --output=outfiles/%x-%j.out + + +# Runs a "10B" parameter model + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + + +#NNODES=$SLURM_JOB_NUM_NODES +NNODES=$SLURM_JOB_NUM_NODES +GPUS=$(( NNODES * 4 )) +export MASTER_ADDR=$(hostname) +export MASTER_PORT=29500 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_NET_GDR_LEVEL=PHB +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export CUDA_VISIBLE_DEVICES=3,2,1,0 +export NCCL_CROSS_NIC=1 +export NCCL_SOCKET_IFNAME=hsn +export NCCL_NET="AWS Libfabric" +export FI_CXI_RDZV_THRESHOLD=0 +export FI_CXI_RDZV_GET_MIN=0 +export FI_CXI_OFLOW_BUF_SIZE=1073741824 +export FI_CXI_OFLOW_BUF_COUNT=1 +export WORLD_SIZE=$GPUS + + +export HF_HOME="$SCRATCH/hf_cache" +export TRANSFORMERS_HOME="$SCRATCH/hf_cache" +export HF_DATASETS_CACHE="$SCRATCH/hf_cache" + +MODEL=$1 +MODEL_TYPE=$2 +SEQ_LEN=$3 +MODEL_NAME=$(echo "$MODEL" | cut -d'/' -f2) +TOPK=$4 +DATASET=$5 +SAVE=true + +OUT_FILE_PATH="experiments/exp-saver/${MODEL_NAME}" +mkdir -p $OUT_FILE_PATH + +echo "Model: ${MODEL}" +echo "Model Name: ${MODEL_NAME}" +echo "Sequence Length: ${SEQ_LEN}" +echo "Output Path: ${OUT_FILE_PATH}" +echo "Running model ${MODEL} with for saving with top-k ${TOPK}" + +SAVE_ARGS="" +if [ "$SAVE" = true ]; then + OUT_TENSOR_DATA_PATH="${SCRATCH}/InferenceData/topk/${MODEL_NAME}/${TOPK}/${DATASET}/" + mkdir -p $OUT_TENSOR_DATA_PATH + SAVE_ARGS="--save-tensors --tensors-dir ${OUT_TENSOR_DATA_PATH}" +fi + +run_cmd="srun -C gpu -N ${NNODES} -n ${GPUS} -c 32 --cpu-bind=cores --gpus-per-node=4 ./set_env_vars_slurm.sh python -u eval_ppl.py --use-axonn --sequence-length ${SEQ_LEN}\ + --model-id ${MODEL} --model-type ${MODEL_TYPE} --dataset ${DATASET}\ + ${SAVE_ARGS}\ + --dataset ${DATASET} + --use-axonn --use-topk --top-k ${TOPK}| tee ${OUT_FILE_PATH}/out_${MODEL_NAME}_${TOPK}.out 2>&1" + +echo ${run_cmd} +eval ${run_cmd} diff --git a/examples/topk-llama.sh b/examples/topk-llama.sh deleted file mode 100644 index ab3ca81..0000000 --- a/examples/topk-llama.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash - -set -x -#sbatch examples/submit_topk.sh meta-llama/Llama-2-7b-hf llama 4096 0.125 -#sbatch examples/submit_h2o.sh meta-llama/Llama-2-7b-hf llama 4096 0.25 -#sbatch examples/submit_h2o.sh meta-llama/Llama-2-13b-hf llama 4096 0.125 -#sbatch examples/submit_h2o.sh meta-llama/Llama-2-13b-hf llama 4096 0.25 -#sbatch examples/submit_h2o.sh meta-llama/Llama-2-70b-hf llama 4096 0.125 -sbatch examples/submit_topk.sh meta-llama/Llama-2-70b-hf llama 4096 1024 -sbatch examples/submit_topk.sh meta-llama/Llama-2-70b-hf llama 4096 4096 -set +x diff --git a/examples/topk-mistral.sh b/examples/topk-mistral.sh deleted file mode 100644 index fe4ea75..0000000 --- a/examples/topk-mistral.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -set -x -#sbatch examples/submit_topk.sh meta-llama/Llama-2-7b-hf llama 4096 0.125 -#sbatch examples/submit_h2o.sh meta-llama/Llama-2-7b-hf llama 4096 0.25 -#sbatch examples/submit_h2o.sh meta-llama/Llama-2-13b-hf llama 4096 0.125 -#sbatch examples/submit_h2o.sh meta-llama/Llama-2-13b-hf llama 4096 0.25 -#sbatch examples/submit_h2o.sh meta-llama/Llama-2-70b-hf llama 4096 0.125 -#sbatch examples/submit_topk.sh meta-llama/Llama-2-70b-hf llama 4096 1024 -#sbatch examples/submit_topk.sh meta-llama/Llama-2-70b-hf llama 4096 4096 - -#sbatch examples/submit_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.125 --lm-harness-eval -sbatch examples/submit_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.25 --lm-harness-eval -sbatch examples/submit_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.50 --lm-harness-eval -set +x diff --git a/examples/topk/submit_topk.sh b/examples/topk/submit_topk.sh new file mode 100644 index 0000000..c454ac9 --- /dev/null +++ b/examples/topk/submit_topk.sh @@ -0,0 +1,74 @@ +#!/bin/bash +#SBATCH --qos=regular +#SBATCH --constraint=gpu&hbm80g +#SBATCH -N 1 +#SBATCH --gpus-per-node=4 +#SBATCH --account=m4641_g +#SBATCH --ntasks-per-node=4 +#SBATCH --time=01:00:00 +#SBATCH -J topk +#SBATCH --output=outfiles/%x-%j.out + + +# Runs a "10B" parameter model + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + + +#NNODES=$SLURM_JOB_NUM_NODES +NNODES=$SLURM_JOB_NUM_NODES +GPUS=$(( NNODES * 4 )) +export MASTER_ADDR=$(hostname) +export MASTER_PORT=29500 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_NET_GDR_LEVEL=PHB +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export CUDA_VISIBLE_DEVICES=3,2,1,0 +export NCCL_CROSS_NIC=1 +export NCCL_SOCKET_IFNAME=hsn +export NCCL_NET="AWS Libfabric" +export FI_CXI_RDZV_THRESHOLD=0 +export FI_CXI_RDZV_GET_MIN=0 +export FI_CXI_OFLOW_BUF_SIZE=1073741824 +export FI_CXI_OFLOW_BUF_COUNT=1 +export WORLD_SIZE=$GPUS + + +export HF_HOME="$SCRATCH/hf_cache" +export TRANSFORMERS_HOME="$SCRATCH/hf_cache" +export HF_DATASETS_CACHE="$SCRATCH/hf_cache" + +MODEL=$1 +MODEL_TYPE=$2 +SEQ_LEN=$3 +MODEL_NAME=$(echo "$MODEL" | cut -d'/' -f2) +TOPK=$4 +DATASET=$5 +EVAL=$6 +SAVE=false + +OUT_FILE_PATH="experiments/exp-topk/${MODEL_NAME}" +mkdir -p $OUT_FILE_PATH + + +echo "Model: ${MODEL}" +echo "Model Name: ${MODEL_NAME}" +echo "Sequence Length: ${SEQ_LEN}" +echo "Output Path: ${OUT_FILE_PATH}" +echo "Running model ${MODEL} with top-k ${TOPK}" + +SAVE_ARGS="" +if [ "$SAVE" = true ]; then + OUT_TENSOR_DATA_PATH="${SCRATCH}/InferenceData/topk/${MODEL_NAME}/${TOPK}/${DATASET}/" + mkdir -p $OUT_TENSOR_DATA_PATH + SAVE_ARGS="--save-tensors --tensors-dir ${OUT_TENSOR_DATA_PATH}" +fi + +run_cmd="srun -C gpu -N ${NNODES} -n ${GPUS} -c 32 --cpu-bind=cores --gpus-per-node=4 ./set_env_vars_slurm.sh python -u eval_ppl.py --use-axonn --sequence-length ${SEQ_LEN}\ + --model-id ${MODEL} --model-type ${MODEL_TYPE} --dataset ${DATASET}\ + ${SAVE_ARGS}\ + --dataset ${DATASET} + --use-axonn --use-topk --top-k ${TOPK} ${EVAL}| tee ${OUT_FILE_PATH}/out_${MODEL_NAME}_${TOPK}${EVAL}.out 2>&1" + +echo ${run_cmd} +eval ${run_cmd} diff --git a/examples/submit_topk.sh b/examples/topk/submit_topk_phi.sh similarity index 86% rename from examples/submit_topk.sh rename to examples/topk/submit_topk_phi.sh index c295803..753c7b5 100644 --- a/examples/submit_topk.sh +++ b/examples/topk/submit_topk_phi.sh @@ -14,7 +14,7 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 NNODES=$SLURM_JOB_NUM_NODES -GPUS=$(( NNODES * 4 )) +GPUS=$(( NNODES * 1 )) export MASTER_ADDR=$(hostname) export MASTER_PORT=29500 export CUDA_DEVICE_MAX_CONNECTIONS=1 @@ -29,7 +29,6 @@ export FI_CXI_RDZV_GET_MIN=0 export FI_CXI_OFLOW_BUF_SIZE=1073741824 export FI_CXI_OFLOW_BUF_COUNT=1 - export HF_HOME="$SCRATCH/hf_cache" export TRANSFORMERS_HOME="$SCRATCH/hf_cache" export HF_DATASETS_CACHE="$SCRATCH/hf_cache" @@ -39,13 +38,14 @@ MODEL_TYPE=$2 SEQ_LEN=$3 MODEL_NAME=$(echo "$MODEL" | cut -d'/' -f2) TOPK=$4 -EVAL=$5 +DATASET=$5 +EVAL=$6 SAVE=true OUT_FILE_PATH="experiments/exp-topk/${MODEL_NAME}" mkdir -p $OUT_FILE_PATH -OUT_TENSOR_DATA_PATH="${SCRATCH}/InferenceData/topk/${MODEL_NAME}/${TOPK}/wikitext/" +OUT_TENSOR_DATA_PATH="${SCRATCH}/InferenceData/topk/${MODEL_NAME}/${TOPK}/${DATASET}/" mkdir -p $OUT_TENSOR_DATA_PATH echo "Model: ${MODEL}" @@ -59,9 +59,10 @@ if [ "$SAVE" = true ]; then SAVE_ARGS="--save-tensors --tensors-dir ${OUT_TENSOR_DATA_PATH}" fi -run_cmd="srun -C gpu -N ${NNODES} -n ${GPUS} -c 32 --cpu-bind=cores --gpus-per-node=4 ./set_env_vars_slurm.sh python -u eval_ppl.py --use-axonn --sequence-length ${SEQ_LEN}\ - --model-id ${MODEL} --model-type ${MODEL_TYPE} --dataset wikitext-valid\ +run_cmd="srun -C gpu -N ${NNODES} -n ${GPUS} -c 32 --cpu-bind=cores --gpus-per-node=4 python -u eval_ppl.py --sequence-length ${SEQ_LEN}\ + --model-id ${MODEL} --model-type ${MODEL_TYPE}\ ${SAVE_ARGS}\ + --dataset ${DATASET} --use-topk --top-k ${TOPK} ${EVAL}| tee ${OUT_FILE_PATH}/out_${MODEL_NAME}_${TOPK}${EVAL}.out 2>&1" echo ${run_cmd} diff --git a/examples/topk/topk-llama.sh b/examples/topk/topk-llama.sh new file mode 100644 index 0000000..d8616c3 --- /dev/null +++ b/examples/topk/topk-llama.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +set -x +#sbatch examples/topk/submit_topk.sh meta-llama/Llama-2-7b-hf llama 4096 0.125 +#sbatch examples/topk/submit_h2o.sh meta-llama/Llama-2-7b-hf llama 4096 0.25 +#sbatch examples/topk/submit_h2o.sh meta-llama/Llama-2-13b-hf llama 4096 0.125 +#sbatch examples/topk/submit_h2o.sh meta-llama/Llama-2-13b-hf llama 4096 0.25 +#sbatch examples/topk/submit_h2o.sh meta-llama/Llama-2-70b-hf llama 4096 0.125 +sbatch examples/topk/submit_topk.sh meta-llama/Llama-2-70b-hf llama 4096 1024 +sbatch examples/topk/submit_topk.sh meta-llama/Llama-2-70b-hf llama 4096 4096 +set +x diff --git a/examples/topk/topk-mistral.sh b/examples/topk/topk-mistral.sh new file mode 100644 index 0000000..8345992 --- /dev/null +++ b/examples/topk/topk-mistral.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +set -x +sbatch examples/topk/submit_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 1 wikitext-test +sbatch examples/topk/submit_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.25 wikitext-test +sbatch examples/topk/submit_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.50 wikitext-test + +sbatch examples/topk/submit_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 1 wikitext-test --lm-harness-eval +sbatch examples/topk/submit_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.25 wikitext-test --lm-harness-eval +sbatch examples/topk/submit_topk.sh mistralai/Mistral-7B-v0.1 mistral 4096 0.50 wikitext-test --lm-harness-eval +set +x diff --git a/examples/topk/topk-pythia.sh b/examples/topk/topk-pythia.sh new file mode 100644 index 0000000..ca62546 --- /dev/null +++ b/examples/topk/topk-pythia.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +set -x +sbatch examples/topk/submit_topk.sh EleutherAI/pythia-6.9b gptneox 2048 1 wikitext-test +sbatch examples/topk/submit_topk.sh EleutherAI/pythia-6.9b gptneox 2048 0.25 wikitext-test +sbatch examples/topk/submit_topk.sh EleutherAI/pythia-6.9b gptneox 2048 0.50 wikitext-test + +sbatch examples/topk/submit_topk.sh EleutherAI/pythia-6.9b gptneox 2048 1 wikitext-test --lm-harness-eval +sbatch examples/topk/submit_topk.sh EleutherAI/pythia-6.9b gptneox 2048 0.25 wikitext-test --lm-harness-eval +sbatch examples/topk/submit_topk.sh EleutherAI/pythia-6.9b gptneox 2048 0.50 wikitext-test --lm-harness-eval +set +x diff --git a/helper/downloadllama.py b/helper/downloadllama.py index 0bd26d8..fcd733e 100644 --- a/helper/downloadllama.py +++ b/helper/downloadllama.py @@ -1,7 +1,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer #from datasets import load_dataset -model_id = "huggyllama/llama-30b" +#model_id = "huggyllama/llama-30b" +model_id = "mistralai/Mixtral-8x22B-v0.1" model = AutoModelForCausalLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) diff --git a/infer.py b/infer.py new file mode 100644 index 0000000..a98b3a2 --- /dev/null +++ b/infer.py @@ -0,0 +1,112 @@ +from transformers import AutoTokenizer, AutoModelForCausalLM +from axonn.models.transformers import parallelize +from axonn import axonn as ax +import torch +import random +import numpy as np +import argparse +from datasets import load_dataset + +from methods import init_tensor_saver +from methods.common.configure_model import get_h2o_args, get_topk_args, get_spar_args, get_pca_args, get_save_tensor_args +from methods.common.configure_model import get_modifier +from methods import init_logger, finish_logger +import methods + +OKBLUE = '\033[94m' +OKGREEN = '\033[92m' +ENDC = '\033[0m' + +def init_everything(): + torch.distributed.init_process_group(backend='nccl') + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + if rank == 0: + print(f"Going to distribute the model over {world_size} GPUs") + ax.init(G_data=1, G_inter=1, G_intra_r=world_size, G_intra_c=1, G_intra_d=1) + +def set_seed(seed=123456): + # Extremely important for AxoNN + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def setup_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-id", type=str, default="meta-llama/Llama-2-7b-hf", help="huggingface model to use") + parser.add_argument("--method", type=str, default="baseline", choices=["baseline", "pca-topk"], help="method") + parser.add_argument("--batch-size", type=int, default=32, help="Batch Size") + parser.add_argument("--prompt-length", type=int, default=1988, help="Batch Size") + parser.add_argument("--gen-length", type=int, default=32, help="Batch Size") + parser.add_argument("--seed", type=int, default=1234, help="Seed") + parser.add_argument("--use-optimized-code", action='store_true', default=False) + + return parser + +def load_prompts(tokenizer, batch_size, prompt_length): + dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + encodings = tokenizer("\n\n".join(dataset["text"]), return_tensors="pt") + total_tokens = encodings.input_ids.shape[1] + start_index = min(random.randint(0, total_tokens), total_tokens - batch_size * prompt_length) + input_ids = encodings.input_ids[:, start_index : start_index + batch_size * prompt_length].reshape(batch_size, prompt_length) + return input_ids + +if __name__ == "__main__": + parser = setup_parser() + args = parser.parse_args() + model_id = args.model_id + dtype = torch.float32 + + init_everything() + set_seed(args.seed) + + if args.method == "pca-topk": + args.top_k = args.prompt_length + args.top_r = 128 + args.rotary_type = "postrotary" + + if args.use_optimized_code: + from methods.pca_topk.modify_llama_optimized import make_llama_attention_pca_topk + else: + from methods.pca_topk.modify_llama import make_llama_attention_pca_topk + + make_llama_attention_pca_topk(args) + + with parallelize(model_id): + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype).to('cuda') + + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenized_prompts = load_prompts(tokenizer, args.batch_size, args.prompt_length) + detokenized_prompts = tokenizer.batch_decode(tokenized_prompts) + + total_generated_tokens = 0 + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + generations = [] + + input_ids = tokenized_prompts.cuda() + with torch.autocast(device_type='cuda', dtype=dtype): + outputs = model.generate(input_ids, do_sample=True, max_new_tokens=args.gen_length, num_beams=4) + + end_event.record() + generated_tokens = outputs.numel() - input_ids.numel() + total_generated_tokens += generated_tokens + + torch.cuda.synchronize() + total_time = start_event.elapsed_time(end_event) + tput = total_generated_tokens * 1000 / total_time + + output_ids = outputs[:, args.prompt_length:] + detokenized_generations = tokenizer.batch_decode(output_ids) + + if torch.distributed.get_rank() == 0: + for prompt, generation in zip(detokenized_prompts, detokenized_generations): + print(f"{OKBLUE}[PROMPT]: {prompt}{ENDC}") + print(f"{OKGREEN}[GENERATION]: = {generation}{ENDC}") + print("=====") + print(f"Tput = {tput} generated tokens / second") + diff --git a/methods/__init__.py b/methods/__init__.py index e112969..43c74ab 100644 --- a/methods/__init__.py +++ b/methods/__init__.py @@ -22,11 +22,27 @@ #from .pca_topk.modify_mistral import make_mistral_attention_pca_topk as make_mistral_attention_pca_topk from .common.saver import TensorSaver as TensorSaver +from .common.logger import WandbLogger as WandbLogger +from .common.logger import NoOpLogger as NoOpLogger G_TENSOR_SAVER = None +LOGGER = None def init_tensor_saver(tensor_dir): global G_TENSOR_SAVER G_TENSOR_SAVER = TensorSaver(tensor_dir) +def init_logger(args): + global LOGGER + if args.use_wandb: + LOGGER = WandbLogger(args) + else: + LOGGER = NoOpLogger(args) + +def finish_logger(): + global LOGGER + if LOGGER is not None: + LOGGER.finish() + + diff --git a/methods/baselines/h2o/modify_gptneox.py b/methods/baselines/h2o/modify_gptneox.py new file mode 100644 index 0000000..48d2b4d --- /dev/null +++ b/methods/baselines/h2o/modify_gptneox.py @@ -0,0 +1,98 @@ +from typing import List, Optional, Tuple, Union +import math +import warnings +from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXAttention, apply_rotary_pos_emb +from transformers.cache_utils import Cache +import torch +from torch import nn +import torch.nn.functional as F +from functools import partial + +from .external.h2o_utils import local_heavy_hitter_mask +import methods + + +def get_h2o_attn(args): + def modified_attn(self, query, key, value, attention_mask=None, head_mask=None): + # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] + # compute causal mask from causal mask buffer + batch_size, num_attention_heads, query_length, attn_head_size = query.size() + key_length = key.size(-2) + + # dynamically increase the causal mask with the key length, if needed. + if key_length > self.bias.shape[-1]: + self._init_bias(key_length, device=key.device) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + + query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) + key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) + + attn_scores = torch.zeros( + batch_size * num_attention_heads, + query_length, + key_length, + dtype=query.dtype, + device=key.device, + ) + attn_scores = torch.baddbmm( + attn_scores, + query, + key.transpose(1, 2), + beta=1.0, + alpha=self.norm_factor, + ) + attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) + + mask_value = torch.finfo(attn_scores.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device) + + + + attn_scores = torch.where(causal_mask, attn_scores, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_scores = attn_scores + attention_mask + else: + # Create the attention mask if it is not provided + attention_mask = torch.where(causal_mask, torch.tensor(0.0).to(attn_scores.dtype), mask_value) + + ### Heavy + Recent + heavy_budget = int(args.heavy_ratio * attn_scores.shape[-1]) + recent_budget = int(args.heavy_ratio * attn_scores.shape[-1]) + + # Heavy Hitter Mask + if heavy_budget > 0: + mask_bottom = local_heavy_hitter_mask(attn_scores, heavy_budget) # Default: No padding applied to input + else: + mask_bottom = torch.zeros_like(attn_scores, dtype=torch.bool) + + ones = torch.ones_like(attn_scores, dtype=torch.bool) + ones = torch.triu(ones, diagonal=-recent_budget) + mask_bottom = torch.logical_or(mask_bottom, ones) + + mask_bottom = torch.tril(mask_bottom, diagonal=0) + + # mask_bottom = ones + attn_scores[~mask_bottom] = torch.min(attention_mask) + + attn_weights = nn.functional.softmax(attn_scores, dim=-1) + attn_weights = attn_weights.to(value.dtype) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_weights = self.attention_dropout(attn_weights) + + attn_output = torch.matmul(attn_weights, value) + return attn_output, attn_weights + return modified_attn + +def make_gptneox_attention_h2o(args): + #TODO: Maybe we should not use fractions here to be consistent with other methods + print ("Modifying GPT NeoX Attention -> H2O") + print (f"Heavy and Recent Ratio:{args.heavy_ratio}") + GPTNeoXAttention._attn = get_h2o_attn(args) diff --git a/methods/baselines/h2o/modify_llama.py b/methods/baselines/h2o/modify_llama.py index 5fd29c7..d88658f 100644 --- a/methods/baselines/h2o/modify_llama.py +++ b/methods/baselines/h2o/modify_llama.py @@ -10,7 +10,7 @@ from .external.h2o_utils import local_heavy_hitter_mask -def get_h2o_forward(heavy_ratio): +def get_h2o_forward(args): def modified_forward( self, hidden_states: torch.Tensor, @@ -89,8 +89,8 @@ def modified_forward( attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) ### Heavy + Recent - heavy_budget = int(heavy_ratio * attn_weights.shape[-1]) - recent_budget = int(heavy_ratio * attn_weights.shape[-1]) + heavy_budget = int(args.heavy_ratio * attn_weights.shape[-1]) + recent_budget = int(args.heavy_ratio * attn_weights.shape[-1]) # Heavy Hitter Mask if heavy_budget > 0: @@ -135,8 +135,8 @@ def modified_forward( return attn_output, attn_weights, past_key_value return modified_forward -def make_llama_attention_h2o(hr): +def make_llama_attention_h2o(args): #TODO: Maybe we should not use fractions here to be consistent with other methods print ("Modifying Llama Attention -> H2O") - print (f"Heavy and Recent Ratio:{hr}") - LlamaAttention.forward = get_h2o_forward(hr) + print (f"Heavy and Recent Ratio:{args.heavy_ratio}") + LlamaAttention.forward = get_h2o_forward(args) diff --git a/methods/baselines/h2o/modify_mistral.py b/methods/baselines/h2o/modify_mistral.py index 2d0beac..b087192 100644 --- a/methods/baselines/h2o/modify_mistral.py +++ b/methods/baselines/h2o/modify_mistral.py @@ -10,7 +10,7 @@ from .external.h2o_utils import local_heavy_hitter_mask -def get_h2o_forward(heavy_ratio): +def get_h2o_forward(args): def modified_forward( self, hidden_states: torch.Tensor, @@ -73,8 +73,8 @@ def modified_forward( ### Heavy + Recent - heavy_budget = int(heavy_ratio * attn_weights.shape[-1]) - recent_budget = int(heavy_ratio * attn_weights.shape[-1]) + heavy_budget = int(args.heavy_ratio * attn_weights.shape[-1]) + recent_budget = int(args.heavy_ratio * attn_weights.shape[-1]) # Heavy Hitter Mask if heavy_budget > 0: @@ -113,8 +113,8 @@ def modified_forward( return attn_output, attn_weights, past_key_value return modified_forward -def make_mistral_attention_h2o(hr): +def make_mistral_attention_h2o(args): #TODO: Maybe we should not use fractions here to be consistent with other methods print ("Modifying Mistral Attention -> H2O") - print (f"Heavy and Recent Ratio:{hr}") - MistralAttention.forward = get_h2o_forward(hr) + print (f"Heavy and Recent Ratio:{args.heavy_ratio}") + MistralAttention.forward = get_h2o_forward(args) diff --git a/methods/baselines/h2o/modify_opt.py b/methods/baselines/h2o/modify_opt.py index 104d278..5edeb62 100644 --- a/methods/baselines/h2o/modify_opt.py +++ b/methods/baselines/h2o/modify_opt.py @@ -11,7 +11,7 @@ from .external.h2o_utils import local_heavy_hitter_mask -def get_h2o_forward(heavy_ratio): +def get_h2o_forward(args): def modified_forward( self, hidden_states: torch.Tensor, @@ -87,8 +87,8 @@ def modified_forward( ) ### Heavy + Recent - heavy_budget = int(heavy_ratio * attn_weights.shape[-1]) - recent_budget = int(heavy_ratio * attn_weights.shape[-1]) + heavy_budget = int(args.heavy_ratio * attn_weights.shape[-1]) + recent_budget = int(args.heavy_ratio * attn_weights.shape[-1]) # Heavy Hitter Mask if heavy_budget > 0: @@ -155,8 +155,8 @@ def modified_forward( return modified_forward -def make_opt_attention_h2o(hr): +def make_opt_attention_h2o(args): #TODO: Maybe we should not use fractions here to be consistent with other methods print ("Modifying OPT Attention -> H2O") - print (f"Heavy and Recent Ratio:{hr}") - OPTAttention.forward = get_h2o_forward(hr) + print (f"Heavy and Recent Ratio:{args.heavy_ratio}") + OPTAttention.forward = get_h2o_forward(args) diff --git a/methods/baselines/h2o_hf_opt/modify_llama.py b/methods/baselines/h2o_hf_opt/modify_llama.py index a3699f3..60063fe 100644 --- a/methods/baselines/h2o_hf_opt/modify_llama.py +++ b/methods/baselines/h2o_hf_opt/modify_llama.py @@ -11,7 +11,7 @@ from .h2o_utils import local_heavy_hitter_mask -def get_hfopth2o_forward(heavy_ratio): +def get_hfopth2o_forward(args): def modified_forward( self, hidden_states: torch.Tensor, @@ -90,8 +90,8 @@ def modified_forward( attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) ### Heavy + Recent - heavy_budget = int(heavy_ratio * attn_weights.shape[-1]) - recent_budget = int(heavy_ratio * attn_weights.shape[-1]) + heavy_budget = int(args.heavy_ratio * attn_weights.shape[-1]) + recent_budget = int(args.heavy_ratio * attn_weights.shape[-1]) # Heavy Hitter Mask alpha = None @@ -109,27 +109,13 @@ def modified_forward( # Now we change the mask_bottom tensor to have 0 in place of True and -inf in place of False mask_bottom = torch.where(mask_bottom, torch.tensor(0.0, dtype=torch.float32), torch.tensor(-float('inf'), dtype=torch.float32)).to(attention_mask.dtype) - #mask_bottom = mask_bottom.float().neg_().add_(1).mul_(-float('inf')) - - # mask_bottom = ones attn_weights = attn_weights + mask_bottom - #attn_weights[~mask_bottom] = torch.min(attention_mask) - - #alpha = torch.sum(mask_bottom, dim=-1, keepdim=True)/mask_bottom.shape[-1]; - - #mask_bottom = ~mask_bottom - #mask_bottom = torch.tril(mask_bottom, diagonal=0) - - #mask_bottom = (mask_bottom / (torch.sum(mask_bottom, dim=-1, keepdim=True) + 1e-8)) # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) - #if alpha is not None: - # attn_output = alpha * attn_output + (1-alpha) * torch.mean(attn_output, dim=-2, keepdim=True) - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" @@ -153,8 +139,8 @@ def modified_forward( return attn_output, attn_weights, past_key_value return modified_forward -def make_llama_attention_h2o(hr): +def make_llama_attention_h2o(args): #TODO: Maybe we should not use fractions here to be consistent with other methods print ("Modifying Llama Attention -> HF Optimised H2O") - print (f"Heavy and Recent Ratio:{hr}") - LlamaAttention.forward = get_hfopth2o_forward(hr) + print (f"Heavy and Recent Ratio:{args.heavy_ratio}") + LlamaAttention.forward = get_hfopth2o_forward(args) diff --git a/methods/baselines/topk/modify_gemma.py b/methods/baselines/topk/modify_gemma.py index 3227354..2392193 100644 --- a/methods/baselines/topk/modify_gemma.py +++ b/methods/baselines/topk/modify_gemma.py @@ -34,9 +34,9 @@ def modified_forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if methods.G_TENSOR_SAVER is not None: - methods.G_TENSOR_SAVER.save("key", key_states, self.layer_idx) - methods.G_TENSOR_SAVER.save("query", query_states, self.layer_idx) - methods.G_TENSOR_SAVER.save("value", value_states, self.layer_idx) + methods.G_TENSOR_SAVER.save("key", key_states, self.layer_idx, "prerotary") + #methods.G_TENSOR_SAVER.save("query", query_states, self.layer_idx, "prerotary") + #methods.G_TENSOR_SAVER.save("value", value_states, self.layer_idx, "prerotary") past_key_value = getattr(self, "past_key_value", past_key_value) @@ -48,6 +48,10 @@ def modified_forward( cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if methods.G_TENSOR_SAVER is not None: + methods.G_TENSOR_SAVER.save("key", key_states, self.layer_idx, "postrotary") + #methods.G_TENSOR_SAVER.save("query", query_states, self.layer_idx, "postrotary") + key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) diff --git a/methods/baselines/topk/modify_gptneox.py b/methods/baselines/topk/modify_gptneox.py index d27286e..cf562e9 100644 --- a/methods/baselines/topk/modify_gptneox.py +++ b/methods/baselines/topk/modify_gptneox.py @@ -11,7 +11,7 @@ from methods.common.utils import mask_attn_top_k import methods -def get_topk_init(top_k): +def get_topk_init(args): def modified_attention_init(self, config): super(GPTNeoXAttention, self).__init__() self.config = config @@ -38,8 +38,67 @@ def modified_attention_init(self, config): return modified_attention_init - -def get_top_k_forward(top_k, use_percentage=False): +def get_topk_attn(args): + def modified_attn(self, query, key, value, attention_mask=None, head_mask=None): + # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] + # compute causal mask from causal mask buffer + batch_size, num_attention_heads, query_length, attn_head_size = query.size() + key_length = key.size(-2) + + # dynamically increase the causal mask with the key length, if needed. + if key_length > self.bias.shape[-1]: + self._init_bias(key_length, device=key.device) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + + query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) + key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) + attn_scores = torch.zeros( + batch_size * num_attention_heads, + query_length, + key_length, + dtype=query.dtype, + device=key.device, + ) + attn_scores = torch.baddbmm( + attn_scores, + query, + key.transpose(1, 2), + beta=1.0, + alpha=self.norm_factor, + ) + attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) + + mask_value = torch.finfo(attn_scores.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device) + attn_scores = torch.where(causal_mask, attn_scores, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_scores = attn_scores + attention_mask + + # Get top-k attention weights + if args.top_k <= 1: + topk = int(args.top_k * attn_scores.shape[-1]) + else: + topk = int(args.top_k) + attn_scores = mask_attn_top_k(attn_scores, topk, dim=-1) + + attn_weights = nn.functional.softmax(attn_scores, dim=-1) + attn_weights = attn_weights.to(value.dtype) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_weights = self.attention_dropout(attn_weights) + + attn_output = torch.matmul(attn_weights, value) + return attn_output, attn_weights + return modified_attn + +def get_top_k_forward(args): def modified_forward( self, hidden_states: torch.FloatTensor, @@ -70,8 +129,8 @@ def modified_forward( if methods.G_TENSOR_SAVER is not None: methods.G_TENSOR_SAVER.save("key", key, self.layer_idx, "prerotary") - methods.G_TENSOR_SAVER.save("query", query, self.layer_idx, "prerotary") - methods.G_TENSOR_SAVER.save("value", value, self.layer_idx, "prerotary") + #methods.G_TENSOR_SAVER.save("query", query, self.layer_idx, "prerotary") + #methods.G_TENSOR_SAVER.save("value", value, self.layer_idx, "prerotary") # Compute rotary embeddings on rotary_ndims query_rot = query[..., : self.rotary_ndims] @@ -99,15 +158,13 @@ def modified_forward( # Compute attention attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) - # TODO: Implement top-k scheme - # Reshape outputs attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) attn_output = self.dense(attn_output) if methods.G_TENSOR_SAVER is not None: methods.G_TENSOR_SAVER.save("key", key, self.layer_idx, "postrotary") - methods.G_TENSOR_SAVER.save("query", query, self.layer_idx, "postrotary") + #methods.G_TENSOR_SAVER.save("query", query, self.layer_idx, "postrotary") outputs = (attn_output, present) if output_attentions: @@ -116,12 +173,13 @@ def modified_forward( return outputs return modified_forward -def make_gptneox_attention_top_k(top_k, use_percentage=False): +def make_gptneox_attention_top_k(args): print ("Modifying GPT Neo X Attention -> TopK Attention") - if not use_percentage: - print (f"TopK - {top_k}") + if args.top_k <= 1: + print (f"TopK - {args.top_k} (Percentage)") else: - print (f"TopK% - {top_k}") + print (f"TopK - {args.top_k}") - GPTNeoXAttention.forward = get_top_k_forward(top_k, use_percentage) - GPTNeoXAttention.__init__ = get_topk_init(top_k) + GPTNeoXAttention.forward = get_top_k_forward(args) + GPTNeoXAttention.__init__ = get_topk_init(args) + GPTNeoXAttention._attn = get_topk_attn(args) diff --git a/methods/baselines/topk/modify_llama.py b/methods/baselines/topk/modify_llama.py index 68f1ca3..de2e706 100644 --- a/methods/baselines/topk/modify_llama.py +++ b/methods/baselines/topk/modify_llama.py @@ -19,7 +19,7 @@ AXONN_AVAILABLE=False -def get_top_k_forward(top_k): +def get_top_k_forward(args): def modified_forward( self, hidden_states: torch.Tensor, @@ -106,10 +106,10 @@ def modified_forward( attn_weights = attn_weights + causal_mask # Get top-k attention weights - if top_k <= 1: - topk = int(top_k * attn_weights.shape[-1]) + if args.top_k <= 1: + topk = int(args.top_k * attn_weights.shape[-1]) else: - topk = int(top_k) + topk = int(args.top_k) attn_weights = mask_attn_top_k(attn_weights, topk, dim=-1) # upcast attention to fp32 @@ -141,11 +141,11 @@ def modified_forward( return attn_output, attn_weights, past_key_value return modified_forward -def make_llama_attention_top_k(top_k, use_percentage=False): +def make_llama_attention_top_k(args): print ("Modifying Llama Attention -> TopK Attention") - if not use_percentage: - print (f"TopK - {top_k}") + if args.top_k <= 1: + print (f"TopK% - {args.top_k}") else: - print (f"TopK% - {top_k}") + print (f"TopK - {args.top_k}") - LlamaAttention.forward = get_top_k_forward(top_k) + LlamaAttention.forward = get_top_k_forward(args) diff --git a/methods/baselines/topk/modify_mistral.py b/methods/baselines/topk/modify_mistral.py index e316ed3..98ae9a7 100644 --- a/methods/baselines/topk/modify_mistral.py +++ b/methods/baselines/topk/modify_mistral.py @@ -3,6 +3,7 @@ import math import warnings from transformers.models.mistral.modeling_mistral import MistralAttention, repeat_kv, apply_rotary_pos_emb +from transformers.models.mixtral.modeling_mixtral import MixtralAttention from transformers.cache_utils import Cache import torch from torch import nn @@ -12,7 +13,7 @@ from methods.common.utils import mask_attn_top_k import methods -def get_top_k_forward(top_k): +def get_top_k_forward(args): def modified_forward( self, hidden_states: torch.Tensor, @@ -39,8 +40,8 @@ def modified_forward( if methods.G_TENSOR_SAVER is not None: methods.G_TENSOR_SAVER.save("key", key_states, self.layer_idx, "prerotary") - methods.G_TENSOR_SAVER.save("query", query_states, self.layer_idx, "prerotary") - methods.G_TENSOR_SAVER.save("value", value_states, self.layer_idx, "prerotary") + #methods.G_TENSOR_SAVER.save("query", query_states, self.layer_idx, "prerotary") + #methods.G_TENSOR_SAVER.save("value", value_states, self.layer_idx, "prerotary") kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -60,7 +61,7 @@ def modified_forward( if methods.G_TENSOR_SAVER is not None: methods.G_TENSOR_SAVER.save("key", key_states, self.layer_idx, "postrotary") - methods.G_TENSOR_SAVER.save("query", query_states, self.layer_idx, "postrotary") + #methods.G_TENSOR_SAVER.save("query", query_states, self.layer_idx, "postrotary") # repeat k/v heads if n_kv_heads < n_heads @@ -84,10 +85,10 @@ def modified_forward( attn_weights = attn_weights + attention_mask # Get top-k attention weights - if top_k <= 1: - topk = int(top_k * attn_weights.shape[-1]) + if args.top_k <= 1: + topk = int(args.top_k * attn_weights.shape[-1]) else: - topk = int(top_k) + topk = int(args.top_k) attn_weights = mask_attn_top_k(attn_weights, topk, dim=-1) # upcast attention to fp32 @@ -113,11 +114,12 @@ def modified_forward( return modified_forward # TODO: Remove use_percentage -def make_mistral_attention_top_k(top_k, use_percentage=False): - print ("Modifying Mistral Attention -> TopK Attention") - if not use_percentage: - print (f"TopK - {top_k}") +def make_mistral_attention_top_k(args): + print ("Modifying Mistral and Mixtral Attention -> TopK Attention") + if args.top_k <= 1: + print (f"TopK% - {args.top_k}") else: - print (f"TopK% - {top_k}") + print (f"TopK - {args.top_k}") - MistralAttention.forward = get_top_k_forward(top_k) \ No newline at end of file + MistralAttention.forward = get_top_k_forward(args) + MixtralAttention.forward = get_top_k_forward(args) \ No newline at end of file diff --git a/methods/baselines/topk/modify_opt.py b/methods/baselines/topk/modify_opt.py index a3ac12e..9cd405f 100644 --- a/methods/baselines/topk/modify_opt.py +++ b/methods/baselines/topk/modify_opt.py @@ -10,7 +10,7 @@ import methods -def get_top_k_forward(top_k, use_percentage=False): +def get_top_k_forward(args): def modified_forward( self, hidden_states: torch.Tensor, @@ -88,10 +88,10 @@ def modified_forward( attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) # Get top-k attention weights - if top_k <= 1: - topk = int(top_k * attn_weights.shape[-1]) + if args.top_k <= 1: + topk = int(args.top_k * attn_weights.shape[-1]) else: - topk = int(top_k) + topk = int(args.top_k) attn_weights = mask_attn_top_k(attn_weights, topk, dim=-1) # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 @@ -143,11 +143,11 @@ def modified_forward( # TODO: Remove use_percentage -def make_opt_attention_top_k(top_k, use_percentage=False): +def make_opt_attention_top_k(args): print ("Modifying OPT Attention -> TopK Attention") - if not use_percentage: - print (f"TopK - {top_k}") + if args.top_k <= 1: + print (f"TopK% - {args.top_k}") else: - print (f"TopK% - {top_k}") + print (f"TopK - {args.top_k}") - OPTAttention.forward = get_top_k_forward(top_k, use_percentage) \ No newline at end of file + OPTAttention.forward = get_top_k_forward(args) \ No newline at end of file diff --git a/methods/baselines/topk/modify_phi.py b/methods/baselines/topk/modify_phi.py new file mode 100644 index 0000000..d14b7a6 --- /dev/null +++ b/methods/baselines/topk/modify_phi.py @@ -0,0 +1,117 @@ + +import sys +sys.path.append("/pscratch/sd/p/prajwal/hf_cache/modules/transformers_modules/microsoft/Phi-3-mini-4k-instruct/") +#sys.path.append("/pscratch/sd/p/prajwal/hf_cache/") + +from typing import List, Optional, Tuple, Union +import math +import warnings +#from modeling_phi3 import Phi3Attention, repeat_kv, apply_rotary_pos_emb +from bce928f38989812b69c6f8e3a86763e004387d16.modeling_phi3 import Phi3Attention, repeat_kv, apply_rotary_pos_emb +from transformers.cache_utils import Cache +import torch +from torch import nn +import torch.nn.functional as F +from functools import partial + +from methods.common.utils import mask_attn_top_k +import methods + +def get_top_k_forward(args): + def modified_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.") + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if methods.G_TENSOR_SAVER is not None: + methods.G_TENSOR_SAVER.save("key", key_states, self.layer_idx, "postrotary") + #methods.G_TENSOR_SAVER.save("query", query_states, self.layer_idx, "postrotary") + + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + return modified_forward + + +def make_phi_attention_top_k(args): + print ("Modifying Phi Attention -> TopK Attention") + if args.top_k > 1: + print (f"TopK - {args.top_k}") + else: + print (f"TopK% - {args.top_k}") + + Phi3Attention.forward = get_top_k_forward(args) diff --git a/configure_model.py b/methods/common/configure_model.py similarity index 65% rename from configure_model.py rename to methods/common/configure_model.py index 8bda66f..1a7314e 100644 --- a/configure_model.py +++ b/methods/common/configure_model.py @@ -3,7 +3,6 @@ def get_h2o_args(parser): parser.add_argument("--use-h2o", action='store_true', default=False, help="use the H2O algos") parser.add_argument("--heavy-ratio", type=float, default=0.1, help="H2O heavy ratio," "set to 0.1 by default") - parser.add_argument("--recent-ratio", type=float, default=0.1, help="H2O recent ratio," "set to 0.1 by default") return parser def get_topk_args(parser): @@ -21,6 +20,8 @@ def get_spar_args(parser): def get_pca_args(parser): parser.add_argument("--use-pca", action='store_true', default=False, help="use the PCA algos") parser.add_argument("--use-pca-topk", action='store_true', default=False, help="use the PCA TopK algos") + parser.add_argument("--rotary-type", type=str, default="postrotary", help="rotary type") + parser.add_argument("--recent-ratio", type=float, default=-1, help="PcaTopK recent ratio," "set to -1 by default") return parser def get_save_tensor_args(parser): @@ -57,4 +58,36 @@ def get_modifier(args): method_name = "make_" + args.model_type + "_attention_" + method_name module = import_module(module_name, package="methods") method = getattr(module, method_name) - return method \ No newline at end of file + return method + +def get_config_dict(args): + config_dict = {} + config_dict["model"] = args.model_id + config_dict["sequence_length"] = args.sequence_length + if args.use_h2o: + config_dict["method"] = "h2o" + config_dict["heavy_ratio"] = args.heavy_ratio + elif args.use_topk: + config_dict["method"] = "topk" + config_dict["top_k"] = args.top_k + elif args.use_spark: + config_dict["method"] = "spark" + config_dict["top_r"] = args.top_r + config_dict["top_k"] = args.top_k + elif args.use_sparq: + config_dict["method"] = "sparq" + config_dict["top_r"] = args.top_r + config_dict["top_k"] = args.top_k + elif args.use_spar_hat: + config_dict["method"] = "spar_hat" + config_dict["top_r"] = args.top_r + elif args.use_pca: + config_dict["method"] = "pca" + config_dict["top_r"] = args.top_r + elif args.use_pca_topk: + config_dict["method"] = "pca_topk" + config_dict["top_r"] = args.top_r + config_dict["top_k"] = args.top_k + config_dict["rotary_type"] = args.rotary_type + config_dict["recent_ratio"] = args.recent_ratio + return config_dict \ No newline at end of file diff --git a/methods/common/logger.py b/methods/common/logger.py new file mode 100644 index 0000000..c6a4433 --- /dev/null +++ b/methods/common/logger.py @@ -0,0 +1,59 @@ +from methods.common.configure_model import get_config_dict +import wandb +import torch +import os + +os.environ["WANDB__SERVICE_WAIT"] = "300" + +class WandbLogger: + def __init__(self, args): + self.rank = os.environ.get("RANK") + if self.rank == '0': + self.config = get_config_dict(args) + jobid = os.environ.get("JOBID", '0') + if args.lm_harness_eval: + groupid = "lm_harness" + else: + groupid = "ppl" + self.run = wandb.init(project='PCA-TopK', config=self.config, name=jobid, + group=groupid, job_type='eval', tags=[groupid]) + + def update_config(self, kwargs): + if self.rank == '0': + self.run.config.update(kwargs) + + def log(self, kwargs): + if self.rank == '0': + self.run.log(kwargs) + + def log_ppl(self, ppl): + if self.rank == '0': + self.run.log({'perplexity': ppl}) + + def log_lm_harness_results(self, tasks, results): + if self.rank == '0': + assert results is not None + for task in tasks.keys(): + metric = tasks[task] + result = results[task] + if metric in result.keys(): + # Replace ,none with empty string + metric_name = metric.replace(",none", "") + self.run.log({task + "_" + metric_name: result[metric]}) + + def finish(self): + if self.rank == '0': + self.run.finish() + +class NoOpLogger: + def __init__(self, args): + pass + + def update_config(self, kwargs): + pass + + def log(self, kwargs): + pass + + def finish(self): + pass \ No newline at end of file diff --git a/methods/common/saver.py b/methods/common/saver.py index 488471c..1f0b3fd 100644 --- a/methods/common/saver.py +++ b/methods/common/saver.py @@ -18,6 +18,12 @@ def __init__(self, output_dir): self.index_dict[category] = 0 def save(self, category, tensor, extra_idx = None, extra_dir = ""): + # Only save the tensor if the rank is 0 + if torch.distributed.get_rank() != 0: + return + # Print the first time the function is called + if self.index_dict[category] == 0: + print(f"Saving tensor {category} with shape {tensor.shape}") os.makedirs(os.path.join(self.output_dir, extra_dir, category), exist_ok=True) output_dir = os.path.join(self.output_dir, extra_dir, category) # Clear the directory if it is the first tensor diff --git a/methods/pca_topk/cache_utils.py b/methods/pca_topk/cache_utils.py new file mode 100644 index 0000000..12d20bb --- /dev/null +++ b/methods/pca_topk/cache_utils.py @@ -0,0 +1,281 @@ + +from typing import Any, Dict, List, Optional, Tuple +from transformers.cache_utils import Cache +import math +import time + +import torch +import external.gather_matmul as G + + + +topk_time = 0 +iter_num = 0 +torch.backends.cuda.matmul.allow_tf32 = False + +# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. +torch.backends.cudnn.allow_tf32 = False + + +# Work In Progress +class PcaTopKCache(Cache): # Not used anymore + """ + Cache based on PcaTopK mechanism + """ + def __init__(self) -> None: + self.key_cache: List[torch.Tensor] = [] # Stores the reduced keys for each layer + self.value_cache: List[torch.Tensor] = [] + #self.top_r = r + #self.top_k = k + #print (f"Cache initialized with top_r = {r}, top_k = {k}") + + @torch.no_grad() + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + query_states: torch.Tensor, + layer_idx: int, + topk: bool = True, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. + + Return: + A tuple containing the updated key and value states. + """ + if len(self.key_cache) <= layer_idx: + # Empty cache + # Assume that the keys are alread in the PCA space + self.key_cache.append(key_states) + self.value_cache.append(value_states) + + # This is also the prompt iteration so we need all the keys for attention + return self.key_cache[layer_idx], self.value_cache[layer_idx] + else: + #global topk_time + #global iter_num + #start = time.time() + ## Growing cache + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + if not topk: + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Compute approximate attention scores assuming that keys and queries are already in PCA space + #query_states_pca = query_states[:, :, :, :self.top_r] + #key_states_pca = self.key_cache[layer_idx][:, :, :, :self.top_r] + + #head_dim = query_states.shape[-1] + scaling_factor = head_dim * torch.sqrt((torch.square(key_states_pca).sum(-1 , keepdim=True) / torch.square(self.key_cache[layer_idx]).sum(-1, keepdim = True))) + scaling_factor = scaling_factor.transpose(-1, -2) + + attn_weights = torch.matmul(query_states[:,:,:,:self.top_r], self.key_cache[layer_idx][:,:,:,:self.top_r].transpose(2, 3)) / math.sqrt(head_dim) + + # Get top-k keys and top-k values based on the attention scores + ################# Unoptimized Version + # key_states_topk_indices = torch.topk(attn_weights, self.top_k, dim=-1).indices + # key_states_topk_indices = key_states_topk_indices.transpose(-1, -2).expand(-1, -1, -1, head_dim) + + # key_states_topk = torch.gather(self.key_cache[layer_idx], -2, key_states_topk_indices) + # value_states_topk = torch.gather(self.value_cache[layer_idx], -2, key_states_topk_indices) + ################## + + + return key_states_topk, value_states_topk + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-2] + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states, if there is any.""" + return None + + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_length() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + +def test_pcatopk_cache(): + cache = PcaTopKCache(2, 4) + + torch.set_printoptions(threshold=float('inf')) + prompt_keys = torch.rand(1, 2, 8, 8) + print (prompt_keys) + + cache.update(prompt_keys, prompt_keys, prompt_keys, 0) + + generative_query = torch.rand(1, 2, 1, 8) + generative_key = torch.rand(1, 2, 1, 8) + + top_keys, top_vals = cache.update(generative_key, generative_key, generative_query, 0) + + print (f"Top K Keys: {top_keys.shape}") + print (top_keys) + + + +def micro_benchmark_pca_topk(cache, prompt_keys, top_r, top_k, num_gen_steps=2000, use_optimised_gather=False): + import time + torch.set_float32_matmul_precision("highest") + + head_dim = prompt_keys.shape[-1] + bs = prompt_keys.shape[0] + num_heads = prompt_keys.shape[1] + + generative_query = torch.rand(bs, num_heads, 1, head_dim).to("cuda") + generative_key = torch.rand(bs, num_heads, 1, head_dim).to("cuda") + + print ("Starting microbenchmark") + matmul_time = 0 + top_keys = torch.zeros(bs, num_heads, top_k, head_dim).to("cuda") + top_vals = torch.zeros(bs, num_heads, top_k, head_dim).to("cuda") + + if use_optimised_gather: + for i in range(num_gen_steps): + keys, vals = cache.update(generative_key, generative_key, generative_query, 0, False) + torch.cuda.synchronize() + + start = time.time() + attn_weights = torch.matmul(generative_query[:,:,:,:top_r], keys.transpose(2, 3)[:,:,:top_r,:]) / math.sqrt(head_dim) + + # Get top-k keys and top-k values based on the attention scores + key_states_topk_indices = torch.topk(attn_weights, top_k, dim=-1).indices.to("cuda") + key_states_topk_indices,_ = torch.sort(key_states_topk_indices, dim=-1) + key_states_topk_indices= key_states_topk_indices.reshape(-1, key_states_topk_indices.shape[-1]) + + keys = keys.reshape(-1, keys.shape[-2] , keys.shape[-1]) + vals = vals.reshape(-1, vals.shape[-2] , vals.shape[-1]) + + attn_weights = G.gather_outer_bmv( + generative_query.reshape(-1, 1, head_dim), + keys.transpose(-1, -2), + key_states_topk_indices, + #.squeeze(0).squeeze(-1), + chunk=256 + #chunk=min(k2, 65536 // Q.shape[-1]), + ) / math.sqrt(head_dim) + attn_weights = torch.softmax(attn_weights, dim=-1) + + attn_output = G.gather_inner_matrix_only_bmv( + attn_weights, vals, key_states_topk_indices, chunk=64 + ) + + torch.cuda.synchronize() + end = time.time() + + if i > 5: + matmul_time += end - start + else: + for i in range(num_gen_steps): + keys, vals = cache.update(generative_key, generative_key, generative_query, 0, False) + torch.cuda.synchronize() + + start = time.time() + attn_weights = torch.matmul(generative_query[:,:,:,:top_r], keys.transpose(2, 3)[:,:,:top_r,:]) / math.sqrt(head_dim) + + # Get top-k keys and top-k values based on the attention scores + key_states_topk_indices = torch.topk(attn_weights, top_k, dim=-1).indices.to("cuda") + key_states_topk_indices,_ = torch.sort(key_states_topk_indices, dim=-1) + key_states_topk_indices = key_states_topk_indices.transpose(-1, -2).expand(-1, -1, -1, head_dim) + + torch.gather(keys, -2, key_states_topk_indices, out=top_keys) + torch.gather(vals, -2, key_states_topk_indices, out=top_vals) + + attn_weights = torch.matmul(generative_query, top_keys.transpose(2, 3)) / math.sqrt(head_dim) + attn_weights = torch.softmax(attn_weights, dim=-1) + attn_output = torch.matmul(attn_weights, top_vals) + torch.cuda.synchronize() + end = time.time() + + if i > 5: + matmul_time += end - start + print (f"Matmul Time: {matmul_time}") + +def micro_bench_actual_attention(cache, prompt_keys, num_gen_steps=2000): + import time + torch.set_float32_matmul_precision("highest") + + head_dim = prompt_keys.shape[-1] + bs = prompt_keys.shape[0] + num_heads = prompt_keys.shape[1] + + generative_query = torch.rand(bs, num_heads, 1, head_dim).to("cuda") + generative_key = torch.rand(bs, num_heads, 1, head_dim).to("cuda") + + print ("Starting microbenchmark") + matmul_time = 0 + for i in range(num_gen_steps): + keys, vals = cache.update(generative_key, generative_key, generative_query, 0, False) + torch.cuda.synchronize() + + start = time.time() + attn_weights = torch.matmul(generative_query, keys.transpose(2, 3)) / math.sqrt(head_dim) + attn_weights = torch.softmax(attn_weights, dim=-1) + attn_output = torch.matmul(attn_weights, vals) + torch.cuda.synchronize() + end = time.time() + + if i > 5: + matmul_time += end - start + print (f"Matmul Time: {matmul_time}") + +def benchmark_attention(batch_size=1, + num_heads=32, + num_gen_steps=128, + prompt_length=3072, + topk=256): + + head_dim=128 + # Change this to change batch size, etc. + prompt_keys = torch.rand(batch_size, num_heads, prompt_length, head_dim).to("cuda") + + + print("PCA TOPK Unoptimized") + cache1 = PcaTopKCache() + cache1.update(prompt_keys, prompt_keys, prompt_keys, 0) + micro_benchmark_pca_topk(cache1, prompt_keys, 32, topk, num_gen_steps=num_gen_steps) + del cache1 + + print("PCA TOPK Optimized") + cache2 = PcaTopKCache() + cache2.update(prompt_keys, prompt_keys, prompt_keys, 0) + micro_benchmark_pca_topk(cache2, prompt_keys, 32, topk, num_gen_steps=num_gen_steps, use_optimised_gather=True) + del cache2 + + print("Actual Attention") + cache3= PcaTopKCache() + cache3.update(prompt_keys, prompt_keys, prompt_keys, 0) + micro_bench_actual_attention(cache3, prompt_keys, num_gen_steps=num_gen_steps) + del cache3 + +if __name__ == "__main__": + #test_pcatopk_cache() + with torch.no_grad(): + benchmark_attention(prompt_length=4096, num_gen_steps=2000, batch_size=16, topk=1024) + + diff --git a/methods/pca_topk/external/gather_matmul.py b/methods/pca_topk/external/gather_matmul.py new file mode 100644 index 0000000..d3e8dde --- /dev/null +++ b/methods/pca_topk/external/gather_matmul.py @@ -0,0 +1,234 @@ +# Copyright (c) 2024 Graphcore Ltd. All rights reserved. +# This code was copied from https://github.com/graphcore-research/llm-inference-research/blob/benchmarks/src/gather_matmul.py + +# MIT License +# +# Copyright (c) 2024 Graphcore Ltd. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math +import warnings + +import torch +import triton +import triton.language as tl +from torch import Tensor + + +@triton.jit +def _kernel_gather_inner_bmv( + A_ptr, + B_ptr, + I_ptr, + Y_ptr, + k: tl.constexpr, # int + n: int, + n_chunk: tl.constexpr, # int + A_s0: int, + A_s2: int, + B_s0: int, + B_s1: int, + B_s2: int, + I_s0: int, + I_s1: int, + gather_A: tl.constexpr, # bool +): + pid = tl.program_id(axis=0).to(tl.int64) + i = tl.load(I_ptr + pid * I_s0 + tl.arange(0, k) * I_s1) # (k) + a = tl.load(A_ptr + pid * A_s0 + (i if gather_A else tl.arange(0, k)) * A_s2) # (k) + for chunk in range(0, tl.cdiv(n, n_chunk)): + chunk_idx = chunk * n_chunk + tl.arange(0, n_chunk) + b = tl.load( # (k x n_chunk) + B_ptr + pid * B_s0 + (i * B_s1)[:, None] + (chunk_idx * B_s2)[None, :] + ) + # As tl.dot() is unavailable for matrix-vector + y = tl.sum(a[:, None] * b, 0) # (n_chunk) + tl.store(Y_ptr + pid * n + chunk_idx, y, mask=(chunk_idx < n)) + + +def gather_inner_bmv( + A: Tensor, B: Tensor, I: Tensor, chunk: int, _matrix_only: bool = False +) -> Tensor: + """Batched vector-matrix multiplication, with a gather on the inner dimension. + + Dimensions: + b -- batch + k* -- (pre-gather) inner dimension + k -- (post-gather) inner dimension (k <= k*), must be a power of two + n -- outer dimension + + A -- (b, 1, k*) batch of vectors + B -- (b, k*, n) batch of matrices + I -- int(b, k) indices, in [0, k*) + chunk -- int size of chunks of `B` (along dimension `n`) to be processed at a time + _matrix_only -- bool don't use (see `gather_inner_matrix_only_bmv`) + + returns -- (b, 1, n) the inner product of `A` and `B`, after gathering the inner dimension + according to `I` + """ + if A.ndim > 3: + assert B.ndim == A.ndim and I.ndim == A.ndim - 1 + return gather_inner_bmv( + A.flatten(end_dim=-3), + B.flatten(end_dim=-3), + I.flatten(end_dim=-2), + chunk=chunk, + _matrix_only=_matrix_only, + ).unflatten(0, A.shape[:-2]) + assert A.ndim == 3 and B.ndim == 3 and A.shape[1] == 1 + assert ( + I.ndim == 2 + and I.shape[0] == A.shape[0] + and 2 ** int(math.log2(I.shape[1])) == I.shape[1] + ) + assert A.shape[2] == (I.shape[1] if _matrix_only else B.shape[1]) + if B.stride(2) != 1: + warnings.warn( + "gather_inner_bmv(A, B, ...) `B` should be contiguous in the last dimension" + ", otherwise it is very slow" + ) + + b, k, n = A.shape[0], I.shape[1], B.shape[2] + Y = torch.empty((b, 1, n), dtype=A.dtype, device=A.device) + assert Y.stride(0) == n and Y.stride(2) == 1 + + _kernel_gather_inner_bmv[(b,)]( + A_ptr=A, + B_ptr=B, + I_ptr=I, + Y_ptr=Y, + k=k, + n=n, + n_chunk=chunk, + A_s0=A.stride(0), + A_s2=A.stride(2), + B_s0=B.stride(0), + B_s1=B.stride(1), + B_s2=B.stride(2), + I_s0=I.stride(0), + I_s1=I.stride(1), + gather_A=not _matrix_only, + ) + return Y + + +def gather_inner_matrix_only_bmv(A: Tensor, B: Tensor, I: Tensor, chunk: int) -> Tensor: + """Batched vector-matrix multiplication, with a gather on the inner dimension of the matrix. + + Dimensions: + b -- batch + k* -- (pre-gather) inner dimension + k -- (post-gather) inner dimension (k <= k*), must be a power of two + n -- outer dimension + + A -- (b, 1, k) batch of vectors + B -- (b, k*, n) batch of matrices + I -- int(b, k) indices, in [0, k*) + chunk -- int size of chunks of `B` (along dimension `n`) to be processed at a time + + returns -- (b, 1, n) the inner product of `A` and `B`, after gathering the inner dimension + of `B` according to `I` + """ + return gather_inner_bmv(A, B, I, chunk=chunk, _matrix_only=True) + + +@triton.jit +def _kernel_gather_outer_bmv( + A_ptr, + B_ptr, + I_ptr, + Y_ptr, + k: tl.constexpr, + n: int, + n_chunk: tl.constexpr, + A_s0: int, + A_s2: int, + B_s0: int, + B_s1: int, + B_s2: int, + I_s0: int, + I_s1: int, +): + pid = tl.program_id(axis=0).to(tl.int64) + a = tl.load(A_ptr + pid * A_s0 + tl.arange(0, k) * A_s2) # (k) + for chunk in range(0, tl.cdiv(n, n_chunk)): + chunk_idx = chunk * n_chunk + tl.arange(0, n_chunk) + i = tl.load(I_ptr + pid * I_s0 + chunk_idx * I_s1) # (n_chunk) + b = tl.load( # (k x n_chunk) + B_ptr + + pid * B_s0 + + (tl.arange(0, k) * B_s1)[:, None] + + (i * B_s2)[None, :], + mask=(chunk_idx < n)[None, :], + ) + # # As tl.dot() is unavailable for matrix-vector + y = tl.sum(a[:, None] * b, 0) # (n_chunk) + tl.store(Y_ptr + pid * n + chunk_idx, y, mask=(chunk_idx < n)) + + +def gather_outer_bmv(A: Tensor, B: Tensor, I: Tensor, chunk: int) -> Tensor: + """Batched vector-matrix multiplication, with a gather on the matrix outer dimension. + + Dimensions: + b -- batch + k -- inner dimension, must be a power of two + n* -- (pre-gather) outer dimension + n -- (post-gather) outer dimension (n <= n*) + + A -- (b, 1, k) batch of vectors + B -- (b, k, n*) batch of matrices + I -- int(b, n) indices, in [0, n*) + chunk -- int size of chunks of `B` (along dimension `n`) to be processed at a time + + returns -- (b, 1, n) the inner product of `A` and `B`, after gathering the outer dimension + according to `I` + """ + if A.ndim > 3: + assert B.ndim == A.ndim and I.ndim == A.ndim - 1 + return gather_outer_bmv( + A.flatten(end_dim=-3), + B.flatten(end_dim=-3), + I.flatten(end_dim=-2), + chunk=chunk, + ).unflatten(0, A.shape[:-2]) + assert A.ndim == 3 and B.ndim == 3 and A.shape[1] == 1 and A.shape[2] == B.shape[1] + assert I.ndim == 2 and I.shape[0] == A.shape[0] + + b, k, n = A.shape[0], A.shape[2], I.shape[1] + Y = torch.empty((b, 1, n), dtype=A.dtype, device=A.device) + assert Y.stride(0) == n and Y.stride(2) == 1 + + _kernel_gather_outer_bmv[(b,)]( + A_ptr=A, + B_ptr=B, + I_ptr=I, + Y_ptr=Y, + k=k, + n=n, + n_chunk=chunk, + A_s0=A.stride(0), + A_s2=A.stride(2), + B_s0=B.stride(0), + B_s1=B.stride(1), + B_s2=B.stride(2), + I_s0=I.stride(0), + I_s1=I.stride(1), + ) + return Y \ No newline at end of file diff --git a/methods/pca_topk/modify_gptneox.py b/methods/pca_topk/modify_gptneox.py new file mode 100644 index 0000000..739ae81 --- /dev/null +++ b/methods/pca_topk/modify_gptneox.py @@ -0,0 +1,125 @@ +from typing import List, Optional, Tuple, Union +import math +import warnings +from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXAttention, apply_rotary_pos_emb +from transformers.cache_utils import Cache +import torch +from torch import nn +import torch.nn.functional as F +from functools import partial + +from .utils import mask_attn_pca_topk, get_pca_components +import methods +import os + +def get_pca_topk_init(args): + def modified_attention_init(self, config): + super(GPTNeoXAttention, self).__init__() + self.config = config + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + if self.hidden_size % self.num_attention_heads != 0: + raise ValueError( + "The hidden size is not divisble by the number of attention heads! Make sure to update them" + ) + self.head_size = self.hidden_size // self.num_attention_heads + self.rotary_ndims = int(self.head_size * config.rotary_pct) + self._init_bias(config.max_position_embeddings) + + self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False) + self._init_rope() + + self.norm_factor = self.head_size**-0.5 + self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.attention_bias) + self.dense = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + self.attention_dropout = nn.Dropout(config.attention_dropout) + self.is_causal = True + + self.layer_idx = methods.G_TENSOR_SAVER.get_layer_idx() + return modified_attention_init + + +def get_pca_topk_attn(args): + def modified_attn(self, query, key, value, attention_mask=None, head_mask=None): + # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] + # compute causal mask from causal mask buffer + batch_size, num_attention_heads, query_length, attn_head_size = query.size() + key_length = key.size(-2) + + # dynamically increase the causal mask with the key length, if needed. + if key_length > self.bias.shape[-1]: + self._init_bias(key_length, device=key.device) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + + if not hasattr(self, "pca_components"): + self.pca_means, self.pca_components, self.pca_components_r_key = get_pca_components(args, self.layer_idx, self.head_size , args.top_r, self.num_attention_heads, None) + + self.pca_means = self.pca_means.to(key.dtype) + self.pca_components_r_key = self.pca_components_r_key.to(key.dtype) + self.pca_components = self.pca_components.to(key.dtype) + + query_pca = torch.matmul(query, self.pca_components) + key_pca = torch.matmul(key, self.pca_components) + + query_pca = query_pca.view(batch_size * num_attention_heads, query_length, attn_head_size) + key_pca = key_pca.view(batch_size * num_attention_heads, key_length, attn_head_size) + + attn_scores = torch.zeros( + batch_size * num_attention_heads, + query_length, + key_length, + dtype=query.dtype, + device=key.device, + ) + attn_scores = torch.baddbmm( + attn_scores, + query_pca, + key_pca.transpose(1, 2), + beta=1.0, + alpha=self.norm_factor, + ) + attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) + + mask_value = torch.finfo(attn_scores.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device) + + + + attn_scores = torch.where(causal_mask, attn_scores, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_scores = attn_scores + attention_mask + else: + # Create the attention mask if it is not provided + attention_mask = torch.where(causal_mask, torch.tensor(0.0).to(attn_scores.dtype), mask_value) + + # Get top-k attention weights + if args.top_k <= 1: + topk = int(args.top_k * attn_scores.shape[-1]) + else: + topk = int(args.top_k) + attn_scores, alpha = mask_attn_pca_topk(args, self.layer_idx, attn_scores, attention_mask, query, key, self.pca_components, self.pca_components_r_key, args.top_r, topk) + + attn_weights = nn.functional.softmax(attn_scores, dim=-1) + attn_weights = attn_weights.to(value.dtype) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_weights = self.attention_dropout(attn_weights) + + attn_output = torch.matmul(attn_weights, value) + return attn_output, attn_weights + return modified_attn + +def make_gptneox_attention_pca_topk(args): + print ("Modifying GPT NeoX Attention -> PCA TopK Attention") + print ("Top R:", args.top_r) + print ("Top K:", args.top_k) + print ("Not using alpha") + GPTNeoXAttention.__init__ = get_pca_topk_init(args) + GPTNeoXAttention._attn = get_pca_topk_attn(args) diff --git a/methods/pca_topk/modify_llama.py b/methods/pca_topk/modify_llama.py index af1a35d..cc906ea 100644 --- a/methods/pca_topk/modify_llama.py +++ b/methods/pca_topk/modify_llama.py @@ -10,13 +10,9 @@ import torch.nn.functional as F from functools import partial -from .utils import mask_attn_pca_topk +from .utils import mask_attn_pca_topk, get_pca_components import methods -import os -#pca_data_path = "/global/cfs/cdirs/m4641/ApproxAttn" -pca_data_path="/pscratch/sd/s/ssingh37/InferenceData/topk/" - try: from axonn import axonn as ax @@ -25,54 +21,7 @@ except ImportError: AXONN_AVAILABLE=False -def get_pca_components(layer_idx, head_dim, top_r): - model_folder_name = "Meta-Llama-3-8B" # FIXME: this should be made generic, and not hardcoded to a specific model - components_file_path = os.path.join(pca_data_path, f"{model_folder_name}/wikitext/postrotary/key/pca_components/pca_components_layer_{layer_idx}.pt") - mean_file_path = os.path.join(pca_data_path, f"{model_folder_name}/wikitext/postrotary/key/pca_means/pca_means_layer_{layer_idx}.pt") - explained_variance_file_path = os.path.join(pca_data_path, f"{model_folder_name}/wikitext/postrotary/key/pca_explained_variance/pca_explained_variance_layer_{layer_idx}.pt") - - # PCA Components with the shape (num_heads, head_dim, top_r) - pca_components = torch.load(components_file_path).to("cuda") - - # PCA Means with the shape (num_heads, head_dim) - pca_means = torch.load(mean_file_path).to("cuda") - - # Explained Variance with the shape (num_heads, head_dim) - pca_explained_variance = torch.load(explained_variance_file_path).to("cuda") - - # Reshaping the components and taking a transpose to have components along the column dimension and means to be easily broadcastable over the keys - pca_components = pca_components.reshape(1, -1, head_dim, head_dim).transpose(2, 3) - pca_means = pca_means.reshape(1, -1, 1, head_dim) - - # Get the point where the explained variance is 95% per head - explained_variance_cumsum = pca_explained_variance.cumsum(-1) - - - if top_r < 1: - # Find the maximum index where the explained variance is 95% across all heads - Uncomment this line adaptively set the top_r:w - top_correct_r = (explained_variance_cumsum < top_r).sum(-1).max().item() - - # # Instead of sum, we use the median index - # #top_r = (explained_variance_cumsum < 0.95).sum(-1).median().item() - else: - top_correct_r = int(top_r) - - # Only keep the top_r components of the pca_components - pca_components_r_key = pca_components[:, :, :, :top_correct_r] - - print ("{}: PCA Components Shape: {}".format(layer_idx, pca_components_r_key.shape)) - print ("{}: PCA Means Shape: {}".format(layer_idx, pca_means.shape)) - print ("Compression Ratio: {}".format(top_correct_r / head_dim)) - - if AXONN_AVAILABLE and ax.is_initialized: - ## only keep pca data for the heads on the GPU - pca_components = drop(pca_components, transpose=True, skip_batch=True, dim=1) - pca_means = drop(pca_means, transpose=True, skip_batch=True, dim=1) - pca_components_r_key = drop(pca_components_r_key, transpose=True, skip_batch=True, dim=1) - - return pca_means, pca_components, pca_components_r_key - -def get_pca_forward(top_r, top_k): +def get_pca_forward(args): def modified_forward( self, hidden_states: torch.Tensor, @@ -89,7 +38,7 @@ def modified_forward( ) if not hasattr(self, "pca_components"): - self.pca_means, self.pca_components, self.pca_components_r_key = get_pca_components(self.layer_idx, self.head_dim, top_r) + self.pca_means, self.pca_components, self.pca_components_r_key = get_pca_components(args, self.layer_idx, self.head_dim, args.top_r, self.num_key_value_groups, repeat_kv) bsz, q_len, _ = hidden_states.size() if self.config.pretraining_tp > 1: @@ -129,27 +78,39 @@ def modified_forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = (torch.matmul(query_states, key_states.transpose(2, 3))) / math.sqrt(self.head_dim) + #attn_weights = (torch.matmul(query_states, key_states.transpose(2, 3))) / math.sqrt(self.head_dim) + + #print ("Attn Weights DType:", attn_weights.dtype) + + self.pca_means = self.pca_means.to(key_states.dtype) + self.pca_components_r_key = self.pca_components_r_key.to(key_states.dtype) + self.pca_components = self.pca_components.to(key_states.dtype) + + + key_states_pca = torch.matmul(key_states, self.pca_components) + query_states_pca = torch.matmul(query_states, self.pca_components) + attn_weights = (torch.matmul(query_states_pca, key_states_pca.transpose(2, 3))) / math.sqrt(self.head_dim) + + #print ("Attn Weights DType:", attn_weights.dtype) if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask - pca_means = self.pca_means.to(key_states.dtype) pca_components_r_key = self.pca_components_r_key.to(key_states.dtype) pca_components = self.pca_components.to(key_states.dtype) - if top_k <= 1: - topk = int(top_k * attn_weights.shape[-1]) + if args.top_k <= 1: + topk = int(args.top_k * attn_weights.shape[-1]) else: - topk = int(top_k) - attn_weights, alpha = mask_attn_pca_topk(self.layer_idx, attn_weights, attention_mask, query_states, key_states, pca_components, pca_components_r_key, top_r, topk) + topk = int(args.top_k) + attn_weights, alpha = mask_attn_pca_topk(args, self.layer_idx, attn_weights, attention_mask, query_states, key_states, pca_components, pca_components_r_key, args.top_r, topk) assert alpha is not None, "alpha is None" - #print ("Alpha:", alpha.shape) + #print ("Alpha:", alpha.dtype) # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) @@ -163,11 +124,18 @@ def modified_forward( ) # Compute cumulative sum along the desired dimension - cumulative_sum = torch.cumsum(value_states, dim=2).cuda() + # cumulative_sum = torch.cumsum(value_states, dim=2).cuda() # Compute the cumulative mean along the desired dimension - cumulative_mean = cumulative_sum / torch.arange(1, value_states.size(2) + 1).float().unsqueeze(0).unsqueeze(1).unsqueeze(3).cuda() + # cumulative_mean = cumulative_sum / torch.arange(1, value_states.size(2) + 1).float().unsqueeze(0).unsqueeze(1).unsqueeze(3).cuda() + + # Compute cumulative sum along the desired dimension + #cumulative_sum = torch.cumsum(value_states, dim=2).cuda() + + ## Compute the cumulative mean along the desired dimension + #cumulative_mean = cumulative_sum / torch.arange(1, value_states.size(2) + 1).float().unsqueeze(0).unsqueeze(1).unsqueeze(3).cuda() - attn_output = ((1 - alpha) * cumulative_mean) + alpha * attn_output + #attn_output = ((1 - alpha) * cumulative_mean) + alpha * attn_output + #attn_output = attn_output.to(query_states.dtype) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( @@ -192,9 +160,9 @@ def modified_forward( return attn_output, attn_weights, past_key_value return modified_forward -def make_llama_attention_pca_topk(top_r, top_k): +def make_llama_attention_pca_topk(args): print ("Modifying Llama Attention -> PCA TopK Attention") - print ("Top R:", top_r) - print ("Top K:", top_k) + print ("Top R:", args.top_r) + print ("Top K:", args.top_k) #LlamaAttention.__init__ = get_pca_init(top_r, top_k) - LlamaAttention.forward = get_pca_forward(top_r, top_k) + LlamaAttention.forward = get_pca_forward(args) diff --git a/methods/pca_topk/modify_llama_optimized.py b/methods/pca_topk/modify_llama_optimized.py new file mode 100644 index 0000000..073b8f0 --- /dev/null +++ b/methods/pca_topk/modify_llama_optimized.py @@ -0,0 +1,168 @@ +from typing import List, Optional, Tuple, Union +import math +import warnings +from transformers.models.llama.modeling_llama import LlamaAttention, repeat_kv, apply_rotary_pos_emb +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaMLP, ACT2FN +from transformers.cache_utils import Cache +import torch +from torch import nn +import torch.nn.functional as F +from functools import partial + +from .utils import mask_attn_pca_topk, get_pca_components +import methods.pca_topk.external.gather_matmul as G +import methods + + +try: + from axonn import axonn as ax + from axonn.intra_layer import drop + AXONN_AVAILABLE=True +except ImportError: + AXONN_AVAILABLE=False + +def get_pca_forward(args): + def modified_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if not hasattr(self, "pca_components"): + _, self.pca_components, _= get_pca_components(args, self.layer_idx, self.head_dim, args.top_r, self.num_key_value_groups, repeat_kv) + self.pca_components = self.pca_components.to(query_states.dtype) + + # TODO: Keep it fixed or make it dynamic? + if args.top_k <= 1: + self.top_k = int(args.top_k * key_states.shape[-2]) + else: + self.top_k = int(args.top_k) + + key_states = torch.matmul(key_states, self.pca_components) + query_states = torch.matmul(query_states, self.pca_components) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Generation Step + if query_states.shape[-2] == 1: + # Compute Approximate Attention Weights + # We do not need a causal mask here since this is the generation step + attn_weights = torch.matmul(query_states[:,:,:,:args.top_r], key_states.transpose(2, 3)[:,:,:args.top_r,:]) / math.sqrt(self.head_dim) + + key_states_topk_indices = torch.topk(attn_weights, self.top_k, dim=-1).indices.to("cuda") + key_states_topk_indices , _ = torch.sort(key_states_topk_indices, dim=-1) + key_states_topk_indices = key_states_topk_indices.reshape(-1, key_states_topk_indices.shape[-1]) + + key_states = key_states.reshape(-1, key_states.shape[-2], key_states.shape[-1]) + query_states = query_states.reshape(-1, query_states.shape[-2], query_states.shape[-1]) + + attn_weights = G.gather_outer_bmv( + query_states.contiguous(), + key_states.transpose(-1, -2).contiguous(), + key_states_topk_indices, + chunk=256 # Varying this changes performance + #chunk=min(k2, 65536 // Q.shape[-1]), + ) / math.sqrt(self.head_dim) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + value_states = value_states.reshape(-1, value_states.shape[-2], value_states.shape[-1]) + attn_output = G.gather_inner_matrix_only_bmv( + attn_weights.contiguous(), + value_states.contiguous(), + key_states_topk_indices, + chunk=64 + ) + attn_output = attn_output.reshape(bsz, self.num_heads, q_len, self.head_dim) + else: + # Compute Standard Attention + attn_weights = (torch.matmul(query_states, key_states.transpose(2, 3))) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + return modified_forward + +def make_llama_attention_pca_topk(args): + print ("Modifying Llama Attention -> PCA TopK Attention") + print ("Top R:", args.top_r) + print ("Top K:", args.top_k) + #if args.optimised: + print ("Optimised PCA TopK Attention") + #LlamaAttention.__init__ = get_pca_init(top_r, top_k) + LlamaAttention.forward = get_pca_forward(args) diff --git a/methods/pca_topk/modify_mistral.py b/methods/pca_topk/modify_mistral.py index e0cec72..0a42111 100644 --- a/methods/pca_topk/modify_mistral.py +++ b/methods/pca_topk/modify_mistral.py @@ -4,20 +4,16 @@ import warnings from transformers.models.mistral.modeling_mistral import MistralAttention, repeat_kv, apply_rotary_pos_emb, MistralRotaryEmbedding from transformers.models.mistral.configuration_mistral import MistralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralAttention from transformers.cache_utils import Cache import torch from torch import nn import torch.nn.functional as F from functools import partial -from .utils import mask_attn_pca_topk +from .utils import mask_attn_pca_topk, get_pca_components import methods - -import os -pca_data_path = "/global/cfs/cdirs/m4641/ApproxAttn" - - try: from axonn import axonn as ax from axonn.intra_layer import drop @@ -25,55 +21,7 @@ except ImportError: AXONN_AVAILABLE=False -def get_pca_components(layer_idx, head_dim, top_r, num_key_value_groups): - components_file_path = os.path.join(pca_data_path, "Mistral-7B-PCA/wikitext/postrotary/key/pca_components/pca_components_layer_{}.pt".format(layer_idx)) - mean_file_path = os.path.join(pca_data_path, "Mistral-7B-PCA/wikitext/postrotary/key/pca_means/pca_means_layer_{}.pt".format(layer_idx)) - explained_variance_file_path = os.path.join(pca_data_path, "Mistral-7B-PCA/wikitext/postrotary/key/pca_explained_variance/pca_explained_variance_layer_{}.pt".format(layer_idx)) - - # PCA Components with the shape (num_heads, head_dim, top_r) - pca_components = torch.load(components_file_path).to("cuda") - - # PCA Means with the shape (num_heads, head_dim) - pca_means = torch.load(mean_file_path).to("cuda") - - # Explained Variance with the shape (num_heads, head_dim) - pca_explained_variance = torch.load(explained_variance_file_path).to("cuda") - - # Reshaping the components and taking a transpose to have components along the column dimension and means to be easily broadcastable over the keys - pca_components = pca_components.reshape(1, -1, head_dim, head_dim).transpose(2, 3) - pca_means = pca_means.reshape(1, -1, 1, head_dim) - - # Get the point where the explained variance is 95% per head - explained_variance_cumsum = pca_explained_variance.cumsum(-1) - - - if top_r < 1: - # Find the maximum index where the explained variance is 95% across all heads - Uncomment this line adaptively set the top_r:w - top_correct_r = (explained_variance_cumsum < top_r).sum(-1).max().item() - - # # Instead of sum, we use the median index - # #top_r = (explained_variance_cumsum < 0.95).sum(-1).median().item() - else: - top_correct_r = int(top_r) - - # Only keep the top_r components of the pca_components - pca_components_r_key = pca_components[:, :, :, :top_correct_r] - pca_components_r_key = repeat_kv(pca_components_r_key, num_key_value_groups) - pca_components = repeat_kv(pca_components, num_key_value_groups) - - - print ("{}: PCA Components Shape: {}".format(layer_idx, pca_components_r_key.shape)) - print ("{}: PCA Means Shape: {}".format(layer_idx, pca_means.shape)) - print ("Compression Ratio: {}".format(top_correct_r / head_dim)) - - if AXONN_AVAILABLE and ax.is_initialized: - ## only keep pca data for the heads on the GPU - pca_components = drop(pca_components, transpose=True, skip_batch=True, dim=1) - pca_means = drop(pca_means, transpose=True, skip_batch=True, dim=1) - pca_components_r_key = drop(pca_components_r_key, transpose=True, skip_batch=True, dim=1) - - return pca_means, pca_components, pca_components_r_key -def get_pca_forward(top_r, top_k): +def get_pca_forward(args): def modified_forward( self, hidden_states: torch.Tensor, @@ -90,7 +38,7 @@ def modified_forward( ) if not hasattr(self, "pca_components"): - self.pca_means, self.pca_components, self.pca_components_r_key = get_pca_components(self.layer_idx, self.head_dim, top_r, self.num_key_value_groups) + self.pca_means, self.pca_components, self.pca_components_r_key = get_pca_components(args, self.layer_idx, self.head_dim, args.top_r, self.num_key_value_groups, repeat_kv) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -121,7 +69,14 @@ def modified_forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + self.pca_means = self.pca_means.to(key_states.dtype) + self.pca_components_r_key = self.pca_components_r_key.to(key_states.dtype) + self.pca_components = self.pca_components.to(key_states.dtype) + + + key_states_pca = torch.matmul(key_states, self.pca_components) + query_states_pca = torch.matmul(query_states, self.pca_components) + attn_weights = (torch.matmul(query_states_pca, key_states_pca.transpose(2, 3))) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( @@ -137,15 +92,11 @@ def modified_forward( attn_weights = attn_weights + attention_mask - self.pca_means = self.pca_means.to(key_states.dtype) - self.pca_components_r_key = self.pca_components_r_key.to(key_states.dtype) - self.pca_components = self.pca_components.to(key_states.dtype) - - if top_k <= 1: - topk = int(top_k * attn_weights.shape[-1]) + if args.top_k <= 1: + topk = int(args.top_k * attn_weights.shape[-1]) else: - topk = int(top_k) - attn_weights, alpha = mask_attn_pca_topk(self.layer_idx, attn_weights, attention_mask, query_states, key_states, self.pca_components, self.pca_components_r_key, top_r, topk) + topk = int(args.top_k) + attn_weights, alpha = mask_attn_pca_topk(args, self.layer_idx, attn_weights, attention_mask, query_states, key_states, self.pca_components, self.pca_components_r_key, args.top_r, topk) assert alpha is not None, "alpha is None" @@ -154,6 +105,15 @@ def modified_forward( attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) + ## Compute cumulative sum along the desired dimension + #cumulative_sum = torch.cumsum(value_states, dim=2).cuda() + + ## Compute the cumulative mean along the desired dimension + #cumulative_mean = cumulative_sum / torch.arange(1, value_states.size(2) + 1).float().unsqueeze(0).unsqueeze(1).unsqueeze(3).cuda() + + #attn_output = ((1 - alpha) * cumulative_mean) + alpha * attn_output + #attn_output = attn_output.to(query_states.dtype) + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" @@ -171,8 +131,9 @@ def modified_forward( return attn_output, attn_weights, past_key_value return modified_forward -def make_mistral_attention_pca_topk(top_r, top_k): - print ("Modifying Mistral Attention -> PCA Attention") - print ("Top R:", top_r) - print ("Top K:", top_k) - MistralAttention.forward = get_pca_forward(top_r, top_k) +def make_mistral_attention_pca_topk(args): + print ("Modifying Mistral & Mixtral Attention -> PCA Attention") + print ("Top R:", args.top_r) + print ("Top K:", args.top_k) + MistralAttention.forward = get_pca_forward(args) + MixtralAttention.forward = get_pca_forward(args) \ No newline at end of file diff --git a/methods/pca_topk/utils.py b/methods/pca_topk/utils.py index 3740c4a..8ac71c6 100644 --- a/methods/pca_topk/utils.py +++ b/methods/pca_topk/utils.py @@ -1,16 +1,91 @@ from networkx import intersection import torch import math +import os +import methods +try: + from axonn import axonn as ax + from axonn.intra_layer import drop + AXONN_AVAILABLE=True +except ImportError: + AXONN_AVAILABLE=False -def mask_attn_pca_topk(layer_idx, attn_weights, attention_mask, query_states, key_states, pca_comps_full, pca_comps, top_r, top_k, l=-1): +#PCA_DATA_PATH = "/pscratch/sd/p/prajwal/InferenceData" +PCA_DATA_PATH = "/global/cfs/cdirs/m4641/ApproxAttn/" + +def get_pca_components(args, layer_idx, head_dim, top_r, num_key_value_groups, repeat_kv): + model_folder_name = args.model_id.split("/")[-1] + "-PCA" + rotary_type = args.rotary_type + + components_file_path = os.path.join(PCA_DATA_PATH, f"{model_folder_name}/wikitext/{rotary_type}/key/pca_components/pca_components_layer_{layer_idx}.pt") + mean_file_path = os.path.join(PCA_DATA_PATH, f"{model_folder_name}/wikitext/{rotary_type}/key/pca_means/pca_means_layer_{layer_idx}.pt") + explained_variance_file_path = os.path.join(PCA_DATA_PATH, f"{model_folder_name}/wikitext/{rotary_type}/key/pca_explained_variance/pca_explained_variance_layer_{layer_idx}.pt") + + #methods.LOGGER.update_config({"components_file_path": os.path.dirname(components_file_path)}) + + # PCA Components with the shape (num_heads, head_dim, top_r) + pca_components = torch.load(components_file_path).to("cuda") + + # PCA Means with the shape (num_heads, head_dim) + pca_means = torch.load(mean_file_path).to("cuda") + + # Explained Variance with the shape (num_heads, head_dim) + pca_explained_variance = torch.load(explained_variance_file_path).to("cuda") + + # Reshaping the components and taking a transpose to have components along the column dimension and means to be easily broadcastable over the keys + pca_components = pca_components.reshape(1, -1, head_dim, head_dim).transpose(2, 3) + pca_means = pca_means.reshape(1, -1, 1, head_dim) + + # Get the point where the explained variance is 95% per head + explained_variance_cumsum = pca_explained_variance.cumsum(-1) + + + if top_r < 1: + # Find the maximum index where the explained variance is 95% across all heads - Uncomment this line adaptively set the top_r:w + top_correct_r = (explained_variance_cumsum < top_r).sum(-1).max().item() + + # # Instead of sum, we use the median index + # #top_r = (explained_variance_cumsum < 0.95).sum(-1).median().item() + else: + top_correct_r = int(top_r) + + # Only keep the top_r components of the pca_components + pca_components_r_key = pca_components[:, :, :, :top_correct_r] + + if repeat_kv is not None: + pca_components_r_key = repeat_kv(pca_components_r_key, num_key_value_groups) + pca_components = repeat_kv(pca_components, num_key_value_groups) + + + print ("{}: PCA Components Shape: {}".format(layer_idx, pca_components_r_key.shape)) + print ("{}: PCA Means Shape: {}".format(layer_idx, pca_means.shape)) + print ("Compression Ratio: {}".format(top_correct_r / head_dim)) + + if methods.LOGGER is not None: + methods.LOGGER.update_config({"Compression Ratio": top_correct_r / head_dim}) + + if AXONN_AVAILABLE and ax.is_initialized: + print ("Dropping PCA Components and PCA Means") + ## only keep pca data for the heads on the GPU + pca_components = drop(pca_components, transpose=True, skip_batch=True, dim=1) + pca_means = drop(pca_means, transpose=True, skip_batch=True, dim=1) + pca_components_r_key = drop(pca_components_r_key, transpose=True, skip_batch=True, dim=1) + + return pca_means, pca_components, pca_components_r_key + + +def mask_attn_pca_topk(args, layer_idx, attn_weights, attention_mask, query_states, key_states, pca_comps_full, pca_comps, top_r, top_k, l=-1): head_dim = key_states.shape[-1] if top_r == -1: top_r = head_dim # Default recent history = k / 4 - if l == -1: - l = top_k / 4 - #l = 0 + + if hasattr(args, "recent_ratio"): + if args.recent_ratio == -1: + l = 0 + else: + l = int(args.recent_ratio * key_states.shape[-2]) # Transform key_states and query_states to PCA space #key_states_pca = torch.matmul(key_states, pca_comps_full).to(query_states.dtype) @@ -28,8 +103,12 @@ def mask_attn_pca_topk(layer_idx, attn_weights, attention_mask, query_states, ke scaling_factor = scaling_factor.transpose(-1, -2) # Compute attention with the query_states and key_states_sparse - attn_weights_s_hat = torch.matmul(query_states_sparse, key_states_sparse.transpose(-1, -2)) / torch.sqrt(scaling_factor) - attn_weights_s_hat = attn_weights_s_hat + attention_mask + attn_weights_s_hat = torch.matmul(query_states_sparse, key_states_sparse.transpose(-1, -2)) / math.sqrt(head_dim) + if methods.LOGGER is not None: + methods.LOGGER.update_config({"scaling_factor": "fixed"}) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights_s_hat = attn_weights_s_hat + causal_mask s_hat = torch.nn.functional.softmax(attn_weights_s_hat, dim=-1, dtype=torch.float32).to(query_states.dtype) @@ -41,8 +120,8 @@ def mask_attn_pca_topk(layer_idx, attn_weights, attention_mask, query_states, ke # Adding 1 to the recent token scores makes sure they are in the top-k #s_hat_recent = s_hat + mask_recent - if (top_k >= key_states.shape[2]): - top_k = key_states.shape[2] + if (top_k >= key_states.shape[-2]): + top_k = key_states.shape[-2] # Get top-k keys based on the s_hat_recent score matrix i2 = torch.topk(attn_weights_s_hat, top_k, dim=-1).indices diff --git a/set_env_vars_slurm.sh b/set_env_vars_slurm.sh index 7f3f86e..810bf17 100755 --- a/set_env_vars_slurm.sh +++ b/set_env_vars_slurm.sh @@ -1,5 +1,6 @@ #!/bin/bash # select_gpu_device wrapper script +export JOBID=${SLURM_JOB_ID} export RANK=${SLURM_PROCID} export WORLD_SIZE=${SLURM_NTASKS} export LOCAL_RANK=${SLURM_LOCALID}