Skip to content

Commit 7035111

Browse files
committed
[Executorch][llm] Enable local global attention in export_llama script
Pull Request resolved: #10612 Added a new option of --local_global_attention that takes in pattern of sizes to determine which layers are using local sliding window attention. For example, [0, 256, 256, 0, 256, 256] can be used for 6 layers transformer. Or you can also use [0, 256, 256] as pattern you want to repeat. ghstack-source-id: 282706487 @exported-using-ghexport Differential Revision: [D73891423](https://our.internmc.facebook.com/intern/diff/D73891423/)
1 parent 13c43cb commit 7035111

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

examples/models/llama/export_llama_lib.py

+33
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from .source_transformation.custom_kv_cache import (
6363
replace_kv_cache_with_custom_kv_cache,
6464
replace_kv_cache_with_quantized_kv_cache,
65+
replace_kv_cache_with_ring_kv_cache,
6566
)
6667

6768
from .source_transformation.quantize import (
@@ -153,6 +154,23 @@ def build_model(
153154
return export_llama(args)
154155

155156

157+
def parse_list_of_ints(s):
158+
import ast
159+
160+
try:
161+
parsed = ast.literal_eval(s)
162+
if isinstance(parsed, list) and all(isinstance(i, int) for i in parsed):
163+
print(parsed)
164+
return parsed
165+
raise argparse.ArgumentTypeError(
166+
"Must be a list of integers, e.g., [0, 16, 0, 16]"
167+
)
168+
except Exception:
169+
raise argparse.ArgumentTypeError(
170+
"Must be a list of integers, e.g., [0, 16, 0, 16]"
171+
)
172+
173+
156174
def build_args_parser() -> argparse.ArgumentParser:
157175
parser = argparse.ArgumentParser()
158176
parser.add_argument("-o", "--output-dir", default=".", help="output directory")
@@ -363,6 +381,13 @@ def build_args_parser() -> argparse.ArgumentParser:
363381
help="maximum length of context for model to remember",
364382
)
365383

384+
parser.add_argument(
385+
"--local_global_attention",
386+
type=parse_list_of_ints,
387+
default=None,
388+
help="List of integers specifying local and global attention pattern, e.g., [0, 16, 0, 16].",
389+
)
390+
366391
parser.add_argument("-2", "--fairseq2", action="store_true")
367392
parser.add_argument("-v", "--verbose", action="store_true")
368393
parser.add_argument(
@@ -1307,6 +1332,14 @@ def _get_source_transforms( # noqa
13071332
if args.vulkan:
13081333
transforms.append(replace_with_vulkan_rotary_emb)
13091334

1335+
if args.local_global_attention:
1336+
transforms.append(
1337+
partial(
1338+
replace_kv_cache_with_ring_kv_cache,
1339+
layer_sizes=args.local_global_attention,
1340+
)
1341+
)
1342+
13101343
return transforms
13111344

13121345

examples/models/llama/source_transformation/custom_kv_cache.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -555,8 +555,17 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
555555
# This is needed to ensure that custom ops are registered
556556
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
557557

558+
assert len(module.layers) >= len(
559+
layer_sizes
560+
), f"Length of layer sizes {len(layer_sizes)} must match the number of layers in the module {len(module.layers)}."
561+
multiplier = len(module.layers) // len(layer_sizes)
562+
modulo = len(module.layers) % len(layer_sizes)
563+
assert (
564+
modulo == 0
565+
), f"num layers specified must be multiple of model layers in order to specify pattern. pattern: {layer_sizes} model's num layers {len(module.layers)}"
566+
layer_sizes = layer_sizes * multiplier
558567
logging.info(
559-
"Replacing kv cache with ring kv cache. This modifies the model in place."
568+
f"Applying local sliding window attention with following pattern {layer_sizes}."
560569
)
561570
assert len(layer_sizes) == len(
562571
module.layers
@@ -570,4 +579,8 @@ def replace_kv_cache_with_ring_kv_cache(module, layer_sizes):
570579
), f"Transfomer block must have attention module. Transformer block {transformer_block}"
571580
attention = transformer_block.attention
572581
_replace_kv_cache_with_ring_kv_cache(attention, sliding_window_size)
582+
# if attention's sdpa is custom sdpa then we have to make sure
583+
# it is not doing causal attention
584+
if "SDPACustom" in attention.SDPA.__class__.__name__:
585+
attention.SDPA.use_attention_mask = True
573586
return module

0 commit comments

Comments
 (0)