From bae77d451c1bf3d27038e0fdd0e0b5824b3cc47f Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 1 May 2025 10:19:20 -0700 Subject: [PATCH] [Executorch][llm] Enable local global attention in export_llama script Added a new option of --local_global_attention that takes in pattern of sizes to determine which layers are using local sliding window attention. For example, [0, 256, 256, 0, 256, 256] can be used for 6 layers transformer. Or you can also use [0, 256, 256] as pattern you want to repeat. Differential Revision: [D73891423](https://our.internmc.facebook.com/intern/diff/D73891423/) [ghstack-poisoned] --- examples/models/llama/export_llama_lib.py | 33 +++++++++++++++++++ .../source_transformation/custom_kv_cache.py | 15 ++++++++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 79a225232e0..67394ea5785 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -62,6 +62,7 @@ from .source_transformation.custom_kv_cache import ( replace_kv_cache_with_custom_kv_cache, replace_kv_cache_with_quantized_kv_cache, + replace_kv_cache_with_ring_kv_cache, ) from .source_transformation.quantize import ( @@ -147,6 +148,23 @@ def build_model( return export_llama(args) +def parse_list_of_ints(s): + import ast + + try: + parsed = ast.literal_eval(s) + if isinstance(parsed, list) and all(isinstance(i, int) for i in parsed): + print(parsed) + return parsed + raise argparse.ArgumentTypeError( + "Must be a list of integers, e.g., [0, 16, 0, 16]" + ) + except Exception: + raise argparse.ArgumentTypeError( + "Must be a list of integers, e.g., [0, 16, 0, 16]" + ) + + def build_args_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument("-o", "--output-dir", default=".", help="output directory") @@ -357,6 +375,13 @@ def build_args_parser() -> argparse.ArgumentParser: help="maximum length of context for model to remember", ) + parser.add_argument( + "--local_global_attention", + type=parse_list_of_ints, + default=None, + help="List of integers specifying local and global attention pattern, e.g., [0, 16, 0, 16].", + ) + parser.add_argument("-2", "--fairseq2", action="store_true") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument( @@ -1297,6 +1322,14 @@ def _get_source_transforms( # noqa if args.vulkan: transforms.append(replace_with_vulkan_rotary_emb) + if args.local_global_attention: + transforms.append( + partial( + replace_kv_cache_with_ring_kv_cache, + layer_sizes=args.local_global_attention, + ) + ) + return transforms diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index 24038959dba..3d0324af3d1 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -519,8 +519,17 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes): # This is needed to ensure that custom ops are registered from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 + assert len(module.layers) > len( + layer_sizes + ), f"Length of layer sizes {len(layer_sizes)} must match the number of layers in the module {len(module.layers)}." + multiplier = len(module.layers) // len(layer_sizes) + modulo = len(module.layers) % len(layer_sizes) + assert ( + modulo == 0 + ), f"num layers specified must be multiple of model layers in order to specify pattern. pattern: {layer_sizes} model's num layers {len(module.layers)}" + layer_sizes = layer_sizes * multiplier logging.info( - "Replacing kv cache with ring kv cache. This modifies the model in place." + f"Applying local sliding window attention with following pattern {layer_sizes}." ) assert len(layer_sizes) == len( module.layers @@ -534,4 +543,8 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes): ), f"Transfomer block must have attention module. Transformer block {transformer_block}" attention = transformer_block.attention _replace_kv_cache_with_ring_kv_cache(attention, sliding_window_size) + # if attention's sdpa is custom sdpa then we have to make sure + # it is not doing causal attention + if "SDPACustom" in attention.SDPA.__class__.__name__: + attention.SDPA.use_attention_mask = True return module