Skip to content

PosgreSQL hybrid search #958

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 17 additions & 23 deletions extensions/Postgres/Postgres.TestApplication/Program.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.KernelMemory;
using Microsoft.KernelMemory.AI.Ollama;
using Microsoft.KernelMemory.DocumentStorage.DevTools;
using Microsoft.KernelMemory.FileSystem.DevTools;

Expand All @@ -26,16 +27,13 @@ private static async Task Test1()
var postgresConfig = cfg.GetSection("KernelMemory:Services:Postgres").Get<PostgresConfig>();
ArgumentNullExceptionEx.ThrowIfNull(postgresConfig, nameof(postgresConfig), "Postgres config not found");

var azureOpenAIEmbeddingConfig = cfg.GetSection("KernelMemory:Services:AzureOpenAIEmbedding").Get<AzureOpenAIConfig>();
ArgumentNullExceptionEx.ThrowIfNull(azureOpenAIEmbeddingConfig, nameof(azureOpenAIEmbeddingConfig), "AzureOpenAIEmbedding config not found");

var azureOpenAITextConfig = cfg.GetSection("KernelMemory:Services:AzureOpenAIText").Get<AzureOpenAIConfig>();
ArgumentNullExceptionEx.ThrowIfNull(azureOpenAITextConfig, nameof(azureOpenAITextConfig), "AzureOpenAIText config not found");
var ollamaConfig = cfg.GetSection("KernelMemory:Services:Ollama").Get<OllamaConfig>();
ArgumentNullExceptionEx.ThrowIfNull(ollamaConfig, nameof(ollamaConfig), "Ollama config not found");

// Concatenate our 'WithPostgresMemoryDb()' after 'WithOpenAIDefaults()' from the core nuget
var mem1 = new KernelMemoryBuilder()
.WithAzureOpenAITextGeneration(azureOpenAITextConfig)
.WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig)
.WithOllamaTextEmbeddingGeneration(ollamaConfig)
.WithOllamaTextGeneration(ollamaConfig)
.WithPostgresMemoryDb(postgresConfig)
.WithSimpleFileStorage(SimpleFileStorageConfig.Persistent)
.Build();
Expand All @@ -44,16 +42,16 @@ private static async Task Test1()
var mem2 = new KernelMemoryBuilder()
.WithPostgresMemoryDb(postgresConfig)
.WithSimpleFileStorage(SimpleFileStorageConfig.Persistent)
.WithAzureOpenAITextGeneration(azureOpenAITextConfig)
.WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig)
.WithOllamaTextEmbeddingGeneration(ollamaConfig)
.WithOllamaTextGeneration(ollamaConfig)
.Build();

// Concatenate our 'WithPostgresMemoryDb()' before and after KM builder extension methods from the core nuget
var mem3 = new KernelMemoryBuilder()
.WithSimpleFileStorage(SimpleFileStorageConfig.Persistent)
.WithAzureOpenAITextGeneration(azureOpenAITextConfig)
.WithOllamaTextEmbeddingGeneration(ollamaConfig)
.WithOllamaTextGeneration(ollamaConfig)
.WithPostgresMemoryDb(postgresConfig)
.WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig)
.Build();

