Skip to content

Commit

Permalink
Correct scipy.optimize.linear_sum_assignment usage (ultralytics#4390)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZouJiu1 authored Aug 16, 2023
1 parent 9a0555e commit 17e6b9c
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions ultralytics/trackers/utils/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,21 @@ def linear_assignment(cost_matrix, thresh, use_lap=True):
return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))

if use_lap:
# https://github.com/gatagat/lap
_, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]
unmatched_a = np.where(x < 0)[0]
unmatched_b = np.where(y < 0)[0]
else:
# Scipy linear sum assignment is NOT working correctly, DO NOT USE
y, x = scipy.optimize.linear_sum_assignment(cost_matrix) # row y, col x
matches = np.asarray([[i, x] for i, x in enumerate(x) if cost_matrix[i, x] <= thresh])
unmatched = np.ones(cost_matrix.shape)
for i, xi in matches:
unmatched[i, xi] = 0.0
unmatched_a = np.where(unmatched.all(1))[0]
unmatched_b = np.where(unmatched.all(0))[0]
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html
x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y
matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh])
if len(matches) == 0:
unmatched_a = list(np.arange(cost_matrix.shape[0]))
unmatched_b = list(np.arange(cost_matrix.shape[1]))
else:
unmatched_a = list(set(np.arange(cost_matrix.shape[0])) - set(matches[:, 0]))
unmatched_b = list(set(np.arange(cost_matrix.shape[1])) - set(matches[:, 1]))

return matches, unmatched_a, unmatched_b

Expand Down

0 comments on commit 17e6b9c

Please sign in to comment.