Skip to content

Commit 4f57444

Browse files
committed
Port metadata from the linear node onto the reference custom op for int4
Summary: Allow for numerical debugger in ExecuTorch use the from_node info for correlation Test Plan: CI Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 858a6f4 Pull Request resolved: #2860
1 parent 8669213 commit 4f57444

File tree

1 file changed

+36
-2
lines changed

1 file changed

+36
-2
lines changed

torchao/quantization/pt2e/reference_representation_rewrite.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1616
from torch.fx import GraphModule
1717
from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch
18-
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
18+
from torch.fx.subgraph_rewriter import ReplacedPatterns, replace_pattern_with_filters
1919

2020
from torchao.quantization.pt2e.export_utils import WrapperModule
2121
from torchao.quantization.pt2e.utils import (
@@ -455,6 +455,34 @@ def _filter_fn_for_dynamic_quantized_linear_4bit_groupwise(
455455
return weight_is_int4 and act_quant_is_int8
456456

457457

458+
def _port_metadata_for_dynamic_quantized_linear_4bit_groupwise(
459+
replacement_pattern: ReplacedPatterns,
460+
):
461+
"""
462+
Port metadata for dynamically quantized linear 4-bit groupwise operation.
463+
It custom_op node's metadata with corresponding linear node's metadata.
464+
"""
465+
from torch.fx.traceback import NodeSource, NodeSourceAction
466+
467+
linear_node = None
468+
int4_custom_op_node = None
469+
for _, g_n in replacement_pattern.nodes_map.items():
470+
if g_n.target == torch.ops.aten.linear.default:
471+
linear_node = g_n
472+
break
473+
if len(replacement_pattern.replacements) > 0:
474+
int4_custom_op_node = replacement_pattern.replacements[-1]
475+
if linear_node is not None and int4_custom_op_node is not None:
476+
int4_custom_op_node.meta = linear_node.meta.copy()
477+
int4_custom_op_node.meta["from_node"] = [
478+
NodeSource(
479+
linear_node,
480+
"ReplaceInt4DynamicQuantWithCustomOp",
481+
NodeSourceAction.REPLACE,
482+
)
483+
]
484+
485+
458486
def _qdq_quantized_conv2d(
459487
x_i8,
460488
x_scale,
@@ -883,6 +911,7 @@ class _RewriteInfo:
883911
list[Callable[["InternalMatch", torch.fx.Graph, torch.fx.Graph], bool]]
884912
] = None
885913
ignore_literals: bool = False
914+
port_metadata_fn: Optional[Callable[["ReplacedPatterns"], None]] = None
886915

887916

888917
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
@@ -1053,6 +1082,7 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
10531082
),
10541083
filter_fn=[_filter_fn_for_dynamic_quantized_linear_4bit_groupwise],
10551084
ignore_literals=True,
1085+
port_metadata_fn=_port_metadata_for_dynamic_quantized_linear_4bit_groupwise,
10561086
),
10571087
_RewriteInfo(
10581088
_DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_2,
@@ -1074,6 +1104,7 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
10741104
),
10751105
filter_fn=[_filter_fn_for_dynamic_quantized_linear_4bit_groupwise],
10761106
ignore_literals=True,
1107+
port_metadata_fn=_port_metadata_for_dynamic_quantized_linear_4bit_groupwise,
10771108
),
10781109
_RewriteInfo(
10791110
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
@@ -1153,12 +1184,15 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
11531184
replacement = replacement_post_trans(replacement)
11541185
pattern.recompile() # type: ignore[attr-defined]
11551186
replacement.recompile() # type: ignore[attr-defined]
1156-
replace_pattern_with_filters(
1187+
matches = replace_pattern_with_filters(
11571188
model,
11581189
pattern,
11591190
replacement,
11601191
match_filters=rewrite_info.filter_fn,
11611192
ignore_literals=rewrite_info.ignore_literals,
11621193
) # type: ignore[arg-type]
1194+
if rewrite_info.port_metadata_fn:
1195+
for m in matches:
1196+
rewrite_info.port_metadata_fn(m) # type: ignore[arg-type]
11631197

11641198
return model

0 commit comments

Comments
 (0)