Skip to content

Commit

Permalink
[#25859] DocDB: Pass ybhnsw options to docdb layer
Browse files Browse the repository at this point in the history
Summary:
Support specifying index creation options  m, m0, ef_construction for ybhnsw.
Standard Postgres pgvector hsnw supports the following options:
index creation options: m, ef_construction
query options: ef_search
Reference: https://github.com/pgvector/pgvector?tab=readme-ov-file#index-options
For ybhsnw, this diff:
adds plumbing for index creation options: m, ef_construction to be passed down
adds a new index creation option m0 (used by usearch) and adds plumbing for it

It does not yet handle query option ef_search.

**Upgrade/Rollback safety:** Safe, since modified field is used only by the code that was not yet released.
Jira: DB-15153

Test Plan: PgVectorIndexTest.Options/*

Reviewers: tnayak, arybochkin, jason

Reviewed By: tnayak

Subscribers: mihnea, smishra, jason, slingam, aleksandr.ponomarenko, ybase, yql

Tags: #jenkins-ready

Differential Revision: https://phorge.dev.yugabyte.com/D41785
  • Loading branch information
spolitov committed Mar 4, 2025
1 parent 4789974 commit 83b8ea1
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 50 deletions.
19 changes: 16 additions & 3 deletions src/postgres/third-party-extensions/pgvector/src/ybvector/ybhnsw.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
#define YBHNSW_MIN_M 5
#define YBHNSW_MAX_M 64

#define YBHNSW_DEFAULT_M0 0
#define YBHNSW_MIN_M0 0
#define YBHNSW_MAX_M0 (YBHNSW_MAX_M * 4)

#define YBHNSW_DEFAULT_EF_CONSTRUCTION 200
#define YBHNSW_MIN_EF_CONSTRUCTION 50
#define YBHNSW_MAX_EF_CONSTRUCTION 1000
Expand All @@ -47,6 +51,7 @@ typedef struct YbHnswOptions
{
int32 vl_len_; /* varlena header (do not touch directly!) */
int m; /* number of connections per node */
int m0; /* number of connections per node in base level */
int ef_construction; /* size of dynamic candidate list */
} YbHnswOptions;

