Skip to content

Commit

Permalink
chore(search_family): Simplify FT.SEARCH reply code
Browse files Browse the repository at this point in the history
Signed-off-by: Stepan Bagritsevich <[email protected]>
  • Loading branch information
BagritsevichStepan committed Feb 3, 2025
1 parent 80e4012 commit 191b6d7
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 80 deletions.
9 changes: 6 additions & 3 deletions src/core/search/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -688,18 +688,21 @@ SearchResult SearchAlgorithm::Search(const FieldIndices* index, size_t limit) co
return bs.Search(*query_);
}

optional<AggregationInfo> SearchAlgorithm::HasAggregation() const {
optional<AggregationInfo> SearchAlgorithm::GetAggregationInfo() const {
DCHECK(query_);

// KNN query
if (auto* knn = get_if<AstKnnNode>(query_.get()); knn)
return AggregationInfo{knn->limit, string_view{knn->score_alias}, false};
return AggregationInfo{string_view{knn->score_alias}, false, knn->limit};

// SEARCH query with SORTBY option
if (auto* sort = get_if<AstSortNode>(query_.get()); sort) {
string_view alias = "";
if (auto* knn = get_if<AstKnnNode>(&sort->filter->Variant());
knn && knn->score_alias == sort->field)
alias = knn->score_alias;

return AggregationInfo{nullopt, alias, sort->descending};
return AggregationInfo{alias, sort->descending};
}

return nullopt;
Expand Down
4 changes: 2 additions & 2 deletions src/core/search/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ struct SearchResult {
};

struct AggregationInfo {
std::optional<size_t> limit;
std::string_view alias;
bool descending;
size_t limit = std::numeric_limits<size_t>::max();
};

// SearchAlgorithm allows searching field indices with a query
Expand All @@ -154,7 +154,7 @@ class SearchAlgorithm {
size_t limit = std::numeric_limits<size_t>::max()) const;

// if enabled, return limit & alias for knn query
std::optional<AggregationInfo> HasAggregation() const;
std::optional<AggregationInfo> GetAggregationInfo() const;

void EnableProfiling();

Expand Down
115 changes: 40 additions & 75 deletions src/server/search/search_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,89 +467,62 @@ void SendSerializedDoc(const SerializedSearchDoc& doc, SinkReplyBuilder* builder
}
}

void ReplyWithResults(const SearchParams& params, absl::Span<SearchResult> results,
SinkReplyBuilder* builder) {
size_t total_count = 0;
for (const auto& shard_docs : results)
total_count += shard_docs.total_hits;

size_t result_count =
min(total_count - min(total_count, params.limit_offset), params.limit_total);

facade::SinkReplyBuilder::ReplyAggregator agg{builder};

bool ids_only = params.IdsOnly();
size_t reply_size = ids_only ? (result_count + 1) : (result_count * 2 + 1);

auto* rb = static_cast<RedisReplyBuilder*>(builder);
rb->StartArray(reply_size);
rb->SendLong(total_count);

size_t sent = 0;
size_t to_skip = params.limit_offset;
for (const auto& shard_docs : results) {
for (const auto& serialized_doc : shard_docs.docs) {
// Scoring is not implemented yet, so we just cut them in the order they were retrieved
if (to_skip > 0) {
to_skip--;
continue;
}

if (sent++ >= result_count)
return;

if (ids_only)
rb->SendBulkString(serialized_doc.key);
else
SendSerializedDoc(serialized_doc, builder);
}
}
}

void ReplySorted(search::AggregationInfo agg, const SearchParams& params,
void SearchReply(const SearchParams& params, std::optional<search::AggregationInfo> agg_info,
absl::Span<SearchResult> results, SinkReplyBuilder* builder) {
size_t total = 0;
vector<SerializedSearchDoc*> docs;
size_t total_hits = 0;
std::vector<SerializedSearchDoc*> docs;
docs.reserve(results.size());
for (auto& shard_results : results) {
total += shard_results.total_hits;
total_hits += shard_results.total_hits;
for (auto& doc : shard_results.docs) {
docs.push_back(&doc);
}
}

size_t agg_limit = agg.limit.value_or(total);
size_t prefix = min(params.limit_offset + params.limit_total, agg_limit);
size_t size = docs.size();
bool should_add_score_field = false;

partial_sort(docs.begin(), docs.begin() + min(docs.size(), prefix), docs.end(),
[desc = agg.descending](const auto* l, const auto* r) {
return desc ? (*l >= *r) : (*l < *r);
});
if (agg_info) {
size = std::min(size, agg_info->limit);
total_hits = std::min(total_hits, agg_info->limit);
should_add_score_field = !agg_info->alias.empty();

auto comparator = [desc = agg_info->descending](const auto* l, const auto* r) {
return desc ? (*l >= *r) : (*l < *r);
};

const size_t prefix_size_to_sort = std::min(params.limit_offset + params.limit_total, size);
if (prefix_size_to_sort == docs.size()) {
std::sort(docs.begin(), docs.end(), std::move(comparator));
} else {
std::partial_sort(docs.begin(), docs.begin() + prefix_size_to_sort, docs.end(),
std::move(comparator));
}
}

docs.resize(min(docs.size(), agg_limit));
const size_t offset = std::min(params.limit_offset, size);
const size_t limit = std::min(size - offset, params.limit_total);

size_t start_idx = min(params.limit_offset, docs.size());
size_t result_count = min(docs.size() - start_idx, params.limit_total);
bool ids_only = params.IdsOnly();
size_t reply_size = ids_only ? (result_count + 1) : (result_count * 2 + 1);
const bool reply_with_ids_only = params.IdsOnly();
const size_t reply_size = reply_with_ids_only ? (limit + 1) : (limit * 2 + 1);

// Clear score alias if it's excluded from return values
if (!params.ShouldReturnField(agg.alias))
agg.alias = "";
facade::SinkReplyBuilder::ReplyAggregator agg{builder};

facade::SinkReplyBuilder::ReplyAggregator agg_reply{builder};
auto* rb = static_cast<RedisReplyBuilder*>(builder);
rb->StartArray(reply_size);
rb->SendLong(min(total, agg_limit));
for (auto* doc : absl::MakeSpan(docs).subspan(start_idx, result_count)) {
if (ids_only) {
rb->SendBulkString(doc->key);
rb->SendLong(total_hits);

const size_t end = offset + limit;
for (size_t i = offset; i < end; i++) {
if (reply_with_ids_only) {
rb->SendBulkString(docs[i]->key);
continue;
}

if (!agg.alias.empty() && holds_alternative<float>(doc->score))
doc->values[agg.alias] = absl::StrCat(get<float>(doc->score));
if (should_add_score_field && holds_alternative<float>(docs[i]->score))
docs[i]->values[agg_info->alias] = absl::StrCat(get<float>(docs[i]->score));

SendSerializedDoc(*doc, builder);
SendSerializedDoc(*docs[i], builder);
}
}

Expand Down Expand Up @@ -819,10 +792,7 @@ void SearchFamily::FtSearch(CmdArgList args, const CommandContext& cmd_cntx) {
return builder->SendError(*res.error);
}

if (auto agg = search_algo.HasAggregation(); agg)
ReplySorted(*agg, *params, absl::MakeSpan(docs), builder);
else
ReplyWithResults(*params, absl::MakeSpan(docs), builder);
SearchReply(*params, search_algo.GetAggregationInfo(), absl::MakeSpan(docs), builder);
}

void SearchFamily::FtProfile(CmdArgList args, const CommandContext& cmd_cntx) {
Expand Down Expand Up @@ -898,12 +868,7 @@ void SearchFamily::FtProfile(CmdArgList args, const CommandContext& cmd_cntx) {

// Result of the search command
if (!result_is_empty) {
auto agg = search_algo.HasAggregation();
if (agg) {
ReplySorted(*agg, *params, absl::MakeSpan(search_results), rb);
} else {
ReplyWithResults(*params, absl::MakeSpan(search_results), rb);
}
SearchReply(*params, search_algo.GetAggregationInfo(), absl::MakeSpan(search_results), rb);
} else {
rb->StartArray(1);
rb->SendLong(0);
Expand Down
38 changes: 38 additions & 0 deletions src/server/search/search_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1884,4 +1884,42 @@ TEST_F(SearchFamilyTest, InvalidSearchOptions) {
EXPECT_THAT(resp, IsArray(IntArg(1), "j1"));
}

TEST_F(SearchFamilyTest, KnnSearchOptions) {
Run({"JSON.SET", "doc:1", ".", R"({"vector": [0.1, 0.2, 0.3, 0.4]})"});
Run({"JSON.SET", "doc:2", ".", R"({"vector": [0.5, 0.6, 0.7, 0.8]})"});
Run({"JSON.SET", "doc:3", ".", R"({"vector": [0.9, 0.1, 0.4, 0.3]})"});

auto resp = Run({"FT.CREATE", "my_index", "ON", "JSON", "PREFIX", "1", "doc:",
"SCHEMA", "$.vector", "AS", "vector", "VECTOR", "FLAT", "6",
"TYPE", "FLOAT32", "DIM", "4", "DISTANCE_METRIC", "COSINE"});
EXPECT_EQ(resp, "OK");

std::string query_vector("\x00\x00\x00\x3f\x00\x00\x00\x40\x00\x00\x00\x41\x00\x00\x80\x42", 16);

// KNN 2
resp = Run({"FT.SEARCH", "my_index", "*=>[KNN 2 @vector $query_vector]", "PARAMS", "2",
"query_vector", query_vector});
EXPECT_THAT(resp, AreDocIds("doc:1", "doc:2"));

// KNN 11929939
resp = Run({"FT.SEARCH", "my_index", "*=>[KNN 11929939 @vector $query_vector]", "PARAMS", "2",
"query_vector", query_vector});
EXPECT_THAT(resp, AreDocIds("doc:1", "doc:2", "doc:3"));

// KNN 11929939, LIMIT 4 2
resp = Run({"FT.SEARCH", "my_index", "*=>[KNN 11929939 @vector $query_vector]", "PARAMS", "2",
"query_vector", query_vector, "LIMIT", "4", "2"});
EXPECT_THAT(resp, IntArg(3));

// KNN 11929939, LIMIT 0 10
resp = Run({"FT.SEARCH", "my_index", "*=>[KNN 11929939 @vector $query_vector]", "PARAMS", "2",
"query_vector", query_vector, "LIMIT", "0", "10"});
EXPECT_THAT(resp, AreDocIds("doc:1", "doc:2", "doc:3"));

// KNN 1, LIMIT 0 2
resp = Run({"FT.SEARCH", "my_index", "*=>[KNN 1 @vector $query_vector]", "PARAMS", "2",
"query_vector", query_vector, "LIMIT", "0", "2"});
EXPECT_THAT(resp, AreDocIds("doc:1"));
}

} // namespace dfly

0 comments on commit 191b6d7

Please sign in to comment.