Skip to content

Commit 68a3b81

Browse files
committed
binary search approach for searchsorted(), closes issue 14833
1 parent 1342657 commit 68a3b81

File tree

1 file changed

+25
-17
lines changed

1 file changed

+25
-17
lines changed

pandas/core/indexes/multi.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3855,6 +3855,30 @@ def searchsorted(
38553855
if side not in ["left", "right"]:
38563856
raise ValueError("side must be either 'left' or 'right'")
38573857

3858+
if sorter is None:
3859+
get_val = lambda i: self.values[i]
3860+
else:
3861+
get_val = lambda i: self.values[sorter[i]]
3862+
3863+
def has_missing(val):
3864+
if isinstance(val, tuple):
3865+
return np.any(isna(list(val)))
3866+
return np.any(isna(val))
3867+
3868+
def binary_search(key, side="left", sorter=None):
3869+
l_ptr, r_ptr = 0, len(self) if sorter is None else len(sorter)
3870+
while l_ptr < r_ptr:
3871+
mid = l_ptr + (r_ptr - l_ptr) // 2
3872+
3873+
mid_val = get_val(mid)
3874+
if has_missing(mid_val):
3875+
raise ValueError(f"Unsortable or missing value: {mid_val}")
3876+
if mid_val > key or (side == "left" and mid_val == key):
3877+
r_ptr = mid
3878+
else:
3879+
l_ptr = mid + 1
3880+
return sorter[l_ptr] if sorter is not None else l_ptr
3881+
38583882
indexer = self.get_indexer(value)
38593883
result = []
38603884

@@ -3863,23 +3887,7 @@ def searchsorted(
38633887
val = i if side == "left" else i + 1
38643888
result.append(np.intp(val))
38653889
else:
3866-
fields = []
3867-
for j, level in enumerate(self.levels):
3868-
level_dtype = level.dtype
3869-
if isinstance(level_dtype, ExtensionDtype):
3870-
fields.append((f"level_{j}", object))
3871-
else:
3872-
fields.append((f"level_{j}", level_dtype))
3873-
dtype = np.dtype(fields)
3874-
3875-
val_array = np.array([v], dtype=dtype)
3876-
pos = np.searchsorted(
3877-
np.asarray(self.values, dtype=dtype),
3878-
val_array,
3879-
side=side,
3880-
sorter=sorter,
3881-
)
3882-
result.append(np.intp(pos[0]))
3890+
result.append(binary_search(v, side=side, sorter=sorter))
38833891

38843892
if len(result) == 1:
38853893
return result[0]

0 commit comments

Comments
 (0)