Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into unit_test_3
Browse files Browse the repository at this point in the history
  • Loading branch information
Ami11111 committed Aug 28, 2024
2 parents d7c58dd + 155b371 commit 9119d47
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 8 deletions.
9 changes: 7 additions & 2 deletions python/infinity_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,8 @@ def match(self, *args, **kwargs):
deprecated_api("match is deprecated, please use match_text instead")
return self.match_text(*args, **kwargs)

def match_tensor(self, column_name: str, query_data, query_data_type: str, topn: int):
def match_tensor(self, column_name: str, query_data, query_data_type: str, topn: int,
extra_option: Optional[dict] = None):
self._match_tensor = {}
self._match_tensor["search_method"] = "maxsim"
if isinstance(column_name, list):
Expand All @@ -584,7 +585,11 @@ def match_tensor(self, column_name: str, query_data, query_data_type: str, topn:
column_name = [column_name]
self._match_tensor["fields"] = column_name
self._match_tensor["query_tensor"] = query_data
self._match_tensor["options"] = f"topn={topn}"
option_str = f"topn={topn}"
if extra_option is not None:
for k, v in extra_option.items():
option_str += f";{k}={v}"
self._match_tensor["options"] = option_str
self._match_tensor["element_type"] = type_transfrom[query_data_type]
return self

Expand Down
14 changes: 8 additions & 6 deletions python/test_pysdk/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,11 @@ def test_tensor_scan(self, check_data, save_elem_t, query_elem_t, suffix):
.match_tensor('t', [[0.0, -10.0, 0.0, 0.7], [9.2, 45.6, -55.8, 3.5]], query_elem_t, 3)
.to_pl())
print(res)
with pytest.raises(InfinityException) as e:
table_obj.output(["*", "_row_id", "_score"]).match_tensor('t',
[[0.0, -10.0, 0.0, 0.7], [9.2, 45.6, -55.8, 3.5]],
query_elem_t, 3, {"some_option": 222}).to_df()
print(e)
pd.testing.assert_frame_equal(
table_obj.output(["title"]).match_tensor('t', [[0.0, -10.0, 0.0, 0.7], [9.2, 45.6, -55.8, 3.5]],
query_elem_t, 3).to_df(),
Expand Down Expand Up @@ -1205,7 +1210,6 @@ def test_tensor_scan_with_invalid_data_type(self, check_data, data_type, suffix)
pytest.param(1),
pytest.param(1.1),
pytest.param([]),
pytest.param({}),
pytest.param(()),
])
@pytest.mark.parametrize("check_data", [{"file_name": "tensor_maxsim.csv",
Expand All @@ -1223,11 +1227,9 @@ def test_tensor_scan_with_invalid_extra_option(self, check_data, extra_option, s
test_csv_dir = common_values.TEST_TMP_DIR + "tensor_maxsim.csv"
table_obj.import_data(test_csv_dir, import_options={"delimiter": ","})
with pytest.raises(Exception):
res = (table_obj
.output(["*", "_row_id", "_score"])
.match_tensor('t', [[0.0, -10.0, 0.0, 0.7], [9.2, 45.6, -55.8, 3.5]], 'float', 'maxsim',
extra_option)
.to_pl())
table_obj.output(["*", "_row_id", "_score"]).match_tensor('t',
[[0.0, -10.0, 0.0, 0.7], [9.2, 45.6, -55.8, 3.5]],
'float', 3, extra_option).to_pl()

res = db_obj.drop_table("test_tensor_scan"+suffix, ConflictType.Error)
assert res.error_code == ErrorCode.OK
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ import filter_value_type_classification;
import logical_match_tensor_scan;
import simd_functions;
import knn_expression;
import search_options;

namespace infinity {

Expand Down Expand Up @@ -200,6 +201,12 @@ void PhysicalMatchTensorScan::PlanWithIndex(QueryContext *query_context) {
}
}
}
if (!block_column_entries_.empty()) {
// check unused option text
if (const SearchOptions options(src_match_tensor_expr_->options_text_); options.size() != options.options_.count("topn")) {
RecoverableError(Status::SyntaxError(fmt::format(R"(Input option text "{}" has unused part.)", src_match_tensor_expr_->options_text_)));
}
}
LOG_TRACE(fmt::format("MatchTensorScan: brute force task: {}, index task: {}", block_column_entries_.size(), index_entries_.size()));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,16 @@ import logger;
namespace infinity {

void LogicalMatchTensorScan::InitExtraOptions() {
static const std::set<String> valid_options =
{"topn", "emvb_centroid_nprobe", "emvb_threshold_first", "emvb_n_doc_to_score", "emvb_n_doc_out_second_stage", "emvb_threshold_final"};
auto match_tensor_expr = static_cast<MatchTensorExpression *>(query_expression_.get());
SearchOptions options(match_tensor_expr->options_text_);
for (const auto &[x, _] : options.options_) {
if (!valid_options.contains(x)) {
RecoverableError(
Status::SyntaxError(fmt::format(R"(Input option text "{}" has invalid part "{}".)", match_tensor_expr->options_text_, x)));
}
}
// topn option
if (auto top_n_it = options.options_.find("topn"); top_n_it != options.options_.end()) {
const int top_n_option = std::stoi(top_n_it->second);
Expand Down

0 comments on commit 9119d47

Please sign in to comment.