diff --git a/tools/generate_multivector_knn_scan.py b/tools/generate_multivector_knn_scan.py index 51564730c2..f909ecf680 100644 --- a/tools/generate_multivector_knn_scan.py +++ b/tools/generate_multivector_knn_scan.py @@ -3,6 +3,18 @@ import random from math import sin, cos, acos, pi import numpy as np +import string + + +class MyFormatter(string.Formatter): + def format_field(self, value, format_spec): + if format_spec == 'm': + return super().format_field(value, 'e').replace('e+', 'e') + else: + return super().format_field(value, format_spec) + + +fmt = MyFormatter() def get_sphere_point(r, rand1, rand2): @@ -15,11 +27,11 @@ def get_sphere_point(r, rand1, rand2): row_n = 1000 -query_data = [3, 500, 990, random.randint(0, row_n - 1)] -top_n = [5, 255, 256, row_n] +query_data = [random.randint(0, row_n - 1), 3, 500, 990] +top_n = [row_n, 5, 255, 256] num_in_row = 10 radius = 10000.0 -require_min_diff = 1e-6 +require_min_diff = 1e-5 def get_numpy_l2_norm(v1: np.ndarray, v2: np.ndarray): @@ -44,7 +56,7 @@ def get_random_data(): # check query result for query_data, return if all query distances are different distance_results = [] good_diff = True - for i in query_data: + for pos, i in enumerate(query_data): query_v = all_multivector_centers[i] distance_pair = [] for j in range(row_n): @@ -53,7 +65,7 @@ def get_random_data(): distance_pair.append((j, np.min(l2_d))) # sort by distance distance_pair.sort(key=lambda x: x[1]) - for j in range(row_n - 1): + for j in range(min(row_n - 1, top_n[pos])): if (distance_pair[j + 1][1] - distance_pair[j][1]) / distance_pair[j + 1][1] < require_min_diff: good_diff = False break @@ -127,7 +139,8 @@ def write_twice(statement: str, output_table_name: bool = False): write_twice("\n# multivector scan\n") for i, q_id in enumerate(query_data): write_twice("\nquery I\n") - query_v = np.array2string(all_multivector_centers[q_id], separator=',') + query_v = np.array2string(all_multivector_centers[q_id], separator=',', + formatter={'float': lambda x: fmt.format_field(x, 'm')}) query_str = f"SELECT c1 FROM \u007b\u007d SEARCH MATCH VECTOR (c2, {query_v}, 'float', 'l2', {top_n[i]});\n" write_twice(query_str, True) write_twice("----\n")