Skip to content

Commit 4ae8347

Browse files
committed
[Executorch][llm] Enable local global attention in export_llama script
Added a new option of --local_global_attention that takes in pattern of sizes to determine which layers are using local sliding window attention. For example, [0, 256, 256, 0, 256, 256] can be used for 6 layers transformer. Or you can also use [0, 256, 256] as pattern you want to repeat. Differential Revision: [D73891423](https://our.internmc.facebook.com/intern/diff/D73891423/) ghstack-source-id: 281455704 Pull Request resolved: #10612
1 parent 4304f5a commit 4ae8347

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

examples/models/llama/export_llama_lib.py

+33
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from .source_transformation.custom_kv_cache import (
6363
replace_kv_cache_with_custom_kv_cache,
6464
replace_kv_cache_with_quantized_kv_cache,
65+
replace_kv_cache_with_ring_kv_cache,
6566
)
6667

6768
from .source_transformation.quantize import (
@@ -147,6 +148,23 @@ def build_model(
147148
return export_llama(args)
148149

149150

151+
def parse_list_of_ints(s):
152+
import ast
153+
154+
try:
155+
parsed = ast.literal_eval(s)
156+
if isinstance(parsed, list) and all(isinstance(i, int) for i in parsed):
157+
print(parsed)
158+
return parsed
159+
raise argparse.ArgumentTypeError(
160+
"Must be a list of integers, e.g., [0, 16, 0, 16]"
161+
)
162+
except Exception:
163+
raise argparse.ArgumentTypeError(
164+
"Must be a list of integers, e.g., [0, 16, 0, 16]"
165+
)
166+
167+
150168
def build_args_parser() -> argparse.ArgumentParser:
151169
parser = argparse.ArgumentParser()
152170
parser.add_argument("-o", "--output-dir", default=".", help="output directory")
@@ -357,6 +375,13 @@ def build_args_parser() -> argparse.ArgumentParser:
357375
help="maximum length of context for model to remember",
358376
)
359377

378+
parser.add_argument(
379+
"--local_global_attention",
380+
type=parse_list_of_ints,
381+
default=None,
382+
help="List of integers specifying local and global attention pattern, e.g., [0, 16, 0, 16].",
383+
)
384+
360385
parser.add_argument("-2", "--fairseq2", action="store_true")
361386
parser.add_argument("-v", "--verbose", action="store_true")
362387
parser.add_argument(
@@ -1297,6 +1322,14 @@ def _get_source_transforms( # noqa
12971322
if args.vulkan:
12981323
transforms.append(replace_with_vulkan_rotary_emb)
12991324

1325+
if args.local_global_attention:
1326+
transforms.append(
1327+
partial(
1328+
replace_kv_cache_with_ring_kv_cache,
1329+
layer_sizes=args.local_global_attention,
1330+
)
1331+
)
1332+
13001333
return transforms
13011334

13021335

examples/models/llama/source_transformation/custom_kv_cache.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,17 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
519519
# This is needed to ensure that custom ops are registered
520520
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
521521

522+
assert len(module.layers) > len(
523+
layer_sizes
524+
), f"Length of layer sizes {len(layer_sizes)} must match the number of layers in the module {len(module.layers)}."
525+
multiplier = len(module.layers) // len(layer_sizes)
526+
modulo = len(module.layers) % len(layer_sizes)
527+
assert (
528+
modulo == 0
529+
), f"num layers specified must be multiple of model layers in order to specify pattern. pattern: {layer_sizes} model's num layers {len(module.layers)}"
530+
layer_sizes = layer_sizes * multiplier
522531
logging.info(
523-
"Replacing kv cache with ring kv cache. This modifies the model in place."
532+
f"Applying local sliding window attention with following pattern {layer_sizes}."
524533
)
525534
assert len(layer_sizes) == len(
526535
module.layers
@@ -534,4 +543,8 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
534543
), f"Transfomer block must have attention module. Transformer block {transformer_block}"
535544
attention = transformer_block.attention
536545
_replace_kv_cache_with_ring_kv_cache(attention, sliding_window_size)
546+
# if attention's sdpa is custom sdpa then we have to make sure
547+
# it is not doing causal attention
548+
if "SDPACustom" in attention.SDPA.__class__.__name__:
549+
attention.SDPA.use_attention_mask = True
537550
return module

0 commit comments

Comments
 (0)