diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 0e48a8520d..2ef125a25b 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -62,6 +62,7 @@ from .source_transformation.custom_kv_cache import ( replace_kv_cache_with_custom_kv_cache, replace_kv_cache_with_quantized_kv_cache, + replace_kv_cache_with_ring_kv_cache, ) from .source_transformation.quantize import ( @@ -153,6 +154,23 @@ def build_model( return export_llama(args) +def parse_list_of_ints(s): + import ast + + try: + parsed = ast.literal_eval(s) + if isinstance(parsed, list) and all(isinstance(i, int) for i in parsed): + print(parsed) + return parsed + raise argparse.ArgumentTypeError( + "Must be a list of integers, e.g., [0, 16, 0, 16]" + ) + except Exception: + raise argparse.ArgumentTypeError( + "Must be a list of integers, e.g., [0, 16, 0, 16]" + ) + + def build_args_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() parser.add_argument("-o", "--output-dir", default=".", help="output directory") @@ -363,6 +381,15 @@ def build_args_parser() -> argparse.ArgumentParser: help="maximum length of context for model to remember", ) + parser.add_argument( + "--local_global_attention", + type=parse_list_of_ints, + default=None, + help="List of integers specifying local and global attention pattern, e.g., [0, 16, 0, 16] to specify that every other layer is sliding window of 16." + " [0, 16, 32] pattern specifes 2nd and 3rd layer has sliding window of 16 and 32 respecitvely. " + " [16] pattern specifies all layers have sliding window of 16.", + ) + parser.add_argument("-2", "--fairseq2", action="store_true") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument( @@ -1307,6 +1334,14 @@ def _get_source_transforms( # noqa if args.vulkan: transforms.append(replace_with_vulkan_rotary_emb) + if getattr(args, "local_global_attention", None) is not None: + transforms.append( + partial( + replace_kv_cache_with_ring_kv_cache, + layer_sizes=args.local_global_attention, + ) + ) + return transforms diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index ffe6732dd5..25ec207d0e 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -555,8 +555,17 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes): # This is needed to ensure that custom ops are registered from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 + assert len(module.layers) >= len( + layer_sizes + ), f"Length of layer sizes {len(layer_sizes)} must match the number of layers in the module {len(module.layers)}." + multiplier = len(module.layers) // len(layer_sizes) + modulo = len(module.layers) % len(layer_sizes) + assert ( + modulo == 0 + ), f"num layers specified must be multiple of model layers in order to specify pattern. pattern: {layer_sizes} model's num layers {len(module.layers)}" + layer_sizes = layer_sizes * multiplier logging.info( - "Replacing kv cache with ring kv cache. This modifies the model in place." + f"Applying local sliding window attention with following pattern {layer_sizes}." ) assert len(layer_sizes) == len( module.layers @@ -570,4 +579,8 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes): ), f"Transfomer block must have attention module. Transformer block {transformer_block}" attention = transformer_block.attention _replace_kv_cache_with_ring_kv_cache(attention, sliding_window_size) + # if attention's sdpa is custom sdpa then we have to make sure + # it is not doing causal attention + if "SDPACustom" in attention.SDPA.__class__.__name__: + attention.SDPA.use_attention_mask = True return module diff --git a/examples/models/llama/tests/TARGETS b/examples/models/llama/tests/TARGETS index 40ab6653c6..fe79e405ca 100644 --- a/examples/models/llama/tests/TARGETS +++ b/examples/models/llama/tests/TARGETS @@ -85,3 +85,19 @@ python_unittest( "//executorch/examples/models/llama:sdpa", ], ) + +python_unittest( + name = "test_export_llama_lib", + srcs = [ + "test_export_llama_lib.py", + ], + preload_deps = [ + "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", + ], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama:export_library", + "//executorch/examples/models/llama:llama_transformer", + "//executorch/extension/pybindings:portable_lib", + ], +)