Skip to content

Commit

Permalink
Add testcases for test_knn.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Ami11111 committed Jul 24, 2024
1 parent 29cd635 commit 50c04c3
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 0 deletions.
47 changes: 47 additions & 0 deletions python/test/cases/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,50 @@ def test_sparse_knn_with_index(self, check_data):
"data_dir": common_values.TEST_TMP_DIR}], indirect=True)
def test_with_multiple_fusion(self, check_data):
self.test_infinity_obj._test_with_multiple_fusion(check_data)

@pytest.mark.parametrize("check_data", [{"file_name": "pysdk_test_knn.csv",
"data_dir": common_values.TEST_TMP_DIR}], indirect=True)
@pytest.mark.parametrize("index_column_name", ["gender_vector"])
@pytest.mark.parametrize("knn_column_name", ["gender_vector"])
@pytest.mark.parametrize("index_distance_type", ["l2","ip", "cosine"])
@pytest.mark.parametrize("knn_distance_type", ["l2", "ip", "cosine"])
@pytest.mark.parametrize("index_type", [index.IndexType.Hnsw, index.IndexType.IVFFlat])
def test_with_various_index_knn_distance_combination(self, check_data, index_column_name, knn_column_name,
index_distance_type, knn_distance_type, index_type):
self.test_infinity_obj._test_with_various_index_knn_distance_combination(check_data, index_column_name, knn_column_name,
index_distance_type, knn_distance_type, index_type)

def test_zero_dimension_vector(self):
self.test_infinity_obj._test_zero_dimension_vector()

@pytest.mark.parametrize("dim", [1000, 10000, 100000, 200000])
def test_big_dimension_vector(self, dim):
self.test_infinity_obj._test_big_dimension_vector(dim)

# "^5" indicates the point that column "body" get multipy by 5, default is multipy by 1
# refer to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-query-string-query.html
@pytest.mark.parametrize("fields_and_matching_text", [
["body","black"],
["doctitle,num,body", "black"],
["doctitle,num,body^5", "black"],
["", "body:black"],
["", "body:black^5"],
["", "'body':'(black)'"],
["", "body:'(black)^5'"],
["", "'body':'(black OR white)'"],
["", "'body':'(black AND white)'"],
["", "'body':'(black)^5 OR (white)'"],
["", "'body':'(black)^5 AND (white)'"],
["", "'body':'black - white'"],
["", "body:black OR doctitle:black"],
["", "body:black AND doctitle:black"],
["", "(body:black OR doctitle:black) AND (body:white OR doctitle:white)"],
["", "(body:black)^5 OR doctitle:black"],
["", "(body:black)^5 AND doctitle:black"],
["", "(body:black OR doctitle:black)^5 AND (body:white OR doctitle:white)"],
#["", "doc\*:back"] not support
])
@pytest.mark.parametrize("check_data", [{"file_name": "enwiki_embedding_99_commas.csv",
"data_dir": common_values.TEST_TMP_DIR}], indirect=True)
def test_with_various_fulltext_match(self, check_data, fields_and_matching_text):
self.test_infinity_obj._test_with_various_fulltext_match(check_data, fields_and_matching_text)
172 changes: 172 additions & 0 deletions python/test/internal/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,3 +797,175 @@ def _test_with_multiple_fusion(self, check_data):

res = db_obj.drop_table("test_with_multiple_fusion", ConflictType.Error)
assert res.error_code == ErrorCode.OK

def _test_with_various_index_knn_distance_combination(self, check_data, index_column_name, knn_column_name,
index_distance_type, knn_distance_type, index_type):
db_obj = self.infinity_obj.get_database("default_db")
db_obj.drop_table("test_with_index", ConflictType.Ignore)
table_obj = db_obj.create_table("test_with_index", {
"variant_id": {"type": "varchar"},
"gender_vector": {"type": "vector,4,float"},
"color_vector": {"type": "vector,4,float"},
"category_vector": {"type": "vector,4,float"},
"tag_vector": {"type": "vector,4,float"},
"other_vector": {"type": "vector,4,float"},
"query_is_recommend": {"type": "varchar"},
"query_gender": {"type": "varchar"},
"query_color": {"type": "varchar"},
"query_price": {"type": "float"}
}, ConflictType.Error)
if not check_data:
copy_data("pysdk_test_knn.csv")
test_csv_dir = "/var/infinity/test_data/pysdk_test_knn.csv"
table_obj.import_data(test_csv_dir, None)
if index_type == index.IndexType.Hnsw:
if index_distance_type == "cosine":
with pytest.raises(InfinityException) as e:
res = table_obj.create_index("my_index",
[index.IndexInfo(index_column_name,
index.IndexType.Hnsw,
[
index.InitParameter(
"M", "16"),
index.InitParameter(
"ef_construction", "50"),
index.InitParameter(
"ef", "50"),
index.InitParameter(
"metric", index_distance_type)
])], ConflictType.Error)
assert e.type == InfinityException
assert e.value.args[0] == ErrorCode.INVALID_INDEX_PARAM
else:
res = table_obj.create_index("my_index",
[index.IndexInfo(index_column_name,
index.IndexType.Hnsw,
[
index.InitParameter(
"M", "16"),
index.InitParameter(
"ef_construction", "50"),
index.InitParameter(
"ef", "50"),
index.InitParameter(
"metric", index_distance_type)
])], ConflictType.Error)
assert res.error_code == ErrorCode.OK
res = table_obj.output(["variant_id"]).knn(
knn_column_name, [1] * 4, "float", knn_distance_type, 5).to_pl()
print(res)
res = table_obj.drop_index("my_index", ConflictType.Error)
assert res.error_code == ErrorCode.OK
elif index_type == index.IndexType.IVFFlat:
if index_distance_type == "cosine":
with pytest.raises(InfinityException) as e:
res = table_obj.create_index("my_index",
[index.IndexInfo(index_column_name,
index.IndexType.IVFFlat,
[index.InitParameter("centroids_count", "128"),
index.InitParameter("metric",
index_distance_type)])],
ConflictType.Error)
assert e.type == InfinityException
assert e.value.args[0] == ErrorCode.LACK_INDEX_PARAM
else:
res = table_obj.create_index("my_index",
[index.IndexInfo(index_column_name,
index.IndexType.IVFFlat,
[index.InitParameter("centroids_count", "128"),
index.InitParameter("metric", index_distance_type)])], ConflictType.Error)
assert res.error_code == ErrorCode.OK
#for IVFFlat, index_distance_type has to match knn_distance_type?
res = table_obj.output(["variant_id"]).knn(
knn_column_name, [1] * 4, "float", index_distance_type, 5).to_pl()
print(res)
res = table_obj.drop_index("my_index", ConflictType.Error)
assert res.error_code == ErrorCode.OK

