Skip to content

Correct debug handle merging logic #12073

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 48 additions & 30 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,47 +530,63 @@ def compare_results(
return results


def merge_overlapping_debug_handles(intermediate_outputs: Dict[DebugHandle, Any]):
def merge_overlapping_debug_handles(
intermediate_outputs: Dict[DebugHandle, Any]
) -> Dict[DebugHandle, Any]:
"""
Merge overlapping debug handles int a single key
Merges overlapping debug handles into a single key in the dict.
For each debug handle, this function checks for overlaps with existing keys in the merged dict.
If overlaps are found, it combines the overlapping keys into a single key by taking the union of their elements.
The value associated with the merged key is determined by the debug handle with the highest last element.
"""

if len(intermediate_outputs) == 0:
return
# Extract and normalize into (start, end, val)
intervals = [(min(key), max(key), val) for key, val in intermediate_outputs.items()]
intervals.sort(key=lambda x: x[0])

# Merge overlapping debug_hanldes, picking the last value
merged_intermediate_outputs = []
cur_start, cur_end, cur_val = intervals[0]
for start, end, val in intervals[1:]:
if start <= cur_end: # Overlaps
if end > cur_end: # Extend if this one goes further
cur_end, cur_val = end, val
return {}

else:
merged_intermediate_outputs.append((cur_start, cur_end, cur_val))
cur_start, cur_end, cur_val = start, end, val
merged_intermediate_outputs.append((cur_start, cur_end, cur_val))
merged: Dict[DebugHandle, Any] = {}

for debug_handle, value in intermediate_outputs.items():
debug_handle_set = set(debug_handle)
curr_debug_handle, last_value = debug_handle, value

# collect any existing keys that overlap with the current key
to_remove = []
for existing_debug_handle, existing_value in merged.items():
if debug_handle_set.intersection(set(existing_debug_handle)):
# abosrb their ints
debug_handle_set |= set(existing_debug_handle)
if existing_debug_handle[-1] > curr_debug_handle[-1]:
curr_debug_handle, last_value = (
existing_debug_handle,
existing_value,
)
to_remove.append(existing_debug_handle)

# Clear original one and populate with merged keys (value will point to the same object)
intermediate_outputs.clear()
for start, end, val in merged_intermediate_outputs:
intermediate_outputs[tuple(range(start, end + 1))] = val
# remove all the keys that overlap with the current key
for debug_handle in to_remove:
merged.pop(debug_handle)

# add the current key to the merged one
new_debug_handle = tuple(sorted(debug_handle_set))
merged[new_debug_handle] = last_value

# Sort the merged debug handles in ascending order based on their last element
# TODO: Consider adding more logic to align the order with the execution order
return dict(sorted(merged.items(), key=lambda item: item[0][-1]))


def _debug_handles_have_overlap(
aot_debug_hanlde: DebugHandle, runtime_debug_handle: DebugHandle
debug_handle: DebugHandle, target_debug_handle: DebugHandle
) -> bool:
"""
Check if the AOT debug handle and the runtime debug handle have any overlap.
Check if the debug handle and the target runtime debug handle have any overlap.
"""
aot_set = set(aot_debug_hanlde)
runtime_set = set(runtime_debug_handle)
aot_set = set(debug_handle)
runtime_set = set(target_debug_handle)
return len(aot_set.intersection(runtime_set)) > 0


def _combine_debug_hanldes(debug_handles: List[DebugHandle]) -> DebugHandle:
def _combine_debug_handles(debug_handles: List[DebugHandle]) -> DebugHandle:
"""Combine multiple debug handles into one debug handle"""
combined_debug_handles_set = set()
for debug_handle in debug_handles:
Expand All @@ -584,7 +600,7 @@ def _combine_overlapped_intermediate_outputs(
"""Combine multiple overlapped intermediate outputs into one with combined debug_handles and last output"""
debug_handles = [debug_handle for debug_handle, _ in nodes]
outputs = [output for _, output in nodes]
combined_debug_handle = _combine_debug_hanldes(debug_handles)
combined_debug_handle = _combine_debug_handles(debug_handles)
output = outputs[-1] # Pick the last one
return combined_debug_handle, output

Expand Down Expand Up @@ -673,8 +689,10 @@ def map_runtime_aot_intermediate_outputs(
from runtime intermediate output to AOT intermediate output
"""
# Merge overlapping debug handles
merge_overlapping_debug_handles(aot_intermediate_outputs)
merge_overlapping_debug_handles(runtime_intermediate_outputs)
aot_intermediate_outputs = merge_overlapping_debug_handles(aot_intermediate_outputs)
runtime_intermediate_outputs = merge_overlapping_debug_handles(
runtime_intermediate_outputs
)

# Create a graph(nodes and edges) of overlapping(between aot and runtime) debug handles
nodes, edges = _create_debug_handle_overlap_graph(
Expand Down
10 changes: 5 additions & 5 deletions devtools/inspector/tests/inspector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def _gen_random_events(self) -> List[Event]:
events = []
for i in range(2):
events.append(
# OPERATOR_CALL with debug_hanldes/instruction_id 0 and 2
# OPERATOR_CALL with debug_handle/instruction_id 0 and 2
Event(
name="OPERATOR_CALL",
op_types=[OP_TYPE],
Expand All @@ -676,7 +676,7 @@ def _gen_random_events(self) -> List[Event]:
)
)
events.append(
# op_0/op_1 wiht empty op_types and with debug_hanldes/instruction_id 1 and 3
# op_0/op_1 wiht empty op_types and with debug_handle/instruction_id 1 and 3
Event(
name=f"op_{i}",
op_types=[],
Expand All @@ -687,7 +687,7 @@ def _gen_random_events(self) -> List[Event]:
)
)

# op_2 with debug_hanldes/instruction_id 4
# op_2 with debug_handle/instruction_id 4
events.append(
Event(
name="op_2",
Expand All @@ -698,7 +698,7 @@ def _gen_random_events(self) -> List[Event]:
_instruction_id=4,
)
)
# op_3 also with debug_hanldes 4 but with instruction_id 5
# op_3 also with debug_handle 4 but with instruction_id 5
events.append(
Event(
name="op_3",
Expand All @@ -710,7 +710,7 @@ def _gen_random_events(self) -> List[Event]:
)
)

# op_4 to op_7 with debug_hanldes 5 to 8 and instruction_id 6 to 9
# op_4 to op_7 with debug_handle 5 to 8 and instruction_id 6 to 9
for i in range(4, EVENTS_SIZE - 2):
events.append(
Event(
Expand Down
26 changes: 24 additions & 2 deletions devtools/inspector/tests/inspector_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def test_compare_results_uint8(self):
self.assertGreater(calculate_snr([a], [b])[0], 30.0)
self.assertAlmostEqual(calculate_cosine_similarity([a], [b])[0], 1.0)

def test_merge_overlapping_debug_handles(self):
def test_merge_overlapping_debug_handles_basic(self):
big_tensor = torch.rand(100, 100)
intermediate_outputs = {
(1, 2, 3): "val1",
Expand All @@ -233,7 +233,7 @@ def test_merge_overlapping_debug_handles(self):
(11, 12): big_tensor,
}
# basic merge behavior
merge_overlapping_debug_handles(intermediate_outputs)
intermediate_outputs = merge_overlapping_debug_handles(intermediate_outputs)
expected_intermediate_outputs = {
(1, 2, 3, 4, 5): "val2",
(6, 7, 8): "val3",
Expand All @@ -243,6 +243,28 @@ def test_merge_overlapping_debug_handles(self):
self.assertEqual(intermediate_outputs, expected_intermediate_outputs)
self.assertIs(expected_intermediate_outputs[(10, 11, 12)], big_tensor)

def test_merge_overlapping_debug_handles_non_continuous(self):
tensor1 = (torch.randn(3, 4),)
tensor2 = (torch.randn(2, 3),)
tensor3 = (torch.randn(4, 5),)
tensor4 = (torch.randn(6, 7),)
tensor5 = (torch.randn(8, 9),)
intermediate_outputs = {
(1, 10): tensor1,
(2, 5): tensor2,
(1, 7, 9): tensor3,
(11, 13): tensor4,
(11, 15): tensor5,
}
intermediate_outputs = merge_overlapping_debug_handles(intermediate_outputs)
expected_intermediate_outputs = {
(2, 5): tensor2,
(1, 7, 9, 10): tensor1,
(11, 13, 15): tensor5,
}

self.assertEqual(intermediate_outputs, expected_intermediate_outputs)

def test_map_runtime_aot_intermediate_outputs_empty_inputs(self):
# When the inputs are empty, the output should also be empty
aot_intermediate_outputs = {}
Expand Down
Loading