diff --git a/src/postgres/third-party-extensions/pgvector/src/ybvector/ybhnsw.c b/src/postgres/third-party-extensions/pgvector/src/ybvector/ybhnsw.c index 0df803a469e4..297517562f49 100644 --- a/src/postgres/third-party-extensions/pgvector/src/ybvector/ybhnsw.c +++ b/src/postgres/third-party-extensions/pgvector/src/ybvector/ybhnsw.c @@ -34,6 +34,10 @@ #define YBHNSW_MIN_M 5 #define YBHNSW_MAX_M 64 +#define YBHNSW_DEFAULT_M0 0 +#define YBHNSW_MIN_M0 0 +#define YBHNSW_MAX_M0 (YBHNSW_MAX_M * 4) + #define YBHNSW_DEFAULT_EF_CONSTRUCTION 200 #define YBHNSW_MIN_EF_CONSTRUCTION 50 #define YBHNSW_MAX_EF_CONSTRUCTION 1000 @@ -47,6 +51,7 @@ typedef struct YbHnswOptions { int32 vl_len_; /* varlena header (do not touch directly!) */ int m; /* number of connections per node */ + int m0; /* number of connections per node in base level */ int ef_construction; /* size of dynamic candidate list */ } YbHnswOptions; @@ -58,6 +63,9 @@ YbHnswInit(void) add_int_reloption(ybhnsw_relopt_kind, "m", "Max number of connections", YBHNSW_DEFAULT_M, YBHNSW_MIN_M, YBHNSW_MAX_M, AccessExclusiveLock); + add_int_reloption(ybhnsw_relopt_kind, "m0", "Max number of connections in base level", + YBHNSW_DEFAULT_M0, YBHNSW_MIN_M0, YBHNSW_MAX_M0, + AccessExclusiveLock); add_int_reloption(ybhnsw_relopt_kind, "ef_construction", "Size of the dynamic candidate list for construction", YBHNSW_DEFAULT_EF_CONSTRUCTION, @@ -76,6 +84,7 @@ ybhnswoptions(Datum reloptions, bool validate) */ static const relopt_parse_elt tab[] = { {"m", RELOPT_TYPE_INT, offsetof(YbHnswOptions, m)}, + {"m0", RELOPT_TYPE_INT, offsetof(YbHnswOptions, m0)}, {"ef_construction", RELOPT_TYPE_INT, offsetof(YbHnswOptions, ef_construction)}, }; @@ -90,15 +99,19 @@ static void bindYbHnswIndexOptions(YbcPgStatement handle, Datum reloptions) { YbHnswOptions *hnsw_options = (YbHnswOptions *) ybhnswoptions(reloptions, false); - int ef_construction = YBHNSW_DEFAULT_EF_CONSTRUCTION; int m = YBHNSW_DEFAULT_M; + int m0 = YBHNSW_DEFAULT_M; + int ef_construction = YBHNSW_DEFAULT_EF_CONSTRUCTION; if (hnsw_options) { - ef_construction = hnsw_options->ef_construction; m = hnsw_options->m; + m0 = hnsw_options->m0; + if (m0 < m) + m0 = m; + ef_construction = hnsw_options->ef_construction; } - YBCPgCreateIndexSetHnswOptions(handle, ef_construction, m); + YBCPgCreateIndexSetHnswOptions(handle, m, m0, ef_construction); } static void diff --git a/src/yb/common/common.proto b/src/yb/common/common.proto index c057f849ef32..96c1161c7341 100644 --- a/src/yb/common/common.proto +++ b/src/yb/common/common.proto @@ -134,13 +134,20 @@ enum PgVectorIndexType { HNSW = 3; } +message PgHnswIndexOptionsPB { + optional uint32 m = 1; + optional uint32 m0 = 2; + optional uint32 ef_construction = 3; +} + message PgVectorIdxOptionsPB { optional PgVectorDistanceType dist_type = 1; optional PgVectorIndexType idx_type = 2; optional uint32 dimensions = 3; optional uint32 column_id = 4; - optional uint32 hnsw_ef = 5; - optional uint32 hnsw_m = 6; + + reserved 5, 6; + optional PgHnswIndexOptionsPB hnsw = 7; } message PgVectorReadOptionsPB { diff --git a/src/yb/docdb/vector_index.cc b/src/yb/docdb/vector_index.cc index 858896a4b9e5..4e5b139d7c13 100644 --- a/src/yb/docdb/vector_index.cc +++ b/src/yb/docdb/vector_index.cc @@ -37,37 +37,12 @@ DEFINE_RUNTIME_uint64(vector_index_initial_chunk_size, 100000, DEFINE_RUNTIME_PREVIEW_uint32(vector_index_ef, 128, "The \"expansion\" parameter for search"); -DEFINE_RUNTIME_PREVIEW_uint32(vector_index_ef_construction, 256, - "The \"expansion\" parameter during graph construction"); - -DEFINE_RUNTIME_PREVIEW_uint32(vector_index_num_neighbors_per_vertex, 32, - "Number of neighbors per graph node"); - -DEFINE_RUNTIME_PREVIEW_uint32(vector_index_num_neighbors_per_vertex_base, 128, - "Number of neighbors per graph node in base level graph"); - namespace yb::docdb { const std::string kVectorIndexDirPrefix = "vi-"; namespace { -template class Factory, class LSM> -auto VectorLSMFactory(vector_index::DistanceKind distance_kind, size_t dimensions) { - using FactoryImpl = vector_index::MakeVectorIndexFactory; - return [distance_kind, dimensions] { - vector_index::HNSWOptions hnsw_options = { - .dimensions = dimensions, - .num_neighbors_per_vertex = FLAGS_vector_index_num_neighbors_per_vertex, - .num_neighbors_per_vertex_base = FLAGS_vector_index_num_neighbors_per_vertex_base, - .ef_construction = FLAGS_vector_index_ef_construction, - .ef = FLAGS_vector_index_ef, - .distance_kind = distance_kind, - }; - return FactoryImpl::Create(hnsw_options); - }; -} - vector_index::DistanceKind ConvertDistanceKind(PgVectorDistanceType dist_type) { switch (dist_type) { case PgVectorDistanceType::DIST_L2: @@ -82,22 +57,41 @@ vector_index::DistanceKind ConvertDistanceKind(PgVectorDistanceType dist_type) { FATAL_INVALID_ENUM_VALUE(PgVectorDistanceType, dist_type); } +vector_index::HNSWOptions ConvertToHnswOptions(const PgVectorIdxOptionsPB& options) { + return { + .dimensions = options.dimensions(), + .num_neighbors_per_vertex = options.hnsw().m(), + .num_neighbors_per_vertex_base = options.hnsw().m0(), + .ef_construction = options.hnsw().ef_construction(), + .ef = FLAGS_vector_index_ef, + .distance_kind = ConvertDistanceKind(options.dist_type()), + }; +} + +template class Factory, class LSM> +auto VectorLSMFactory(const PgVectorIdxOptionsPB& options) { + using FactoryImpl = vector_index::MakeVectorIndexFactory; + return [hnsw_options = ConvertToHnswOptions(options)] { + return FactoryImpl::Create(hnsw_options); + }; +} + template -Result::VectorIndexFactory> - GetVectorLSMFactory(PgVectorIndexType type, vector_index::DistanceKind distance_kind, - size_t dimensions) { +auto GetVectorLSMFactory(const PgVectorIdxOptionsPB& options) + -> Result::VectorIndexFactory>{ using LSM = vector_index::VectorLSM; - switch (type) { + switch (options.idx_type()) { case PgVectorIndexType::HNSW: - return VectorLSMFactory(distance_kind, dimensions); + return VectorLSMFactory(options); case PgVectorIndexType::DUMMY: [[fallthrough]]; case PgVectorIndexType::IVFFLAT: [[fallthrough]]; case PgVectorIndexType::UNKNOWN_IDX: break; } return STATUS_FORMAT( - NotSupported, "Vector index $0 is not supported", PgVectorIndexType_Name(type)); + NotSupported, "Vector index $0 is not supported", + PgVectorIndexType_Name(options.idx_type())); } template @@ -180,8 +174,7 @@ class VectorIndexImpl : public VectorIndex { .log_prefix = log_prefix, .storage_dir = GetStorageDir(data_root_dir, DirName()), .vector_index_factory = VERIFY_RESULT((GetVectorLSMFactory( - idx_options.idx_type(), ConvertDistanceKind(idx_options.dist_type()), - idx_options.dimensions()))), + idx_options))), .points_per_chunk = FLAGS_vector_index_initial_chunk_size, .thread_pool = &thread_pool, .frontiers_factory = [] { return std::make_unique(); }, diff --git a/src/yb/yql/pggate/pg_ddl.cc b/src/yb/yql/pggate/pg_ddl.cc index fbdb5a8ca858..586321d42ba6 100644 --- a/src/yb/yql/pggate/pg_ddl.cc +++ b/src/yb/yql/pggate/pg_ddl.cc @@ -249,10 +249,11 @@ Status PgCreateTableBase::SetVectorOptions(YbcPgVectorIdxOptions* options) { return Status::OK(); } -Status PgCreateTableBase::SetHnswOptions(int ef_construction, int m) { - auto options_pb = req_.mutable_vector_idx_options(); - options_pb->set_hnsw_ef(ef_construction); - options_pb->set_hnsw_m(m); +Status PgCreateTableBase::SetHnswOptions(int m, int m0, int ef_construction) { + auto& options_pb = *req_.mutable_vector_idx_options()->mutable_hnsw(); + options_pb.set_m(m); + options_pb.set_m0(m0); + options_pb.set_ef_construction(ef_construction); return Status::OK(); } diff --git a/src/yb/yql/pggate/pg_ddl.h b/src/yb/yql/pggate/pg_ddl.h index 346fe65a9387..ac4ce4210cac 100644 --- a/src/yb/yql/pggate/pg_ddl.h +++ b/src/yb/yql/pggate/pg_ddl.h @@ -112,7 +112,7 @@ class PgCreateTableBase : public PgDdl { Status SetNumTablets(int32_t num_tablets); - Status SetHnswOptions(int ef_construction, int m); + Status SetHnswOptions(int m, int m0, int ef_construction); Status SetVectorOptions(YbcPgVectorIdxOptions* options); diff --git a/src/yb/yql/pggate/pggate.cc b/src/yb/yql/pggate/pggate.cc index 689f92f60ae0..7222404280bf 100644 --- a/src/yb/yql/pggate/pggate.cc +++ b/src/yb/yql/pggate/pggate.cc @@ -1151,9 +1151,10 @@ Status PgApiImpl::CreateIndexSetVectorOptions(PgStatement* handle, YbcPgVectorId return VERIFY_RESULT_REF(GetStatementAs(handle)).SetVectorOptions(options); } -Status PgApiImpl::CreateIndexSetHnswOptions(PgStatement* handle, int ef_construction, int m) { +Status PgApiImpl::CreateIndexSetHnswOptions( + PgStatement* handle, int m, int m0, int ef_construction) { return VERIFY_RESULT_REF(GetStatementAs(handle)) - .SetHnswOptions(ef_construction, m); + .SetHnswOptions(m, m0, ef_construction); } Status PgApiImpl::ExecCreateIndex(PgStatement* handle) { diff --git a/src/yb/yql/pggate/pggate.h b/src/yb/yql/pggate/pggate.h index f3b90be2e9fb..828ce44d2f8f 100644 --- a/src/yb/yql/pggate/pggate.h +++ b/src/yb/yql/pggate/pggate.h @@ -393,7 +393,7 @@ class PgApiImpl { Status CreateIndexSetVectorOptions(PgStatement *handle, YbcPgVectorIdxOptions *options); - Status CreateIndexSetHnswOptions(PgStatement *handle, int ef_construction, int m); + Status CreateIndexSetHnswOptions(PgStatement *handle, int m, int m0, int ef_construction); Status CreateIndexAddSplitRow(PgStatement *handle, int num_cols, YbcPgTypeEntity **types, uint64_t *data); diff --git a/src/yb/yql/pggate/ybc_pggate.cc b/src/yb/yql/pggate/ybc_pggate.cc index 5d778bfeb0b7..4e8cd6e60d62 100644 --- a/src/yb/yql/pggate/ybc_pggate.cc +++ b/src/yb/yql/pggate/ybc_pggate.cc @@ -1355,8 +1355,9 @@ YbcStatus YBCPgCreateIndexSetVectorOptions(YbcPgStatement handle, YbcPgVectorIdx return ToYBCStatus(pgapi->CreateIndexSetVectorOptions(handle, options)); } -YbcStatus YBCPgCreateIndexSetHnswOptions(YbcPgStatement handle, int ef_construction, int m) { - return ToYBCStatus(pgapi->CreateIndexSetHnswOptions(handle, ef_construction, m)); +YbcStatus YBCPgCreateIndexSetHnswOptions( + YbcPgStatement handle, int m, int m0, int ef_construction) { + return ToYBCStatus(pgapi->CreateIndexSetHnswOptions(handle, m, m0, ef_construction)); } YbcStatus YBCPgExecCreateIndex(YbcPgStatement handle) { diff --git a/src/yb/yql/pggate/ybc_pggate.h b/src/yb/yql/pggate/ybc_pggate.h index a0d685b96bbe..eb0087e0cde6 100644 --- a/src/yb/yql/pggate/ybc_pggate.h +++ b/src/yb/yql/pggate/ybc_pggate.h @@ -434,7 +434,8 @@ YbcStatus YBCPgCreateIndexSetNumTablets(YbcPgStatement handle, int32_t num_table YbcStatus YBCPgCreateIndexSetVectorOptions(YbcPgStatement handle, YbcPgVectorIdxOptions *options); -YbcStatus YBCPgCreateIndexSetHnswOptions(YbcPgStatement handle, int ef_construction, int m); +YbcStatus YBCPgCreateIndexSetHnswOptions( + YbcPgStatement handle, int m, int m0, int ef_construction); YbcStatus YBCPgExecCreateIndex(YbcPgStatement handle); diff --git a/src/yb/yql/pgwrapper/pg_vector_index-test.cc b/src/yb/yql/pgwrapper/pg_vector_index-test.cc index df76e63fa347..ff73ec4849c9 100644 --- a/src/yb/yql/pgwrapper/pg_vector_index-test.cc +++ b/src/yb/yql/pgwrapper/pg_vector_index-test.cc @@ -18,6 +18,9 @@ #include "yb/consensus/consensus.h" #include "yb/consensus/log.h" +#include "yb/docdb/doc_read_context.h" +#include "yb/docdb/vector_index.h" + #include "yb/integration-tests/cluster_itest_util.h" #include "yb/qlexpr/index.h" @@ -108,7 +111,10 @@ class PgVectorIndexTest : public PgMiniTestBase, public testing::WithParamInterf } Status CreateIndex(PGConn& conn) { - return conn.ExecuteFormat("CREATE INDEX ON test USING ybhnsw (embedding $0)", VectorOpsName()); + return conn.ExecuteFormat( + "CREATE INDEX ON test USING ybhnsw (embedding $0) " + "WITH (ef_construction = 256, m = 32, m0 = 128)", + VectorOpsName()); } Result MakeIndex(size_t dimensions = 3) { @@ -695,6 +701,85 @@ TEST_P(PgVectorIndexTest, Cosine) { TestMetric("2; 3; 1"); } +TEST_P(PgVectorIndexTest, Options) { + auto conn = ASSERT_RESULT(MakeTable()); + std::unordered_map> checked_indexes; + std::vector option_names = {"m", "m0", "ef_construction"}; + // We need unique values for used params. Since m and ef has different allowed intervals, + // use different counters for them. counters[0] for m/m0 and counters[1] for ef_construction. + std::array counters = {32, 64}; + for (int i = 0; i != 1 << option_names.size(); ++i) { + std::string expected_options; + { + std::string options; + size_t prev_value = 0; + for (size_t j = 0; j != option_names.size(); ++j) { + if (!expected_options.empty()) { + expected_options += " "; + } + size_t value; + if ((i & (1 << j))) { + value = ++counters[j >= 2]; + if (!options.empty()) { + options += ", "; + } + options += Format("$0 = $1", option_names[j], value); + } else { + switch (j) { + case 0: + value = 32; // Default value for m + break; + case 1: + // When not specified m0 uses value of m. + value = prev_value; + break; + case 2: + value = 200; // Default value for ef + break; + default: + ASSERT_LT(j, 4U) << "Unexpected number of options"; + value = 0; + break; + } + } + expected_options += Format("$0: $1", option_names[j], value); + prev_value = value; + } + if (!options.empty()) { + options = " WITH (" + options + ")"; + } + auto query = "CREATE INDEX ON test USING ybhnsw (embedding vector_l2_ops)" + options; + LOG(INFO) << "Query: " << query; + ASSERT_OK(conn.Execute(query)); + } + auto peers = ListTabletPeers( + cluster_.get(), ListPeersFilter::kLeaders, IncludeTransactionStatusTablets::kFalse); + for (const auto& peer : peers) { + auto tablet = peer->shared_tablet(); + auto vector_indexes = tablet->vector_indexes().List(); + if (!vector_indexes) { + continue; + } + auto& tablet_indexes = checked_indexes[peer->tablet_id()]; + size_t num_new_indexes = 0; + for (const auto& vector_index : *vector_indexes) { + if (!tablet_indexes.insert(vector_index->table_id()).second) { + continue; + } + ++num_new_indexes; + auto doc_read_context = ASSERT_RESULT( + tablet->metadata()->GetTableInfo(vector_index->table_id()))->doc_read_context; + const auto& hnsw_options = doc_read_context->vector_idx_options->hnsw(); + LOG(INFO) + << "Vector index: " << AsString(vector_index) << ", options: " + << AsString(hnsw_options); + ASSERT_EQ(AsString(hnsw_options), expected_options); + } + ASSERT_EQ(num_new_indexes, 1); + } + } +} + std::string ColocatedToString(const testing::TestParamInfo& param_info) { return param_info.param ? "Colocated" : "Distributed"; }