await mem1.DeleteIndexAsync("index1");
Expand Down Expand Up @@ -92,22 +90,20 @@ private static async Task Test1()
private static async Task Test2()
{
var postgresConfig = new PostgresConfig();
var azureOpenAIEmbeddingConfig = new AzureOpenAIConfig();
var azureOpenAITextConfig = new AzureOpenAIConfig();
var ollamaConfig = new OllamaConfig();

new ConfigurationBuilder()
.AddJsonFile("appsettings.json")
.AddJsonFile("appsettings.development.json", optional: true)
.AddJsonFile("appsettings.Development.json", optional: true)
.Build()
.BindSection("KernelMemory:Services:Postgres", postgresConfig)
.BindSection("KernelMemory:Services:AzureOpenAIEmbedding", azureOpenAIEmbeddingConfig)
.BindSection("KernelMemory:Services:AzureOpenAIText", azureOpenAITextConfig);
.BindSection("KernelMemory:Services:Ollama", ollamaConfig);

var memory = new KernelMemoryBuilder()
.WithPostgresMemoryDb(postgresConfig)
.WithAzureOpenAITextGeneration(azureOpenAITextConfig)
.WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig)
.WithOllamaTextGeneration(ollamaConfig)
.WithOllamaTextEmbeddingGeneration(ollamaConfig)
.WithSimpleFileStorage(new SimpleFileStorageConfig
{
StorageType = FileSystemTypes.Disk,
Expand Down Expand Up @@ -140,8 +136,7 @@ private static async Task Test2()
private static async Task Test3()
{
var postgresConfig = new PostgresConfig();
var azureOpenAIEmbeddingConfig = new AzureOpenAIConfig();
var azureOpenAITextConfig = new AzureOpenAIConfig();
var ollamaConfig = new OllamaConfig();

// Note: using appsettings.custom-sql.json
new ConfigurationBuilder()
Expand All @@ -151,13 +146,12 @@ private static async Task Test3()
.AddJsonFile("appsettings.custom-sql.json")
.Build()
.BindSection("KernelMemory:Services:Postgres", postgresConfig)
.BindSection("KernelMemory:Services:AzureOpenAIEmbedding", azureOpenAIEmbeddingConfig)
.BindSection("KernelMemory:Services:AzureOpenAIText", azureOpenAITextConfig);
.BindSection("KernelMemory:Services:Ollama", ollamaConfig);

var memory = new KernelMemoryBuilder()
.WithPostgresMemoryDb(postgresConfig)
.WithAzureOpenAITextGeneration(azureOpenAITextConfig)
.WithAzureOpenAITextEmbeddingGeneration(azureOpenAIEmbeddingConfig)
.WithOllamaTextGeneration(ollamaConfig)
.WithOllamaTextEmbeddingGeneration(ollamaConfig)
.WithSimpleFileStorage(new SimpleFileStorageConfig
{
StorageType = FileSystemTypes.Disk,
Expand Down
75 changes: 67 additions & 8 deletions extensions/Postgres/Postgres/Internals/PostgresDbClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ public PostgresDbClient(PostgresConfig config, ILoggerFactory? loggerFactory = n
this._dbNamePresent = config.ConnectionString.Contains("Database=", StringComparison.OrdinalIgnoreCase);
this._schema = config.Schema;
this._tableNamePrefix = config.TableNamePrefix;
this._textSearchLanguage = config.TextSearchLanguage;
this._rrfK = config.RRFK;

this._colId = config.Columns[PostgresConfig.ColumnId];
this._colEmbedding = config.Columns[PostgresConfig.ColumnEmbedding];
Expand All @@ -59,6 +61,14 @@ public PostgresDbClient(PostgresConfig config, ILoggerFactory? loggerFactory = n

this._columnsListNoEmbeddings = $"{this._colId},{this._colTags},{this._colContent},{this._colPayload}";
this._columnsListWithEmbeddings = $"{this._colId},{this._colTags},{this._colContent},{this._colPayload},{this._colEmbedding}";
this._columnsListHybrid = $"{this._colId},{this._colTags},{this._colContent},{this._colPayload},{this._colEmbedding}";
this._columnsListHybridCoalesce = $@"
COALESCE(semantic_search.{this._colId}, keyword_search.{this._colId}) AS {this._colId},
COALESCE(semantic_search.{this._colTags}, keyword_search.{this._colTags}) AS {this._colTags},
COALESCE(semantic_search.{this._colContent}, keyword_search.{this._colContent}) AS {this._colContent},
COALESCE(semantic_search.{this._colPayload}, keyword_search.{this._colPayload}) AS {this._colPayload},
COALESCE(semantic_search.{this._colEmbedding}, keyword_search.{this._colEmbedding}) AS {this._colEmbedding}
";

this._createTableSql = string.Empty;
if (config.CreateTableSql?.Count > 0)
Expand Down Expand Up @@ -138,6 +148,8 @@ public async Task CreateTableAsync(
CancellationToken cancellationToken = default)
{
var origInputTableName = tableName;
var indexTags = this.WithTableNamePrefix(tableName) + "_idx_tags";
var indexContent = this.WithTableNamePrefix(tableName) + "_idx_content";
tableName = this.WithSchemaAndTableNamePrefix(tableName);
this._log.LogTrace("Creating table: {0}", tableName);

Expand Down Expand Up @@ -175,7 +187,8 @@ public async Task CreateTableAsync(
{this._colContent} TEXT DEFAULT '' NOT NULL,
{this._colPayload} JSONB DEFAULT '{{}}'::JSONB NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_tags ON {tableName} USING GIN({this._colTags});
CREATE INDEX IF NOT EXISTS ""{indexTags}"" ON {tableName} USING GIN({this._colTags});
CREATE INDEX IF NOT EXISTS ""{indexContent}"" ON {tableName} USING GIN(to_tsvector('{this._textSearchLanguage}',{this._colContent}));
COMMIT;
";
#pragma warning restore CA2100
Expand Down Expand Up @@ -388,23 +401,27 @@ DO UPDATE SET
/// Get a list of records
/// </summary>
/// <param name="tableName">Table containing the records to fetch</param>
/// <param name="query">Prompt query. Only used in the case of hybrid search</param>
/// <param name="target">Source vector to compare for similarity</param>
/// <param name="minSimilarity">Minimum similarity threshold</param>
/// <param name="filterSql">SQL filter to apply</param>
/// <param name="sqlUserValues">List of user values passed with placeholders to avoid SQL injection</param>
/// <param name="limit">Max number of records to retrieve</param>
/// <param name="offset">Records to skip from the top</param>
/// <param name="withEmbeddings">Whether to include embedding vectors</param>
/// <param name="useHybridSearch">Whether to use hybrid search or vector search</param>
/// <param name="cancellationToken">Async task cancellation token</param>
public async IAsyncEnumerable<(PostgresMemoryRecord record, double similarity)> GetSimilarAsync(
string tableName,
string query,
Vector target,
double minSimilarity,
string? filterSql = null,
Dictionary<string, object>? sqlUserValues = null,
int limit = 1,
int offset = 0,
bool withEmbeddings = false,
bool useHybridSearch = false,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
tableName = this.WithSchemaAndTableNamePrefix(tableName);
Expand All @@ -415,12 +432,15 @@ DO UPDATE SET
string columns = withEmbeddings ? this._columnsListWithEmbeddings : this._columnsListNoEmbeddings;

// Filtering logic, including filter by similarity
//
filterSql = filterSql?.Trim().Replace(PostgresSchema.PlaceholdersTags, this._colTags, StringComparison.Ordinal);
if (string.IsNullOrWhiteSpace(filterSql))
{
filterSql = "TRUE";
}

string filterSqlHybridText = filterSql;

var maxDistance = 1 - minSimilarity;
filterSql += $" AND {this._colEmbedding} <=> @embedding < @maxDistance";

Expand All @@ -440,16 +460,51 @@ DO UPDATE SET
#pragma warning disable CA2100 // SQL reviewed
string colDistance = "__distance";

// When using 1 - (embedding <=> target) the index is not being used, therefore we calculate
// the similarity (1 - distance) later. Furthermore, colDistance can't be used in the WHERE clause.
cmd.CommandText = @$"
SELECT {columns}, {this._colEmbedding} <=> @embedding AS {colDistance}
FROM {tableName}
WHERE {filterSql}
ORDER BY {colDistance} ASC
if (useHybridSearch)
{
// When using 1 - (embedding <=> target) the index is not being used, therefore we calculate
// the similarity (1 - distance) later. Furthermore, colDistance can't be used in the WHERE clause.
cmd.CommandText = @$"
WITH semantic_search AS (
SELECT {this._columnsListHybrid}, RANK () OVER (ORDER BY {this._colEmbedding} <=> @embedding) AS rank
FROM {tableName}
WHERE {filterSql}
ORDER BY {this._colEmbedding} <=> @embedding
LIMIT @limit
),
keyword_search AS (
SELECT {this._columnsListHybrid}, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('{this._textSearchLanguage}', {this._colContent}), query) DESC)
FROM {tableName}, plainto_tsquery('{this._textSearchLanguage}', @query) query
WHERE {filterSqlHybridText} AND to_tsvector('{this._textSearchLanguage}', {this._colContent}) @@ query
ORDER BY ts_rank_cd(to_tsvector('{this._textSearchLanguage}', {this._colContent}), query) DESC
LIMIT @limit
)
SELECT
{this._columnsListHybridCoalesce},
COALESCE(1.0 / ({this._rrfK} + semantic_search.rank), 0.0) +
COALESCE(1.0 / ({this._rrfK} + keyword_search.rank), 0.0) AS {colDistance}
FROM semantic_search
FULL OUTER JOIN keyword_search ON semantic_search.{this._colId} = keyword_search.{this._colId}
ORDER BY {colDistance} DESC
LIMIT @limit
OFFSET @offset
";
cmd.Parameters.AddWithValue("@query", query);
cmd.Parameters.AddWithValue("@minSimilarity", minSimilarity);
}
else
{
// When using 1 - (embedding <=> target) the index is not being used, therefore we calculate
// the similarity (1 - distance) later. Furthermore, colDistance can't be used in the WHERE clause.
cmd.CommandText = @$"
SELECT {columns}, {this._colEmbedding} <=> @embedding AS {colDistance}
FROM {tableName}
WHERE {filterSql}
ORDER BY {colDistance} ASC
LIMIT @limit
OFFSET @offset
";
}

cmd.Parameters.AddWithValue("@embedding", target);
cmd.Parameters.AddWithValue("@maxDistance", maxDistance);
Expand Down Expand Up @@ -692,7 +747,11 @@ public async ValueTask DisposeAsync()
private readonly string _colPayload;
private readonly string _columnsListNoEmbeddings;
private readonly string _columnsListWithEmbeddings;
private readonly string _columnsListHybrid;
private readonly string _columnsListHybridCoalesce;
private readonly bool _dbNamePresent;
private readonly string _textSearchLanguage;
private readonly int _rrfK;

/// <summary>
/// Try to connect to PG, handling exceptions in case the DB doesn't exist
Expand Down
18 changes: 18 additions & 0 deletions extensions/Postgres/Postgres/PostgresConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,24 @@ public class PostgresConfig
/// </summary>
public List<string> CreateTableSql { get; set; } = [];

/// <summary>
/// Important: when using hybrid search, relevance scores
/// are very different from when using just vector search.
/// </summary>
public bool UseHybridSearch { get; set; } = false;

/// <summary>
/// Defines the dictionary language used for the textual part of hybrid search.
/// see: https://www.postgresql.org/docs/current/textsearch-dictionaries.html
/// This query can help you to get the list of dictionaries: SELECT * FROM pg_catalog.pg_ts_dict;
/// </summary>
public string TextSearchLanguage { get; set; } = "english";

/// <summary>
/// Reciprocal Ranked Fusion to score results of Hybrid Search
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the "K" for?

"RRF" => "Reciprocal Ranked Fusion"
"RRFK" => ?

pls add a link to the documentation

/// </summary>
public int RRFK { get; set; } = 50;

/// <summary>
/// Create a new instance of the configuration
/// </summary>
Expand Down
5 changes: 5 additions & 0 deletions extensions/Postgres/Postgres/PostgresMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ public sealed class PostgresMemory : IMemoryDb, IDisposable, IAsyncDisposable
private readonly ITextEmbeddingGenerator _embeddingGenerator;
private readonly ILogger<PostgresMemory> _log;

private readonly bool _useHybridSearch;

/// <summary>
/// Create a new instance of Postgres KM connector
/// </summary>
Expand All @@ -41,6 +43,7 @@ public PostgresMemory(
ILoggerFactory? loggerFactory = null)
{
this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger<PostgresMemory>();
this._useHybridSearch = config.UseHybridSearch;

this._embeddingGenerator = embeddingGenerator;
if (this._embeddingGenerator == null)
Expand Down Expand Up @@ -160,12 +163,14 @@ await this._db.UpsertAsync(

var records = this._db.GetSimilarAsync(
index,
query: text,
target: new Vector(textEmbedding.Data),
minSimilarity: minRelevance,
filterSql: sql,
sqlUserValues: unsafeSqlUserValues,
limit: limit,
withEmbeddings: withEmbeddings,
useHybridSearch: this._useHybridSearch,
cancellationToken: cancellationToken).ConfigureAwait(false);

await foreach ((PostgresMemoryRecord record, double similarity) result in records)
Expand Down
Loading