|
62 | 62 | from .source_transformation.custom_kv_cache import (
|
63 | 63 | replace_kv_cache_with_custom_kv_cache,
|
64 | 64 | replace_kv_cache_with_quantized_kv_cache,
|
| 65 | + replace_kv_cache_with_ring_kv_cache, |
65 | 66 | )
|
66 | 67 |
|
67 | 68 | from .source_transformation.quantize import (
|
@@ -147,6 +148,23 @@ def build_model(
|
147 | 148 | return export_llama(args)
|
148 | 149 |
|
149 | 150 |
|
| 151 | +def parse_list_of_ints(s): |
| 152 | + import ast |
| 153 | + |
| 154 | + try: |
| 155 | + parsed = ast.literal_eval(s) |
| 156 | + if isinstance(parsed, list) and all(isinstance(i, int) for i in parsed): |
| 157 | + print(parsed) |
| 158 | + return parsed |
| 159 | + raise argparse.ArgumentTypeError( |
| 160 | + "Must be a list of integers, e.g., [0, 16, 0, 16]" |
| 161 | + ) |
| 162 | + except Exception: |
| 163 | + raise argparse.ArgumentTypeError( |
| 164 | + "Must be a list of integers, e.g., [0, 16, 0, 16]" |
| 165 | + ) |
| 166 | + |
| 167 | + |
150 | 168 | def build_args_parser() -> argparse.ArgumentParser:
|
151 | 169 | parser = argparse.ArgumentParser()
|
152 | 170 | parser.add_argument("-o", "--output-dir", default=".", help="output directory")
|
@@ -357,6 +375,13 @@ def build_args_parser() -> argparse.ArgumentParser:
|
357 | 375 | help="maximum length of context for model to remember",
|
358 | 376 | )
|
359 | 377 |
|
| 378 | + parser.add_argument( |
| 379 | + "--local_global_attention", |
| 380 | + type=parse_list_of_ints, |
| 381 | + default=None, |
| 382 | + help="List of integers specifying local and global attention pattern, e.g., [0, 16, 0, 16].", |
| 383 | + ) |
| 384 | + |
360 | 385 | parser.add_argument("-2", "--fairseq2", action="store_true")
|
361 | 386 | parser.add_argument("-v", "--verbose", action="store_true")
|
362 | 387 | parser.add_argument(
|
@@ -1297,6 +1322,14 @@ def _get_source_transforms( # noqa
|
1297 | 1322 | if args.vulkan:
|
1298 | 1323 | transforms.append(replace_with_vulkan_rotary_emb)
|
1299 | 1324 |
|
| 1325 | + if args.local_global_attention: |
| 1326 | + transforms.append( |
| 1327 | + partial( |
| 1328 | + replace_kv_cache_with_ring_kv_cache, |
| 1329 | + layer_sizes=args.local_global_attention, |
| 1330 | + ) |
| 1331 | + ) |
| 1332 | + |
1300 | 1333 | return transforms
|
1301 | 1334 |
|
1302 | 1335 |
|
|
0 commit comments