Skip to content

Commit

Permalink
Update test script (infiniflow#1722)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

Update test script
Remove "+" in float exponent part

### Type of change

- [x] Test cases
  • Loading branch information
yangzq50 authored Aug 26, 2024
1 parent 080a8b7 commit b5183d5
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions tools/generate_multivector_knn_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,18 @@
import random
from math import sin, cos, acos, pi
import numpy as np
import string


class MyFormatter(string.Formatter):
def format_field(self, value, format_spec):
if format_spec == 'm':
return super().format_field(value, 'e').replace('e+', 'e')
else:
return super().format_field(value, format_spec)


fmt = MyFormatter()


def get_sphere_point(r, rand1, rand2):
Expand All @@ -15,11 +27,11 @@ def get_sphere_point(r, rand1, rand2):


row_n = 1000
query_data = [3, 500, 990, random.randint(0, row_n - 1)]
top_n = [5, 255, 256, row_n]
query_data = [random.randint(0, row_n - 1), 3, 500, 990]
top_n = [row_n, 5, 255, 256]
num_in_row = 10
radius = 10000.0
require_min_diff = 1e-6
require_min_diff = 1e-5


def get_numpy_l2_norm(v1: np.ndarray, v2: np.ndarray):
Expand All @@ -44,7 +56,7 @@ def get_random_data():
# check query result for query_data, return if all query distances are different
distance_results = []
good_diff = True
for i in query_data:
for pos, i in enumerate(query_data):
query_v = all_multivector_centers[i]
distance_pair = []
for j in range(row_n):
Expand All @@ -53,7 +65,7 @@ def get_random_data():
distance_pair.append((j, np.min(l2_d)))
# sort by distance
distance_pair.sort(key=lambda x: x[1])
for j in range(row_n - 1):
for j in range(min(row_n - 1, top_n[pos])):
if (distance_pair[j + 1][1] - distance_pair[j][1]) / distance_pair[j + 1][1] < require_min_diff:
good_diff = False
break
Expand Down Expand Up @@ -127,7 +139,8 @@ def write_twice(statement: str, output_table_name: bool = False):
write_twice("\n# multivector scan\n")
for i, q_id in enumerate(query_data):
write_twice("\nquery I\n")
query_v = np.array2string(all_multivector_centers[q_id], separator=',')
query_v = np.array2string(all_multivector_centers[q_id], separator=',',
formatter={'float': lambda x: fmt.format_field(x, 'm')})
query_str = f"SELECT c1 FROM \u007b\u007d SEARCH MATCH VECTOR (c2, {query_v}, 'float', 'l2', {top_n[i]});\n"
write_twice(query_str, True)
write_twice("----\n")
Expand Down

0 comments on commit b5183d5

Please sign in to comment.