res = db_obj.drop_table("test_with_index", ConflictType.Error)
assert res.error_code == ErrorCode.OK

def _test_zero_dimension_vector(self):
db_obj = self.infinity_obj.get_database("default_db")
db_obj.drop_table("test_zero_dimension_vector",
conflict_type=ConflictType.Ignore)
table_obj = db_obj.create_table("test_zero_dimension_vector", {
"zero_vector": {"type": "vector,0,float"},
}, ConflictType.Error)

# try to insert and search a non-zero dim vector
with pytest.raises(InfinityException) as e:
table_obj.insert([{"zero_vector":[0.0]}])
assert e.type == InfinityException
assert e.value.args[0] == ErrorCode.DATA_TYPE_MISMATCH
with pytest.raises(InfinityException) as e:
res = table_obj.output(["_row_id"]).knn(
"zero_vector", [0.0], "float", "l2", 5).to_pl()
assert e.type == InfinityException
assert e.value.args[0] == ErrorCode.SYNTAX_ERROR

# try to insert and search a zero dim vector
with pytest.raises(IndexError) as e:
table_obj.insert([{"zero_vector":[]}])
try:
res = table_obj.output(["_row_id"]).knn(
"zero_vector", [], "float", "l2", 5).to_pl()
except:
print("Exception")

res = db_obj.drop_table("test_zero_dimension_vector", ConflictType.Error)
assert res.error_code == ErrorCode.OK

def _test_big_dimension_vector(self, dim):
db_obj = self.infinity_obj.get_database("default_db")
db_obj.drop_table("test_big_dimension_vector",
conflict_type=ConflictType.Ignore)
table_obj = db_obj.create_table("test_big_dimension_vector", {
"big_vector": {"type": f"vector,{dim},float"},
}, ConflictType.Error)
table_obj.insert([{"big_vector": [1.0]*dim},
{"big_vector": [2.0]*dim},
{"big_vector": [3.0]*dim},
{"big_vector": [4.0]*dim},
{"big_vector": [5.0]*dim}])
res = table_obj.output(["_row_id"]).knn(
"big_vector", [0.0]*dim, "float", "l2", 5).to_pl()
print(res)

def _test_with_various_fulltext_match(self, check_data, fields_and_matching_text):
db_obj = self.infinity_obj.get_database("default_db")
db_obj.drop_table(
"test_with_various_fulltext_match", ConflictType.Ignore)
table_obj = db_obj.create_table("test_with_various_fulltext_match",
{"doctitle": {"type": "varchar"},
"docdate": {"type": "varchar"},
"body": {"type": "varchar"},
"num": {"type": "int"},
"vec": {"type": "vector, 4, float"}})
table_obj.create_index("my_index",
[index.IndexInfo("body",
index.IndexType.FullText,
[index.InitParameter("ANALYZER", "standard")]),
], ConflictType.Error)

if not check_data:
generate_commas_enwiki(
"enwiki_99.csv", "enwiki_embedding_99_commas.csv", 1)
copy_data("enwiki_embedding_99_commas.csv")

test_csv_dir = common_values.TEST_TMP_DIR + "enwiki_embedding_99_commas.csv"
table_obj.import_data(test_csv_dir, import_options={"delimiter": ","})
res = (table_obj
.output(["*"])
.knn("vec", [3.0, 2.8, 2.7, 3.1], "float", "ip", 1)
.match(fields_and_matching_text[0], fields_and_matching_text[1], "topn=1")
.fusion('rrf')
.to_pl())
print(res)

res = table_obj.drop_index("my_index", ConflictType.Error)
assert res.error_code == ErrorCode.OK

res = db_obj.drop_table(
"test_with_various_fulltext_match", ConflictType.Error)
assert res.error_code == ErrorCode.OK

0 comments on commit 50c04c3

Please sign in to comment.