Expand All @@ -58,6 +63,9 @@ YbHnswInit(void)
add_int_reloption(ybhnsw_relopt_kind, "m", "Max number of connections",
YBHNSW_DEFAULT_M, YBHNSW_MIN_M, YBHNSW_MAX_M,
AccessExclusiveLock);
add_int_reloption(ybhnsw_relopt_kind, "m0", "Max number of connections in base level",
YBHNSW_DEFAULT_M0, YBHNSW_MIN_M0, YBHNSW_MAX_M0,
AccessExclusiveLock);
add_int_reloption(ybhnsw_relopt_kind, "ef_construction",
"Size of the dynamic candidate list for construction",
YBHNSW_DEFAULT_EF_CONSTRUCTION,
Expand All @@ -76,6 +84,7 @@ ybhnswoptions(Datum reloptions, bool validate)
*/
static const relopt_parse_elt tab[] = {
{"m", RELOPT_TYPE_INT, offsetof(YbHnswOptions, m)},
{"m0", RELOPT_TYPE_INT, offsetof(YbHnswOptions, m0)},
{"ef_construction", RELOPT_TYPE_INT,
offsetof(YbHnswOptions, ef_construction)},
};
Expand All @@ -90,15 +99,19 @@ static void
bindYbHnswIndexOptions(YbcPgStatement handle, Datum reloptions)
{
YbHnswOptions *hnsw_options = (YbHnswOptions *) ybhnswoptions(reloptions, false);
int ef_construction = YBHNSW_DEFAULT_EF_CONSTRUCTION;
int m = YBHNSW_DEFAULT_M;
int m0 = YBHNSW_DEFAULT_M;
int ef_construction = YBHNSW_DEFAULT_EF_CONSTRUCTION;

if (hnsw_options)
{
ef_construction = hnsw_options->ef_construction;
m = hnsw_options->m;
m0 = hnsw_options->m0;
if (m0 < m)
m0 = m;
ef_construction = hnsw_options->ef_construction;
}
YBCPgCreateIndexSetHnswOptions(handle, ef_construction, m);
YBCPgCreateIndexSetHnswOptions(handle, m, m0, ef_construction);
}

static void
Expand Down
11 changes: 9 additions & 2 deletions src/yb/common/common.proto
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,20 @@ enum PgVectorIndexType {
HNSW = 3;
}

message PgHnswIndexOptionsPB {
optional uint32 m = 1;
optional uint32 m0 = 2;
optional uint32 ef_construction = 3;
}

message PgVectorIdxOptionsPB {
optional PgVectorDistanceType dist_type = 1;
optional PgVectorIndexType idx_type = 2;
optional uint32 dimensions = 3;
optional uint32 column_id = 4;
optional uint32 hnsw_ef = 5;
optional uint32 hnsw_m = 6;

reserved 5, 6;
optional PgHnswIndexOptionsPB hnsw = 7;
}

message PgVectorReadOptionsPB {
Expand Down
59 changes: 26 additions & 33 deletions src/yb/docdb/vector_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,37 +37,12 @@ DEFINE_RUNTIME_uint64(vector_index_initial_chunk_size, 100000,
DEFINE_RUNTIME_PREVIEW_uint32(vector_index_ef, 128,
"The \"expansion\" parameter for search");

DEFINE_RUNTIME_PREVIEW_uint32(vector_index_ef_construction, 256,
"The \"expansion\" parameter during graph construction");

DEFINE_RUNTIME_PREVIEW_uint32(vector_index_num_neighbors_per_vertex, 32,
"Number of neighbors per graph node");

DEFINE_RUNTIME_PREVIEW_uint32(vector_index_num_neighbors_per_vertex_base, 128,
"Number of neighbors per graph node in base level graph");

namespace yb::docdb {

const std::string kVectorIndexDirPrefix = "vi-";

namespace {

template <template<class, class> class Factory, class LSM>
auto VectorLSMFactory(vector_index::DistanceKind distance_kind, size_t dimensions) {
using FactoryImpl = vector_index::MakeVectorIndexFactory<Factory, LSM>;
return [distance_kind, dimensions] {
vector_index::HNSWOptions hnsw_options = {
.dimensions = dimensions,
.num_neighbors_per_vertex = FLAGS_vector_index_num_neighbors_per_vertex,
.num_neighbors_per_vertex_base = FLAGS_vector_index_num_neighbors_per_vertex_base,
.ef_construction = FLAGS_vector_index_ef_construction,
.ef = FLAGS_vector_index_ef,
.distance_kind = distance_kind,
};
return FactoryImpl::Create(hnsw_options);
};
}

vector_index::DistanceKind ConvertDistanceKind(PgVectorDistanceType dist_type) {
switch (dist_type) {
case PgVectorDistanceType::DIST_L2:
Expand All @@ -82,22 +57,41 @@ vector_index::DistanceKind ConvertDistanceKind(PgVectorDistanceType dist_type) {
FATAL_INVALID_ENUM_VALUE(PgVectorDistanceType, dist_type);
}

vector_index::HNSWOptions ConvertToHnswOptions(const PgVectorIdxOptionsPB& options) {
return {
.dimensions = options.dimensions(),
.num_neighbors_per_vertex = options.hnsw().m(),
.num_neighbors_per_vertex_base = options.hnsw().m0(),
.ef_construction = options.hnsw().ef_construction(),
.ef = FLAGS_vector_index_ef,
.distance_kind = ConvertDistanceKind(options.dist_type()),
};
}

template <template<class, class> class Factory, class LSM>
auto VectorLSMFactory(const PgVectorIdxOptionsPB& options) {
using FactoryImpl = vector_index::MakeVectorIndexFactory<Factory, LSM>;
return [hnsw_options = ConvertToHnswOptions(options)] {
return FactoryImpl::Create(hnsw_options);
};
}

template<vector_index::IndexableVectorType Vector,
vector_index::ValidDistanceResultType DistanceResult>
Result<typename vector_index::VectorLSMTypes<Vector, DistanceResult>::VectorIndexFactory>
GetVectorLSMFactory(PgVectorIndexType type, vector_index::DistanceKind distance_kind,
size_t dimensions) {
auto GetVectorLSMFactory(const PgVectorIdxOptionsPB& options)
-> Result<typename vector_index::VectorLSMTypes<Vector, DistanceResult>::VectorIndexFactory>{
using LSM = vector_index::VectorLSM<Vector, DistanceResult>;
switch (type) {
switch (options.idx_type()) {
case PgVectorIndexType::HNSW:
return VectorLSMFactory<vector_index::UsearchIndexFactory, LSM>(distance_kind, dimensions);
return VectorLSMFactory<vector_index::UsearchIndexFactory, LSM>(options);
case PgVectorIndexType::DUMMY: [[fallthrough]];
case PgVectorIndexType::IVFFLAT: [[fallthrough]];
case PgVectorIndexType::UNKNOWN_IDX:
break;
}
return STATUS_FORMAT(
NotSupported, "Vector index $0 is not supported", PgVectorIndexType_Name(type));
NotSupported, "Vector index $0 is not supported",
PgVectorIndexType_Name(options.idx_type()));
}

template<vector_index::IndexableVectorType Vector>
Expand Down Expand Up @@ -180,8 +174,7 @@ class VectorIndexImpl : public VectorIndex {
.log_prefix = log_prefix,
.storage_dir = GetStorageDir(data_root_dir, DirName()),
.vector_index_factory = VERIFY_RESULT((GetVectorLSMFactory<Vector, DistanceResult>(
idx_options.idx_type(), ConvertDistanceKind(idx_options.dist_type()),
idx_options.dimensions()))),
idx_options))),
.points_per_chunk = FLAGS_vector_index_initial_chunk_size,
.thread_pool = &thread_pool,
.frontiers_factory = [] { return std::make_unique<docdb::ConsensusFrontiers>(); },
Expand Down
9 changes: 5 additions & 4 deletions src/yb/yql/pggate/pg_ddl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,11 @@ Status PgCreateTableBase::SetVectorOptions(YbcPgVectorIdxOptions* options) {
return Status::OK();
}

Status PgCreateTableBase::SetHnswOptions(int ef_construction, int m) {
auto options_pb = req_.mutable_vector_idx_options();
options_pb->set_hnsw_ef(ef_construction);
options_pb->set_hnsw_m(m);
Status PgCreateTableBase::SetHnswOptions(int m, int m0, int ef_construction) {
auto& options_pb = *req_.mutable_vector_idx_options()->mutable_hnsw();
options_pb.set_m(m);
options_pb.set_m0(m0);
options_pb.set_ef_construction(ef_construction);
return Status::OK();
}

Expand Down
2 changes: 1 addition & 1 deletion src/yb/yql/pggate/pg_ddl.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class PgCreateTableBase : public PgDdl {

Status SetNumTablets(int32_t num_tablets);

Status SetHnswOptions(int ef_construction, int m);
Status SetHnswOptions(int m, int m0, int ef_construction);

Status SetVectorOptions(YbcPgVectorIdxOptions* options);

Expand Down
5 changes: 3 additions & 2 deletions src/yb/yql/pggate/pggate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1151,9 +1151,10 @@ Status PgApiImpl::CreateIndexSetVectorOptions(PgStatement* handle, YbcPgVectorId
return VERIFY_RESULT_REF(GetStatementAs<PgCreateIndex>(handle)).SetVectorOptions(options);
}

Status PgApiImpl::CreateIndexSetHnswOptions(PgStatement* handle, int ef_construction, int m) {
Status PgApiImpl::CreateIndexSetHnswOptions(
PgStatement* handle, int m, int m0, int ef_construction) {
return VERIFY_RESULT_REF(GetStatementAs<PgCreateIndex>(handle))
.SetHnswOptions(ef_construction, m);
.SetHnswOptions(m, m0, ef_construction);
}

Status PgApiImpl::ExecCreateIndex(PgStatement* handle) {
Expand Down
2 changes: 1 addition & 1 deletion src/yb/yql/pggate/pggate.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ class PgApiImpl {

Status CreateIndexSetVectorOptions(PgStatement *handle, YbcPgVectorIdxOptions *options);

Status CreateIndexSetHnswOptions(PgStatement *handle, int ef_construction, int m);
Status CreateIndexSetHnswOptions(PgStatement *handle, int m, int m0, int ef_construction);

Status CreateIndexAddSplitRow(PgStatement *handle, int num_cols,
YbcPgTypeEntity **types, uint64_t *data);
Expand Down
5 changes: 3 additions & 2 deletions src/yb/yql/pggate/ybc_pggate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1355,8 +1355,9 @@ YbcStatus YBCPgCreateIndexSetVectorOptions(YbcPgStatement handle, YbcPgVectorIdx
return ToYBCStatus(pgapi->CreateIndexSetVectorOptions(handle, options));
}

YbcStatus YBCPgCreateIndexSetHnswOptions(YbcPgStatement handle, int ef_construction, int m) {
return ToYBCStatus(pgapi->CreateIndexSetHnswOptions(handle, ef_construction, m));
YbcStatus YBCPgCreateIndexSetHnswOptions(
YbcPgStatement handle, int m, int m0, int ef_construction) {
return ToYBCStatus(pgapi->CreateIndexSetHnswOptions(handle, m, m0, ef_construction));
}

YbcStatus YBCPgExecCreateIndex(YbcPgStatement handle) {
Expand Down
3 changes: 2 additions & 1 deletion src/yb/yql/pggate/ybc_pggate.h
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,8 @@ YbcStatus YBCPgCreateIndexSetNumTablets(YbcPgStatement handle, int32_t num_table

YbcStatus YBCPgCreateIndexSetVectorOptions(YbcPgStatement handle, YbcPgVectorIdxOptions *options);

YbcStatus YBCPgCreateIndexSetHnswOptions(YbcPgStatement handle, int ef_construction, int m);
YbcStatus YBCPgCreateIndexSetHnswOptions(
YbcPgStatement handle, int m, int m0, int ef_construction);

YbcStatus YBCPgExecCreateIndex(YbcPgStatement handle);

Expand Down
87 changes: 86 additions & 1 deletion src/yb/yql/pgwrapper/pg_vector_index-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
#include "yb/consensus/consensus.h"
#include "yb/consensus/log.h"

#include "yb/docdb/doc_read_context.h"
#include "yb/docdb/vector_index.h"

#include "yb/integration-tests/cluster_itest_util.h"

#include "yb/qlexpr/index.h"
Expand Down Expand Up @@ -108,7 +111,10 @@ class PgVectorIndexTest : public PgMiniTestBase, public testing::WithParamInterf
}

Status CreateIndex(PGConn& conn) {
return conn.ExecuteFormat("CREATE INDEX ON test USING ybhnsw (embedding $0)", VectorOpsName());
return conn.ExecuteFormat(
"CREATE INDEX ON test USING ybhnsw (embedding $0) "
"WITH (ef_construction = 256, m = 32, m0 = 128)",
VectorOpsName());
}

Result<PGConn> MakeIndex(size_t dimensions = 3) {
Expand Down Expand Up @@ -695,6 +701,85 @@ TEST_P(PgVectorIndexTest, Cosine) {
TestMetric("2; 3; 1");
}

TEST_P(PgVectorIndexTest, Options) {
auto conn = ASSERT_RESULT(MakeTable());
std::unordered_map<TabletId, std::unordered_set<TableId>> checked_indexes;
std::vector<std::string> option_names = {"m", "m0", "ef_construction"};
// We need unique values for used params. Since m and ef has different allowed intervals,
// use different counters for them. counters[0] for m/m0 and counters[1] for ef_construction.
std::array<size_t, 2> counters = {32, 64};
for (int i = 0; i != 1 << option_names.size(); ++i) {
std::string expected_options;
{
std::string options;
size_t prev_value = 0;
for (size_t j = 0; j != option_names.size(); ++j) {
if (!expected_options.empty()) {
expected_options += " ";
}
size_t value;
if ((i & (1 << j))) {
value = ++counters[j >= 2];
if (!options.empty()) {
options += ", ";
}
options += Format("$0 = $1", option_names[j], value);
} else {
switch (j) {
case 0:
value = 32; // Default value for m
break;
case 1:
// When not specified m0 uses value of m.
value = prev_value;
break;
case 2:
value = 200; // Default value for ef
break;
default:
ASSERT_LT(j, 4U) << "Unexpected number of options";
value = 0;
break;
}
}
expected_options += Format("$0: $1", option_names[j], value);
prev_value = value;
}
if (!options.empty()) {
options = " WITH (" + options + ")";
}
auto query = "CREATE INDEX ON test USING ybhnsw (embedding vector_l2_ops)" + options;
LOG(INFO) << "Query: " << query;
ASSERT_OK(conn.Execute(query));
}
auto peers = ListTabletPeers(
cluster_.get(), ListPeersFilter::kLeaders, IncludeTransactionStatusTablets::kFalse);
for (const auto& peer : peers) {
auto tablet = peer->shared_tablet();
auto vector_indexes = tablet->vector_indexes().List();
if (!vector_indexes) {
continue;
}
auto& tablet_indexes = checked_indexes[peer->tablet_id()];
size_t num_new_indexes = 0;
for (const auto& vector_index : *vector_indexes) {
if (!tablet_indexes.insert(vector_index->table_id()).second) {
continue;
}
++num_new_indexes;
auto doc_read_context = ASSERT_RESULT(
tablet->metadata()->GetTableInfo(vector_index->table_id()))->doc_read_context;
const auto& hnsw_options = doc_read_context->vector_idx_options->hnsw();
LOG(INFO)
<< "Vector index: " << AsString(vector_index) << ", options: "
<< AsString(hnsw_options);
ASSERT_EQ(AsString(hnsw_options), expected_options);
}
ASSERT_EQ(num_new_indexes, 1);
}
}
}

std::string ColocatedToString(const testing::TestParamInfo<bool>& param_info) {
return param_info.param ? "Colocated" : "Distributed";
}
Expand Down

0 comments on commit 83b8ea1

Please sign in to comment.