Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into unit_test_6
Browse files Browse the repository at this point in the history
  • Loading branch information
Ami11111 committed Aug 29, 2024
2 parents 7a2de20 + 663d4ad commit 058f385
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 7 deletions.
10 changes: 9 additions & 1 deletion src/storage/knn_index/knn_hnsw/hnsw_util.cppm
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@ export module hnsw_util;

import stl;
import statement_common;
import infinity_exception;
import status;

namespace infinity {

export struct HnswOptimizeOptions {
bool compress_to_lvq = false;
bool lvq_avg = false;
};

export struct HnswUtil {
Expand All @@ -33,9 +36,14 @@ export struct HnswUtil {
for (const auto &param : opt_params) {
if (IsEqual(param->param_name_, "compress_to_lvq")) {
options.compress_to_lvq = true;
} else if (IsEqual(param->param_name_, "lvq_avg")) {
options.lvq_avg = true;
}
}
if (!options.compress_to_lvq) {
if (options.compress_to_lvq && options.lvq_avg) {
RecoverableError(Status::InvalidIndexParam("compress_to_lvq and lvq_avg cannot be set at the same time"));
}
if (!options.compress_to_lvq && !options.lvq_avg) {
return None;
}
return options;
Expand Down
17 changes: 11 additions & 6 deletions src/storage/meta/entry/segment_index_entry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -789,12 +789,17 @@ void SegmentIndexEntry::OptIndex(IndexBase *index_base,
UnrecoverableError("Invalid index type.");
} else {
using HnswIndexDataType = typename std::remove_pointer_t<T>::DataType;
if constexpr (IsAnyOf<HnswIndexDataType, i8, u8>) {
UnrecoverableError("Invalid index type.");
} else {
auto *p = std::move(*index).CompressToLVQ().release();
delete index;
*abstract_hnsw = p;
if (params->compress_to_lvq) {
if constexpr (IsAnyOf<HnswIndexDataType, i8, u8>) {
UnrecoverableError("Invalid index type.");
} else {
auto *p = std::move(*index).CompressToLVQ().release();
delete index;
*abstract_hnsw = p;
}
}
if (params->lvq_avg) {
index->Optimize();
}
}
},
Expand Down
5 changes: 5 additions & 0 deletions src/storage/meta/entry/table_index_entry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,11 @@ void TableIndexEntry::OptIndex(TxnTableStore *txn_table_store, const Vector<Uniq
}
}
}
if (params->lvq_avg) {
for (const auto &[segment_id, segment_index_entry] : index_by_segment_) {
segment_index_entry->OptIndex(hnsw_index, txn_table_store, opt_params, false /*replay*/);
}
}
break;
}
default: {
Expand Down
10 changes: 10 additions & 0 deletions test/sql/dql/knn/embedding/test_knn_hnsw_l2_lvq.slt
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,15 @@ SELECT c1 FROM test_knn_hnsw_l2 SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], '
8
6

statement ok
OPTIMIZE idx1 ON test_knn_hnsw_l2 WITH (lvq_avg);

query I
SELECT c1 FROM test_knn_hnsw_l2 SEARCH MATCH VECTOR (c2, [0.3, 0.3, 0.2, 0.2], 'float', 'l2', 3) WITH (ef = 6, rerank);
----
8
8
6

statement ok
DROP TABLE test_knn_hnsw_l2;

0 comments on commit 058f385

Please sign in to comment.