15
15
from torch .ao .quantization .fx ._decomposed import quantized_decomposed_lib # noqa: F401
16
16
from torch .fx import GraphModule
17
17
from torch .fx .passes .utils .matcher_with_name_node_map_utils import InternalMatch
18
- from torch .fx .subgraph_rewriter import replace_pattern_with_filters
18
+ from torch .fx .subgraph_rewriter import ReplacedPatterns , replace_pattern_with_filters
19
19
20
20
from torchao .quantization .pt2e .export_utils import WrapperModule
21
21
from torchao .quantization .pt2e .utils import (
@@ -455,6 +455,34 @@ def _filter_fn_for_dynamic_quantized_linear_4bit_groupwise(
455
455
return weight_is_int4 and act_quant_is_int8
456
456
457
457
458
+ def _port_metadata_for_dynamic_quantized_linear_4bit_groupwise (
459
+ replacement_pattern : ReplacedPatterns ,
460
+ ):
461
+ """
462
+ Port metadata for dynamically quantized linear 4-bit groupwise operation.
463
+ It custom_op node's metadata with corresponding linear node's metadata.
464
+ """
465
+ from torch .fx .traceback import NodeSource , NodeSourceAction
466
+
467
+ linear_node = None
468
+ int4_custom_op_node = None
469
+ for _ , g_n in replacement_pattern .nodes_map .items ():
470
+ if g_n .target == torch .ops .aten .linear .default :
471
+ linear_node = g_n
472
+ break
473
+ if len (replacement_pattern .replacements ) > 0 :
474
+ int4_custom_op_node = replacement_pattern .replacements [- 1 ]
475
+ if linear_node is not None and int4_custom_op_node is not None :
476
+ int4_custom_op_node .meta = linear_node .meta .copy ()
477
+ int4_custom_op_node .meta ["from_node" ] = [
478
+ NodeSource (
479
+ linear_node ,
480
+ "ReplaceInt4DynamicQuantWithCustomOp" ,
481
+ NodeSourceAction .REPLACE ,
482
+ )
483
+ ]
484
+
485
+
458
486
def _qdq_quantized_conv2d (
459
487
x_i8 ,
460
488
x_scale ,
@@ -883,6 +911,7 @@ class _RewriteInfo:
883
911
list [Callable [["InternalMatch" , torch .fx .Graph , torch .fx .Graph ], bool ]]
884
912
] = None
885
913
ignore_literals : bool = False
914
+ port_metadata_fn : Optional [Callable [["ReplacedPatterns" ], None ]] = None
886
915
887
916
888
917
def reference_representation_rewrite (model : GraphModule ) -> GraphModule :
@@ -1053,6 +1082,7 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
1053
1082
),
1054
1083
filter_fn = [_filter_fn_for_dynamic_quantized_linear_4bit_groupwise ],
1055
1084
ignore_literals = True ,
1085
+ port_metadata_fn = _port_metadata_for_dynamic_quantized_linear_4bit_groupwise ,
1056
1086
),
1057
1087
_RewriteInfo (
1058
1088
_DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_2 ,
@@ -1074,6 +1104,7 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
1074
1104
),
1075
1105
filter_fn = [_filter_fn_for_dynamic_quantized_linear_4bit_groupwise ],
1076
1106
ignore_literals = True ,
1107
+ port_metadata_fn = _port_metadata_for_dynamic_quantized_linear_4bit_groupwise ,
1077
1108
),
1078
1109
_RewriteInfo (
1079
1110
_QUANTIZED_LINEAR_EXAMPLE_INPUTS ,
@@ -1153,12 +1184,15 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
1153
1184
replacement = replacement_post_trans (replacement )
1154
1185
pattern .recompile () # type: ignore[attr-defined]
1155
1186
replacement .recompile () # type: ignore[attr-defined]
1156
- replace_pattern_with_filters (
1187
+ matches = replace_pattern_with_filters (
1157
1188
model ,
1158
1189
pattern ,
1159
1190
replacement ,
1160
1191
match_filters = rewrite_info .filter_fn ,
1161
1192
ignore_literals = rewrite_info .ignore_literals ,
1162
1193
) # type: ignore[arg-type]
1194
+ if rewrite_info .port_metadata_fn :
1195
+ for m in matches :
1196
+ rewrite_info .port_metadata_fn (m ) # type: ignore[arg-type]
1163
1197
1164
1198
return model
0 commit comments