Skip to content

Commit d7db60d

Browse files
committed
[Executorch][llm] Enable local global attention in export_llama script
Pull Request resolved: #10612 Added a new option of --local_global_attention that takes in pattern of sizes to determine which layers are using local sliding window attention. For example, [0, 256, 256, 0, 256, 256] can be used for 6 layers transformer. Or you can also use [0, 256, 256] as pattern you want to repeat. ghstack-source-id: 282013415 @exported-using-ghexport Differential Revision: [D73891423](https://our.internmc.facebook.com/intern/diff/D73891423/)
1 parent 48877ff commit d7db60d

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

examples/models/llama/export_llama_lib.py

+33
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from .source_transformation.custom_kv_cache import (
6363
replace_kv_cache_with_custom_kv_cache,
6464
replace_kv_cache_with_quantized_kv_cache,
65+
replace_kv_cache_with_ring_kv_cache,
6566
)
6667

6768
from .source_transformation.quantize import (
@@ -153,6 +154,23 @@ def build_model(
153154
return export_llama(args)
154155

155156

157+
def parse_list_of_ints(s):
158+
import ast
159+
160+
try:
161+
parsed = ast.literal_eval(s)
162+
if isinstance(parsed, list) and all(isinstance(i, int) for i in parsed):
163+
print(parsed)
164+
return parsed
165+
raise argparse.ArgumentTypeError(
166+
"Must be a list of integers, e.g., [0, 16, 0, 16]"
167+
)
168+
except Exception:
169+
raise argparse.ArgumentTypeError(
170+
"Must be a list of integers, e.g., [0, 16, 0, 16]"
171+
)
172+
173+
156174
def build_args_parser() -> argparse.ArgumentParser:
157175
parser = argparse.ArgumentParser()
158176
parser.add_argument("-o", "--output-dir", default=".", help="output directory")
@@ -363,6 +381,13 @@ def build_args_parser() -> argparse.ArgumentParser:
363381
help="maximum length of context for model to remember",
364382
)
365383

384+
parser.add_argument(
385+
"--local_global_attention",
386+
type=parse_list_of_ints,
387+
default=None,
388+
help="List of integers specifying local and global attention pattern, e.g., [0, 16, 0, 16].",
389+
)
390+
366391
parser.add_argument("-2", "--fairseq2", action="store_true")
367392
parser.add_argument("-v", "--verbose", action="store_true")
368393
parser.add_argument(
@@ -1307,6 +1332,14 @@ def _get_source_transforms( # noqa
13071332
if args.vulkan:
13081333
transforms.append(replace_with_vulkan_rotary_emb)
13091334

1335+
if args.local_global_attention:
1336+
transforms.append(
1337+
partial(
1338+
replace_kv_cache_with_ring_kv_cache,
1339+
layer_sizes=args.local_global_attention,
1340+
)
1341+
)
1342+
13101343
return transforms
13111344

13121345

examples/models/llama/source_transformation/custom_kv_cache.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,17 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
519519
# This is needed to ensure that custom ops are registered
520520
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
521521

522+
assert len(module.layers) > len(
523+
layer_sizes
524+
), f"Length of layer sizes {len(layer_sizes)} must match the number of layers in the module {len(module.layers)}."
525+
multiplier = len(module.layers) // len(layer_sizes)
526+
modulo = len(module.layers) % len(layer_sizes)
527+
assert (
528+
modulo == 0
529+
), f"num layers specified must be multiple of model layers in order to specify pattern. pattern: {layer_sizes} model's num layers {len(module.layers)}"
530+
layer_sizes = layer_sizes * multiplier
522531
logging.info(
523-
"Replacing kv cache with ring kv cache. This modifies the model in place."
532+
f"Applying local sliding window attention with following pattern {layer_sizes}."
524533
)
525534
assert len(layer_sizes) == len(
526535
module.layers
@@ -534,4 +543,8 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
534543
), f"Transfomer block must have attention module. Transformer block {transformer_block}"
535544
attention = transformer_block.attention
536545
_replace_kv_cache_with_ring_kv_cache(attention, sliding_window_size)
546+
# if attention's sdpa is custom sdpa then we have to make sure
547+
# it is not doing causal attention
548+
if "SDPACustom" in attention.SDPA.__class__.__name__:
549+
attention.SDPA.use_attention_mask = True
537550
return module

0 commit comments

Comments
 (0)