|
62 | 62 | from .source_transformation.custom_kv_cache import (
|
63 | 63 | replace_kv_cache_with_custom_kv_cache,
|
64 | 64 | replace_kv_cache_with_quantized_kv_cache,
|
| 65 | + replace_kv_cache_with_ring_kv_cache, |
65 | 66 | )
|
66 | 67 |
|
67 | 68 | from .source_transformation.quantize import (
|
@@ -153,6 +154,23 @@ def build_model(
|
153 | 154 | return export_llama(args)
|
154 | 155 |
|
155 | 156 |
|
| 157 | +def parse_list_of_ints(s): |
| 158 | + import ast |
| 159 | + |
| 160 | + try: |
| 161 | + parsed = ast.literal_eval(s) |
| 162 | + if isinstance(parsed, list) and all(isinstance(i, int) for i in parsed): |
| 163 | + print(parsed) |
| 164 | + return parsed |
| 165 | + raise argparse.ArgumentTypeError( |
| 166 | + "Must be a list of integers, e.g., [0, 16, 0, 16]" |
| 167 | + ) |
| 168 | + except Exception: |
| 169 | + raise argparse.ArgumentTypeError( |
| 170 | + "Must be a list of integers, e.g., [0, 16, 0, 16]" |
| 171 | + ) |
| 172 | + |
| 173 | + |
156 | 174 | def build_args_parser() -> argparse.ArgumentParser:
|
157 | 175 | parser = argparse.ArgumentParser()
|
158 | 176 | parser.add_argument("-o", "--output-dir", default=".", help="output directory")
|
@@ -363,6 +381,13 @@ def build_args_parser() -> argparse.ArgumentParser:
|
363 | 381 | help="maximum length of context for model to remember",
|
364 | 382 | )
|
365 | 383 |
|
| 384 | + parser.add_argument( |
| 385 | + "--local_global_attention", |
| 386 | + type=parse_list_of_ints, |
| 387 | + default=None, |
| 388 | + help="List of integers specifying local and global attention pattern, e.g., [0, 16, 0, 16].", |
| 389 | + ) |
| 390 | + |
366 | 391 | parser.add_argument("-2", "--fairseq2", action="store_true")
|
367 | 392 | parser.add_argument("-v", "--verbose", action="store_true")
|
368 | 393 | parser.add_argument(
|
@@ -1307,6 +1332,14 @@ def _get_source_transforms( # noqa
|
1307 | 1332 | if args.vulkan:
|
1308 | 1333 | transforms.append(replace_with_vulkan_rotary_emb)
|
1309 | 1334 |
|
| 1335 | + if args.local_global_attention: |
| 1336 | + transforms.append( |
| 1337 | + partial( |
| 1338 | + replace_kv_cache_with_ring_kv_cache, |
| 1339 | + layer_sizes=args.local_global_attention, |
| 1340 | + ) |
| 1341 | + ) |
| 1342 | + |
1310 | 1343 | return transforms
|
1311 | 1344 |
|
1312 | 1345 |
|
|
0 commit comments