diff --git a/FEATURE_MATRIX.md b/FEATURE_MATRIX.md index 294ad5e629f0..2f40e7fdb064 100644 --- a/FEATURE_MATRIX.md +++ b/FEATURE_MATRIX.md @@ -51,6 +51,7 @@ | Weaviate (Memory) | ❌ | ❌ | Vector optimized | | CosmosDB (Memory) | ✅ | ❌ | CosmosDB is not optimized for vector storage | | Sqlite (Memory) | ✅ | ❌ | Sqlite is not optimized for vector storage | +| Postgres (Memory) | ✅ | ❌ | Vector optimized (required the [pgvector](https://github.com/pgvector/pgvector) extension) | | Azure Cognitive Search | ❌ | ❌ | | | MsGraph | ✅ | ❌ | Contains connectors for OneDrive, Outlook, ToDos, and Organization Hierarchies | | Document Skills | ✅ | ❌ | Currently only supports Word documents | diff --git a/dotnet/Directory.Packages.props b/dotnet/Directory.Packages.props index e08eb2dc578d..ffa5ee01deb9 100644 --- a/dotnet/Directory.Packages.props +++ b/dotnet/Directory.Packages.props @@ -7,6 +7,8 @@ + + diff --git a/dotnet/SK-dotnet.sln b/dotnet/SK-dotnet.sln index bf226ddef523..3981d26ffe1a 100644 --- a/dotnet/SK-dotnet.sln +++ b/dotnet/SK-dotnet.sln @@ -67,6 +67,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Memory.Sqlite", EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Memory.CosmosDB", "src\Connectors\Connectors.Memory.CosmosDB\Connectors.Memory.CosmosDB.csproj", "{EA61C289-7928-4B78-A9C1-7AAD61F907CD}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.Memory.Postgres", "src\Connectors\Connectors.Memory.Postgres\Connectors.Memory.Postgres.csproj", "{C9F957FA-A70F-4A6D-8F95-23FCD7F4FB87}" +EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Connectors.AI.OpenAI", "src\Connectors\Connectors.AI.OpenAI\Connectors.AI.OpenAI.csproj", "{AFA81EB7-F869-467D-8A90-744305D80AAC}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "SemanticKernel.Abstractions", "src\SemanticKernel.Abstractions\SemanticKernel.Abstractions.csproj", "{627742DB-1E52-468A-99BD-6FF1A542D25B}" @@ -215,6 +217,12 @@ Global {EA61C289-7928-4B78-A9C1-7AAD61F907CD}.Publish|Any CPU.Build.0 = Release|Any CPU {EA61C289-7928-4B78-A9C1-7AAD61F907CD}.Release|Any CPU.ActiveCfg = Release|Any CPU {EA61C289-7928-4B78-A9C1-7AAD61F907CD}.Release|Any CPU.Build.0 = Release|Any CPU + {C9F957FA-A70F-4A6D-8F95-23FCD7F4FB87}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {C9F957FA-A70F-4A6D-8F95-23FCD7F4FB87}.Debug|Any CPU.Build.0 = Debug|Any CPU + {C9F957FA-A70F-4A6D-8F95-23FCD7F4FB87}.Publish|Any CPU.ActiveCfg = Publish|Any CPU + {C9F957FA-A70F-4A6D-8F95-23FCD7F4FB87}.Publish|Any CPU.Build.0 = Publish|Any CPU + {C9F957FA-A70F-4A6D-8F95-23FCD7F4FB87}.Release|Any CPU.ActiveCfg = Release|Any CPU + {C9F957FA-A70F-4A6D-8F95-23FCD7F4FB87}.Release|Any CPU.Build.0 = Release|Any CPU {AFA81EB7-F869-467D-8A90-744305D80AAC}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {AFA81EB7-F869-467D-8A90-744305D80AAC}.Debug|Any CPU.Build.0 = Debug|Any CPU {AFA81EB7-F869-467D-8A90-744305D80AAC}.Publish|Any CPU.ActiveCfg = Publish|Any CPU @@ -308,6 +316,7 @@ Global {5DEBAA62-F117-496A-8778-FED3604B70E2} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} {EC004F12-2F60-4EDD-B3CD-3A504900D929} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} {EA61C289-7928-4B78-A9C1-7AAD61F907CD} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} + {C9F957FA-A70F-4A6D-8F95-23FCD7F4FB87} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} {AFA81EB7-F869-467D-8A90-744305D80AAC} = {0247C2C9-86C3-45BA-8873-28B0948EDC0C} {627742DB-1E52-468A-99BD-6FF1A542D25B} = {831DDCA2-7D2C-4C31-80DB-6BDB3E1F7AE0} {E3299033-EB81-4C4C-BCD9-E8DC40937969} = {831DDCA2-7D2C-4C31-80DB-6BDB3E1F7AE0} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj b/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj new file mode 100644 index 000000000000..08f9a84f751d --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/Connectors.Memory.Postgres.csproj @@ -0,0 +1,28 @@ + + + + + Microsoft.SemanticKernel.Connectors.Memory.Postgres + $(AssemblyName) + netstandard2.0 + + + + + + + + Semantic Kernel - Postgres Connector + Postgres(with pgvector extension) connector for Semantic Kernel skills and semantic memory + + + + + + + + + + + + diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/Database.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/Database.cs new file mode 100644 index 000000000000..59d0d30ca5c5 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/Database.cs @@ -0,0 +1,348 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Npgsql; +using Pgvector; + +namespace Microsoft.SemanticKernel.Connectors.Memory.Postgres; + +/// +/// A postgres memory entry. +/// +internal struct DatabaseEntry +{ + /// + /// Unique identifier of the memory entry. + /// + public string Key { get; set; } + + /// + /// Metadata as a string. + /// + public string MetadataString { get; set; } + + /// + /// The embedding data as a . + /// + public Vector? Embedding { get; set; } + + /// + /// Optional timestamp. + /// + public long? Timestamp { get; set; } +} + +/// +/// The class for managing postgres database operations. +/// +internal sealed class Database +{ + private const string TableName = "sk_memory_table"; + + /// + /// Create pgvector extensions. + /// + /// An opened instance. + /// The to monitor for cancellation requests. The default is . + /// + public async Task CreatePgVectorExtensionAsync(NpgsqlConnection conn, CancellationToken cancellationToken = default) + { + using NpgsqlCommand cmd = conn.CreateCommand(); + cmd.CommandText = "CREATE EXTENSION IF NOT EXISTS vector"; + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + await conn.ReloadTypesAsync().ConfigureAwait(false); + } + + /// + /// Create memory table. + /// + /// An opened instance. + /// Vector size of embedding column + /// The to monitor for cancellation requests. The default is . + /// + public async Task CreateTableAsync(NpgsqlConnection conn, int vectorSize, CancellationToken cancellationToken = default) + { + await this.CreatePgVectorExtensionAsync(conn, cancellationToken).ConfigureAwait(false); + + using NpgsqlCommand cmd = conn.CreateCommand(); +#pragma warning disable CA2100 // Review SQL queries for security vulnerabilities + cmd.CommandText = $@" + CREATE TABLE IF NOT EXISTS {TableName} ( + collection TEXT, + key TEXT, + metadata TEXT, + embedding vector({vectorSize}), + timestamp BIGINT, + PRIMARY KEY(collection, key))"; +#pragma warning restore CA2100 // Review SQL queries for security vulnerabilities + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + + /// + /// Create index for memory table. + /// + /// An opened instance. + /// The to monitor for cancellation requests. The default is . + /// + public async Task CreateIndexAsync(NpgsqlConnection conn, CancellationToken cancellationToken = default) + { + using NpgsqlCommand cmd = conn.CreateCommand(); + cmd.CommandText = $@" + CREATE INDEX IF NOT EXISTS {TableName}_ivfflat_embedding_vector_cosine_ops_idx + ON {TableName} USING ivfflat (embedding vector_cosine_ops) WITH (lists = 1000)"; + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + + /// + /// Create a collection. + /// + /// An opened instance. + /// The name assigned to a collection of entries. + /// The to monitor for cancellation requests. The default is . + /// + public async Task CreateCollectionAsync(NpgsqlConnection conn, string collectionName, CancellationToken cancellationToken = default) + { + if (await this.DoesCollectionExistsAsync(conn, collectionName, cancellationToken).ConfigureAwait(false)) + { + // Collection already exists + return; + } + + using NpgsqlCommand cmd = conn.CreateCommand(); + cmd.CommandText = $@" + INSERT INTO {TableName} (collection, key) + VALUES(@collection, '')"; + cmd.Parameters.AddWithValue("@collection", collectionName); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + + /// + /// Upsert entry into a collection. + /// + /// An opened instance. + /// The name assigned to a collection of entries. + /// The key of the entry to upsert. + /// The metadata of the entry. + /// The embedding of the entry. + /// The timestamp of the entry + /// The to monitor for cancellation requests. The default is . + /// + public async Task UpsertAsync(NpgsqlConnection conn, + string collectionName, string key, string? metadata, Vector? embedding, long? timestamp, CancellationToken cancellationToken = default) + { + using NpgsqlCommand cmd = conn.CreateCommand(); + cmd.CommandText = $@" + INSERT INTO {TableName} (collection, key, metadata, embedding, timestamp) + VALUES(@collection, @key, @metadata, @embedding, @timestamp) + ON CONFLICT (collection, key) + DO UPDATE SET metadata=@metadata, embedding=@embedding, timestamp=@timestamp"; + cmd.Parameters.AddWithValue("@collection", collectionName); + cmd.Parameters.AddWithValue("@key", key); + cmd.Parameters.AddWithValue("@metadata", metadata ?? string.Empty); + cmd.Parameters.AddWithValue("@embedding", embedding ?? (object)DBNull.Value); + cmd.Parameters.AddWithValue("@timestamp", timestamp ?? (object)DBNull.Value); + await cmd.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + + /// + /// Check if a collection exists. + /// + /// An opened instance. + /// The name assigned to a collection of entries. + /// The to monitor for cancellation requests. The default is . + /// + public async Task DoesCollectionExistsAsync(NpgsqlConnection conn, + string collectionName, + CancellationToken cancellationToken = default) + { + var collections = await this.GetCollectionsAsync(conn, cancellationToken).ToListAsync(cancellationToken).ConfigureAwait(false); + return collections.Contains(collectionName); + } + + /// + /// Get all collections. + /// + /// An opened instance. + /// The to monitor for cancellation requests. The default is . + /// + public async IAsyncEnumerable GetCollectionsAsync(NpgsqlConnection conn, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + using NpgsqlCommand cmd = conn.CreateCommand(); + cmd.CommandText = $@" + SELECT DISTINCT(collection) + FROM {TableName}"; + + using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return dataReader.GetString(dataReader.GetOrdinal("collection")); + } + } + + /// + /// Gets the nearest matches to the . + /// + /// An opened instance. + /// The name assigned to a collection of entries. + /// The to compare the collection's embeddings with. + /// The maximum number of similarity results to return. + /// The minimum relevance threshold for returned results. + /// If true, the embeddings will be returned in the entries. + /// The to monitor for cancellation requests. The default is . + /// + public async IAsyncEnumerable<(DatabaseEntry, double)> GetNearestMatchesAsync(NpgsqlConnection conn, + string collectionName, Vector embeddingFilter, int limit, double minRelevanceScore = 0, bool withEmbeddings = false, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var queryColumns = "collection, key, metadata, timestamp"; + if (withEmbeddings) + { + queryColumns = "*"; + } + + using NpgsqlCommand cmd = conn.CreateCommand(); + cmd.CommandText = @$" + SELECT * FROM (SELECT {queryColumns}, 1 - (embedding <=> @embedding) AS cosine_similarity FROM {TableName} + WHERE collection = @collection + ) AS sk_memory_cosine_similarity_table + WHERE cosine_similarity >= @min_relevance_score + ORDER BY cosine_similarity DESC + Limit @limit"; + cmd.Parameters.AddWithValue("@embedding", embeddingFilter); + cmd.Parameters.AddWithValue("@collection", collectionName); + cmd.Parameters.AddWithValue("@min_relevance_score", minRelevanceScore); + cmd.Parameters.AddWithValue("@limit", limit); + + using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + double cosineSimilarity = dataReader.GetDouble(dataReader.GetOrdinal("cosine_similarity")); + yield return (await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false), cosineSimilarity); + } + } + + /// + /// Read all entries from a collection + /// + /// An opened instance. + /// The name assigned to a collection of entries. + /// If true, the embeddings will be returned in the entries. + /// The to monitor for cancellation requests. The default is . + /// + public async IAsyncEnumerable ReadAllAsync(NpgsqlConnection conn, + string collectionName, bool withEmbeddings = false, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var queryColumns = "collection, key, metadata, timestamp"; + if (withEmbeddings) + { + queryColumns = "*"; + } + + using NpgsqlCommand cmd = conn.CreateCommand(); + cmd.CommandText = $@" + SELECT {queryColumns} FROM {TableName} + WHERE collection=@collection"; + cmd.Parameters.AddWithValue("@collection", collectionName); + + using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false); + } + } + + /// + /// Read a entry by its key. + /// + /// An opened instance. + /// The name assigned to a collection of entries. + /// The key of the entry to read. + /// If true, the embeddings will be returned in the entries. + /// The to monitor for cancellation requests. The default is . + /// + public async Task ReadAsync(NpgsqlConnection conn, + string collectionName, string key, bool withEmbeddings = false, + CancellationToken cancellationToken = default) + { + var queryColumns = "collection, key, metadata, timestamp"; + if (withEmbeddings) + { + queryColumns = "*"; + } + + using NpgsqlCommand cmd = conn.CreateCommand(); + cmd.CommandText = $@" + SELECT {queryColumns} FROM {TableName} + WHERE collection=@collection AND key=@key"; + cmd.Parameters.AddWithValue("@collection", collectionName); + cmd.Parameters.AddWithValue("@key", key); + + using var dataReader = await cmd.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + if (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + return await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false); + } + + return null; + } + + /// + /// Delete a collection. + /// + /// An opened instance. + /// The name assigned to a collection of entries. + /// The to monitor for cancellation requests. The default is . + /// + public Task DeleteCollectionAsync(NpgsqlConnection conn, string collectionName, CancellationToken cancellationToken = default) + { + using NpgsqlCommand cmd = conn.CreateCommand(); + cmd.CommandText = $@" + DELETE FROM {TableName} + WHERE collection=@collection"; + cmd.Parameters.AddWithValue("@collection", collectionName); + return cmd.ExecuteNonQueryAsync(cancellationToken); + } + + /// + /// Delete a entry by its key. + /// + /// An opened instance. + /// The name assigned to a collection of entries. + /// The key of the entry to delete. + /// The to monitor for cancellation requests. The default is . + /// + public Task DeleteAsync(NpgsqlConnection conn, string collectionName, string key, CancellationToken cancellationToken = default) + { + using NpgsqlCommand cmd = conn.CreateCommand(); + cmd.CommandText = $@" + DELETE FROM {TableName} + WHERE collection=@collection AND key=@key "; + cmd.Parameters.AddWithValue("@collection", collectionName); + cmd.Parameters.AddWithValue("@key", key); + return cmd.ExecuteNonQueryAsync(cancellationToken); + } + + /// + /// Read a entry. + /// + /// The to read. + /// If true, the embeddings will be returned in the entries. + /// The to monitor for cancellation requests. The default is . + /// + private async Task ReadEntryAsync(NpgsqlDataReader dataReader, bool withEmbeddings = false, CancellationToken cancellationToken = default) + { + string key = dataReader.GetString(dataReader.GetOrdinal("key")); + string metadata = dataReader.GetString(dataReader.GetOrdinal("metadata")); + Vector? embedding = withEmbeddings ? await dataReader.GetFieldValueAsync(dataReader.GetOrdinal("embedding"), cancellationToken).ConfigureAwait(false) : null; + long? timestamp = await dataReader.GetFieldValueAsync(dataReader.GetOrdinal("timestamp"), cancellationToken).ConfigureAwait(false); + return new DatabaseEntry() { Key = key, MetadataString = metadata, Embedding = embedding, Timestamp = timestamp }; + } +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryStore.cs b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryStore.cs new file mode 100644 index 000000000000..8b21c8482925 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/PostgresMemoryStore.cs @@ -0,0 +1,279 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.AI.Embeddings; +using Microsoft.SemanticKernel.Memory; +using Npgsql; +using Pgvector; +using Pgvector.Npgsql; + +namespace Microsoft.SemanticKernel.Connectors.Memory.Postgres; + +/// +/// An implementation of backed by a Postgres database with pgvector extension. +/// +public class PostgresMemoryStore : IMemoryStore, IDisposable +{ + /// + /// Connect a Postgres database + /// + /// Database connection string. If table does not exist, it will be created. + /// Embedding vector size + /// The to monitor for cancellation requests. The default is . + public static async Task ConnectAsync(string connectionString, int vectorSize, + CancellationToken cancellationToken = default) + { + var dataSourceBuilder = new NpgsqlDataSourceBuilder(connectionString); + // Use pgvector + dataSourceBuilder.UseVector(); + + var memoryStore = new PostgresMemoryStore(dataSourceBuilder.Build()); + using NpgsqlConnection dbConnection = await memoryStore._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + await memoryStore._dbConnector.CreatePgVectorExtensionAsync(dbConnection, cancellationToken).ConfigureAwait(false); + await memoryStore._dbConnector.CreateTableAsync(dbConnection, vectorSize, cancellationToken).ConfigureAwait(false); + await memoryStore._dbConnector.CreateIndexAsync(dbConnection, cancellationToken).ConfigureAwait(false); + return memoryStore; + } + + /// + public async Task CreateCollectionAsync(string collectionName, CancellationToken cancellationToken = default) + { + using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + await this._dbConnector.CreateCollectionAsync(dbConnection, collectionName, cancellationToken).ConfigureAwait(false); + } + + /// + public async Task DoesCollectionExistAsync(string collectionName, CancellationToken cancellationToken = default) + { + using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + return await this._dbConnector.DoesCollectionExistsAsync(dbConnection, collectionName, cancellationToken).ConfigureAwait(false); + } + + /// + public async IAsyncEnumerable GetCollectionsAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + await foreach (var collection in this._dbConnector.GetCollectionsAsync(dbConnection, cancellationToken).ConfigureAwait(false)) + { + yield return collection; + } + } + + /// + public async Task DeleteCollectionAsync(string collectionName, CancellationToken cancellationToken = default) + { + using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + await this._dbConnector.DeleteCollectionAsync(dbConnection, collectionName, cancellationToken).ConfigureAwait(false); + } + + /// + public async Task UpsertAsync(string collectionName, MemoryRecord record, CancellationToken cancellationToken = default) + { + using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + return await this.InternalUpsertAsync(dbConnection, collectionName, record, cancellationToken).ConfigureAwait(false); + } + + /// + public async IAsyncEnumerable UpsertBatchAsync(string collectionName, IEnumerable records, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + foreach (var record in records) + { + yield return await this.InternalUpsertAsync(dbConnection, collectionName, record, cancellationToken).ConfigureAwait(false); + } + } + + /// + public async Task GetAsync(string collectionName, string key, bool withEmbedding = false, CancellationToken cancellationToken = default) + { + using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + return await this.InternalGetAsync(dbConnection, collectionName, key, withEmbedding, cancellationToken).ConfigureAwait(false); + } + + /// + public async IAsyncEnumerable GetBatchAsync(string collectionName, IEnumerable keys, bool withEmbeddings = false, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + foreach (var key in keys) + { + var result = await this.InternalGetAsync(dbConnection, collectionName, key, withEmbeddings, cancellationToken).ConfigureAwait(false); + if (result != null) + { + yield return result; + } + else + { + yield break; + } + } + } + + /// + public async Task RemoveAsync(string collectionName, string key, CancellationToken cancellationToken = default) + { + using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + await this._dbConnector.DeleteAsync(dbConnection, collectionName, key, cancellationToken).ConfigureAwait(false); + } + + /// + public async Task RemoveBatchAsync(string collectionName, IEnumerable keys, CancellationToken cancellationToken = default) + { + using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + foreach (var key in keys) + { + await this._dbConnector.DeleteAsync(dbConnection, collectionName, key, cancellationToken).ConfigureAwait(false); + } + } + + /// + public async IAsyncEnumerable<(MemoryRecord, double)> GetNearestMatchesAsync( + string collectionName, + Embedding embedding, + int limit, + double minRelevanceScore = 0, + bool withEmbeddings = false, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (limit <= 0) + { + yield break; + } + + using NpgsqlConnection dbConnection = await this._dataSource.OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + + IAsyncEnumerable<(DatabaseEntry, double)> results = this._dbConnector.GetNearestMatchesAsync( + dbConnection, + collectionName: collectionName, + embeddingFilter: new Vector(embedding.Vector.ToArray()), + limit: limit, + minRelevanceScore: minRelevanceScore, + withEmbeddings: withEmbeddings, + cancellationToken: cancellationToken); + + await foreach (var (entry, cosineSimilarity) in results.ConfigureAwait(false)) + { + MemoryRecord record = MemoryRecord.FromJsonMetadata( + json: entry.MetadataString, + withEmbeddings && entry.Embedding != null ? new Embedding(entry.Embedding!.ToArray()) : Embedding.Empty, + entry.Key, + ParseTimestamp(entry.Timestamp)); + yield return (record, cosineSimilarity); + } + } + + /// + public async Task<(MemoryRecord, double)?> GetNearestMatchAsync(string collectionName, Embedding embedding, double minRelevanceScore = 0, bool withEmbedding = false, + CancellationToken cancellationToken = default) + { + return await this.GetNearestMatchesAsync( + collectionName: collectionName, + embedding: embedding, + limit: 1, + minRelevanceScore: minRelevanceScore, + withEmbeddings: withEmbedding, + cancellationToken: cancellationToken).FirstOrDefaultAsync(cancellationToken).ConfigureAwait(false); + } + + /// + public void Dispose() + { + this.Dispose(true); + GC.SuppressFinalize(this); + } + + #region protected ================================================================================ + + protected virtual void Dispose(bool disposing) + { + if (!this._disposedValue) + { + if (disposing) + { + this._dataSource.Dispose(); + } + + this._disposedValue = true; + } + } + + #endregion + + #region private ================================================================================ + + private readonly Database _dbConnector; + private readonly NpgsqlDataSource _dataSource; + private bool _disposedValue; + + /// + /// Constructor + /// + /// Postgres data source. + private PostgresMemoryStore(NpgsqlDataSource dataSource) + { + this._dataSource = dataSource; + this._dbConnector = new Database(); + this._disposedValue = false; + } + + private static long? ToTimestampLong(DateTimeOffset? timestamp) + { + return timestamp?.ToUnixTimeMilliseconds(); + } + + private static DateTimeOffset? ParseTimestamp(long? timestamp) + { + if (timestamp.HasValue) + { + return DateTimeOffset.FromUnixTimeMilliseconds(timestamp.Value); + } + + return null; + } + + private async Task InternalUpsertAsync(NpgsqlConnection connection, string collectionName, MemoryRecord record, CancellationToken cancellationToken) + { + record.Key = record.Metadata.Id; + + await this._dbConnector.UpsertAsync( + conn: connection, + collectionName: collectionName, + key: record.Key, + metadata: record.GetSerializedMetadata(), + embedding: new Vector(record.Embedding.Vector.ToArray()), + timestamp: ToTimestampLong(record.Timestamp), + cancellationToken: cancellationToken).ConfigureAwait(false); + + return record.Key; + } + + private async Task InternalGetAsync(NpgsqlConnection connection, string collectionName, string key, bool withEmbedding, CancellationToken cancellationToken) + { + DatabaseEntry? entry = await this._dbConnector.ReadAsync(connection, collectionName, key, withEmbedding, cancellationToken).ConfigureAwait(false); + + if (!entry.HasValue) { return null; } + + if (withEmbedding) + { + return MemoryRecord.FromJsonMetadata( + json: entry.Value.MetadataString, + embedding: entry.Value.Embedding != null ? new Embedding(entry.Value.Embedding.ToArray()) : Embedding.Empty, + entry.Value.Key, + ParseTimestamp(entry.Value.Timestamp)); + } + + return MemoryRecord.FromJsonMetadata( + json: entry.Value.MetadataString, + Embedding.Empty, + entry.Value.Key, + ParseTimestamp(entry.Value.Timestamp)); + } + + #endregion +} diff --git a/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md b/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md new file mode 100644 index 000000000000..6685e3e0fd77 --- /dev/null +++ b/dotnet/src/Connectors/Connectors.Memory.Postgres/README.md @@ -0,0 +1,30 @@ +# Microsoft.SemanticKernel.Connectors.Memory.Postgres + +This connector uses Postgres to implement Semantic Memory. It requires the [pgvector](https://github.com/pgvector/pgvector) extension to be installed on Postgres to implement vector similarity search. + +## What is pgvector? + +[pgvector](https://github.com/pgvector/pgvector) is an open-source vector similarity search engine for Postgres. It supports exact and approximate nearest neighbor search, L2 distance, inner product, and cosine distance. + +How to install the pgvector extension, please refer to its [documentation](https://github.com/pgvector/pgvector#installation). + +## Quick start + +1. To install pgvector using Docker: + +```bash +docker run -d --name postgres-pgvector -p 5432:5432 -e POSTGRES_PASSWORD=mysecretpassword ankane/pgvector +``` + +2. To use Postgres as a semantic memory store: + +```csharp +using PostgresMemoryStore memoryStore = await PostgresMemoryStore.ConnectAsync("Host=localhost;Port=5432;Database=sk_memory;User Id=postgres;Password=mysecretpassword", vectorSize: 1536); + +IKernel kernel = Kernel.Builder + .WithLogger(ConsoleLogger.Log) + .Configure(c => c.AddOpenAITextEmbeddingGenerationService("text-embedding-ada-002", Env.Var("OPENAI_API_KEY"))) + .WithMemoryStorage(memoryStore) + .Build(); +``` + diff --git a/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs new file mode 100644 index 000000000000..c74775c4922b --- /dev/null +++ b/dotnet/src/IntegrationTests/Connectors/Memory/Postgres/PostgresMemoryStoreTests.cs @@ -0,0 +1,721 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Globalization; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.SemanticKernel.AI.Embeddings; +using Microsoft.SemanticKernel.Connectors.Memory.Postgres; +using Microsoft.SemanticKernel.Memory; +using Npgsql; +using Xunit; + +namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres; + +/// +/// Integration tests of . +/// +public class PostgresMemoryStoreTests : IDisposable +{ + // Set null enable tests + private const string SkipOrNot = "Required posgres with pgvector extension"; + + private const string ConnectionString = "Host=localhost;Database={0};User Id=postgres"; + private readonly string _databaseName; + + private bool _disposedValue = false; + + public PostgresMemoryStoreTests() + { +#pragma warning disable CA5394 + this._databaseName = $"sk_pgvector_dotnet_it_{Random.Shared.Next(0, 1000)}"; +#pragma warning restore CA5394 + } + + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + this.Dispose(disposing: true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (!this._disposedValue) + { + if (disposing) + { + using NpgsqlConnection conn = new NpgsqlConnection(string.Format(CultureInfo.CurrentCulture, ConnectionString, "postgres")); + conn.Open(); +#pragma warning disable CA2100 // Review SQL queries for security vulnerabilities + using NpgsqlCommand command = new NpgsqlCommand($"DROP DATABASE IF EXISTS \"{this._databaseName}\"", conn); +#pragma warning restore CA2100 // Review SQL queries for security vulnerabilities + command.ExecuteNonQuery(); + } + + this._disposedValue = true; + } + } + + private int _collectionNum = 0; + + private async Task TryCreateDatabaseAsync() + { + using NpgsqlConnection conn = new NpgsqlConnection(string.Format(CultureInfo.CurrentCulture, ConnectionString, "postgres")); + await conn.OpenAsync(); + using NpgsqlCommand checkCmd = new NpgsqlCommand("SELECT COUNT(*) FROM pg_database WHERE datname = @databaseName", conn); + checkCmd.Parameters.AddWithValue("@databaseName", this._databaseName); + + var count = (long?)await checkCmd.ExecuteScalarAsync(); + if (count == 0) + { +#pragma warning disable CA2100 // Review SQL queries for security vulnerabilities + using var createCmd = new NpgsqlCommand($"CREATE DATABASE \"{this._databaseName}\"", conn); +#pragma warning restore CA2100 // Review SQL queries for security vulnerabilities + await createCmd.ExecuteNonQueryAsync(); + } + } + + private async Task CreateMemoryStoreAsync() + { + await this.TryCreateDatabaseAsync(); + return await PostgresMemoryStore.ConnectAsync(string.Format(CultureInfo.CurrentCulture, ConnectionString, this._databaseName), vectorSize: 3); + } + + private IEnumerable CreateBatchRecords(int numRecords) + { + Assert.True(numRecords % 2 == 0, "Number of records must be even"); + Assert.True(numRecords > 0, "Number of records must be greater than 0"); + + IEnumerable records = new List(numRecords); + for (int i = 0; i < numRecords / 2; i++) + { + var testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 1, 1 })); + records = records.Append(testRecord); + } + + for (int i = numRecords / 2; i < numRecords; i++) + { + var testRecord = MemoryRecord.ReferenceRecord( + externalId: "test" + i, + sourceName: "sourceName" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 2, 3 })); + records = records.Append(testRecord); + } + + return records; + } + + [Fact(Skip = SkipOrNot)] + public async Task InitializeDbConnectionSucceedsAsync() + { + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + // Assert + Assert.NotNull(db); + } + + [Fact(Skip = SkipOrNot)] + public async Task ItCanCreateAndGetCollectionAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + + // Act + await db.CreateCollectionAsync(collection); + var collections = db.GetCollectionsAsync(); + + // Assert + Assert.NotEmpty(collections.ToEnumerable()); + Assert.True(await collections.ContainsAsync(collection)); + } + + [Fact(Skip = SkipOrNot)] + public async Task ItCanCheckIfCollectionExistsAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + string collection = "my_collection"; + this._collectionNum++; + + // Act + await db.CreateCollectionAsync(collection); + + // Assert + Assert.True(await db.DoesCollectionExistAsync("my_collection")); + Assert.False(await db.DoesCollectionExistAsync("my_collection2")); + } + + [Fact(Skip = SkipOrNot)] + public async Task CreatingDuplicateCollectionDoesNothingAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + + // Act + await db.CreateCollectionAsync(collection); + var collections = db.GetCollectionsAsync(); + await db.CreateCollectionAsync(collection); + + // Assert + var collections2 = db.GetCollectionsAsync(); + Assert.Equal(await collections.CountAsync(), await collections.CountAsync()); + } + + [Fact(Skip = SkipOrNot)] + public async Task CollectionsCanBeDeletedAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + await db.CreateCollectionAsync(collection); + var collections = await db.GetCollectionsAsync().ToListAsync(); + Assert.True(collections.Count > 0); + + // Act + foreach (var c in collections) + { + await db.DeleteCollectionAsync(c); + } + + // Assert + var collections2 = db.GetCollectionsAsync(); + Assert.True(await collections2.CountAsync() == 0); + } + + [Fact(Skip = SkipOrNot)] + public async Task ItCanInsertIntoNonExistentCollectionAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: "test", + text: "text", + description: "description", + embedding: new Embedding(new float[] { 1, 2, 3 }), + key: null, + timestamp: null); + + // Arrange + var key = await db.UpsertAsync("random collection", testRecord); + var actual = await db.GetAsync("random collection", key, true); + + // Assert + Assert.NotNull(actual); + Assert.Equal(testRecord.Metadata.Id, key); + Assert.Equal(testRecord.Metadata.Id, actual.Key); + Assert.Equal(testRecord.Embedding.Vector, actual.Embedding.Vector); + Assert.Equal(testRecord.Metadata.Text, actual.Metadata.Text); + Assert.Equal(testRecord.Metadata.Description, actual.Metadata.Description); + Assert.Equal(testRecord.Metadata.ExternalSourceName, actual.Metadata.ExternalSourceName); + Assert.Equal(testRecord.Metadata.Id, actual.Metadata.Id); + } + + [Fact(Skip = SkipOrNot)] + public async Task GetAsyncReturnsEmptyEmbeddingUnlessSpecifiedAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: "test", + text: "text", + description: "description", + embedding: new Embedding(new float[] { 1, 2, 3 }), + key: null, + timestamp: null); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + + // Act + await db.CreateCollectionAsync(collection); + var key = await db.UpsertAsync(collection, testRecord); + var actualDefault = await db.GetAsync(collection, key); + var actualWithEmbedding = await db.GetAsync(collection, key, true); + + // Assert + Assert.NotNull(actualDefault); + Assert.NotNull(actualWithEmbedding); + Assert.Empty(actualDefault.Embedding.Vector); + Assert.NotEmpty(actualWithEmbedding.Embedding.Vector); + } + + [Fact(Skip = SkipOrNot)] + public async Task ItCanUpsertAndRetrieveARecordWithNoTimestampAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: "test", + text: "text", + description: "description", + embedding: new Embedding(new float[] { 1, 2, 3 }), + key: null, + timestamp: null); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + + // Act + await db.CreateCollectionAsync(collection); + var key = await db.UpsertAsync(collection, testRecord); + var actual = await db.GetAsync(collection, key, true); + + // Assert + Assert.NotNull(actual); + Assert.Equal(testRecord.Metadata.Id, key); + Assert.Equal(testRecord.Metadata.Id, actual.Key); + Assert.Equal(testRecord.Embedding.Vector, actual.Embedding.Vector); + Assert.Equal(testRecord.Metadata.Text, actual.Metadata.Text); + Assert.Equal(testRecord.Metadata.Description, actual.Metadata.Description); + Assert.Equal(testRecord.Metadata.ExternalSourceName, actual.Metadata.ExternalSourceName); + Assert.Equal(testRecord.Metadata.Id, actual.Metadata.Id); + } + + [Fact(Skip = SkipOrNot)] + public async Task ItCanUpsertAndRetrieveARecordWithTimestampAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: "test", + text: "text", + description: "description", + embedding: new Embedding(new float[] { 1, 2, 3 }), + key: null, + timestamp: DateTimeOffset.UtcNow); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + + // Act + await db.CreateCollectionAsync(collection); + var key = await db.UpsertAsync(collection, testRecord); + var actual = await db.GetAsync(collection, key, true); + + // Assert + Assert.NotNull(actual); + Assert.Equal(testRecord.Metadata.Id, key); + Assert.Equal(testRecord.Metadata.Id, actual.Key); + Assert.Equal(testRecord.Embedding.Vector, actual.Embedding.Vector); + Assert.Equal(testRecord.Metadata.Text, actual.Metadata.Text); + Assert.Equal(testRecord.Metadata.Description, actual.Metadata.Description); + Assert.Equal(testRecord.Metadata.ExternalSourceName, actual.Metadata.ExternalSourceName); + Assert.Equal(testRecord.Metadata.Id, actual.Metadata.Id); + } + + [Fact(Skip = SkipOrNot)] + public async Task UpsertReplacesExistingRecordWithSameIdAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + string commonId = "test"; + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: commonId, + text: "text", + description: "description", + embedding: new Embedding(new float[] { 1, 2, 3 })); + MemoryRecord testRecord2 = MemoryRecord.LocalRecord( + id: commonId, + text: "text2", + description: "description2", + embedding: new Embedding(new float[] { 1, 2, 4 })); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + + // Act + await db.CreateCollectionAsync(collection); + var key = await db.UpsertAsync(collection, testRecord); + var key2 = await db.UpsertAsync(collection, testRecord2); + var actual = await db.GetAsync(collection, key, true); + + // Assert + Assert.NotNull(actual); + Assert.Equal(testRecord.Metadata.Id, key); + Assert.Equal(testRecord2.Metadata.Id, actual.Key); + Assert.NotEqual(testRecord.Embedding.Vector, actual.Embedding.Vector); + Assert.Equal(testRecord2.Embedding.Vector, actual.Embedding.Vector); + Assert.NotEqual(testRecord.Metadata.Text, actual.Metadata.Text); + Assert.Equal(testRecord2.Metadata.Description, actual.Metadata.Description); + } + + [Fact(Skip = SkipOrNot)] + public async Task ExistingRecordCanBeRemovedAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: "test", + text: "text", + description: "description", + embedding: new Embedding(new float[] { 1, 2, 3 })); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + + // Act + await db.CreateCollectionAsync(collection); + var key = await db.UpsertAsync(collection, testRecord); + await db.RemoveAsync(collection, key); + var actual = await db.GetAsync(collection, key); + + // Assert + Assert.Null(actual); + } + + [Fact(Skip = SkipOrNot)] + public async Task RemovingNonExistingRecordDoesNothingAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + + // Act + await db.CreateCollectionAsync(collection); + await db.RemoveAsync(collection, "key"); + var actual = await db.GetAsync(collection, "key"); + + // Assert + Assert.Null(actual); + } + + [Fact(Skip = SkipOrNot)] + public async Task ItCanListAllDatabaseCollectionsAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + string[] testCollections = { "random_collection1", "random_collection2", "random_collection3" }; + this._collectionNum += 3; + await db.CreateCollectionAsync(testCollections[0]); + await db.CreateCollectionAsync(testCollections[1]); + await db.CreateCollectionAsync(testCollections[2]); + + // Act + var collections = await db.GetCollectionsAsync().ToListAsync(); + + // Assert + foreach (var collection in testCollections) + { + Assert.True(await db.DoesCollectionExistAsync(collection)); + } + + Assert.NotNull(collections); + Assert.NotEmpty(collections); + Assert.Equal(testCollections.Length, collections.Count); + Assert.True(collections.Contains(testCollections[0]), + $"Collections does not contain the newly-created collection {testCollections[0]}"); + Assert.True(collections.Contains(testCollections[1]), + $"Collections does not contain the newly-created collection {testCollections[1]}"); + Assert.True(collections.Contains(testCollections[2]), + $"Collections does not contain the newly-created collection {testCollections[2]}"); + } + + [Fact(Skip = SkipOrNot)] + public async Task GetNearestMatchesReturnsAllResultsWithNoMinScoreAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + var compareEmbedding = new Embedding(new float[] { 1, 1, 1 }); + int topN = 4; + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + await db.CreateCollectionAsync(collection); + int i = 0; + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 1, 1 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { -1, -1, -1 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 2, 3 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { -1, -2, -3 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, -1, -2 })); + _ = await db.UpsertAsync(collection, testRecord); + + // Act + double threshold = -1; + var topNResults = db.GetNearestMatchesAsync(collection, compareEmbedding, limit: topN, minRelevanceScore: threshold).ToEnumerable().ToArray(); + + // Assert + Assert.Equal(topN, topNResults.Length); + for (int j = 0; j < topN - 1; j++) + { + int compare = topNResults[j].Item2.CompareTo(topNResults[j + 1].Item2); + Assert.True(compare >= 0); + } + } + + [Fact(Skip = SkipOrNot)] + public async Task GetNearestMatchAsyncReturnsEmptyEmbeddingUnlessSpecifiedAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + var compareEmbedding = new Embedding(new float[] { 1, 1, 1 }); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + await db.CreateCollectionAsync(collection); + int i = 0; + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 1, 1 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { -1, -1, -1 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 2, 3 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { -1, -2, -3 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, -1, -2 })); + _ = await db.UpsertAsync(collection, testRecord); + + // Act + double threshold = 0.75; + var topNResultDefault = await db.GetNearestMatchAsync(collection, compareEmbedding, minRelevanceScore: threshold); + var topNResultWithEmbedding = await db.GetNearestMatchAsync(collection, compareEmbedding, minRelevanceScore: threshold, withEmbedding: true); + + // Assert + Assert.NotNull(topNResultDefault); + Assert.NotNull(topNResultWithEmbedding); + Assert.Empty(topNResultDefault.Value.Item1.Embedding.Vector); + Assert.NotEmpty(topNResultWithEmbedding.Value.Item1.Embedding.Vector); + } + + [Fact(Skip = SkipOrNot)] + public async Task GetNearestMatchAsyncReturnsExpectedAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + var compareEmbedding = new Embedding(new float[] { 1, 1, 1 }); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + await db.CreateCollectionAsync(collection); + int i = 0; + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 1, 1 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { -1, -1, -1 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 2, 3 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { -1, -2, -3 })); + _ = await db.UpsertAsync(collection, testRecord); + + i++; + testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, -1, -2 })); + _ = await db.UpsertAsync(collection, testRecord); + + // Act + double threshold = 0.75; + var topNResult = await db.GetNearestMatchAsync(collection, compareEmbedding, minRelevanceScore: threshold); + + // Assert + Assert.NotNull(topNResult); + Assert.Equal("test0", topNResult.Value.Item1.Metadata.Id); + Assert.True(topNResult.Value.Item2 >= threshold); + } + + [Fact(Skip = SkipOrNot)] + public async Task GetNearestMatchesDifferentiatesIdenticalVectorsByKeyAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + var compareEmbedding = new Embedding(new float[] { 1, 1, 1 }); + int topN = 4; + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + await db.CreateCollectionAsync(collection); + + for (int i = 0; i < 10; i++) + { + MemoryRecord testRecord = MemoryRecord.LocalRecord( + id: "test" + i, + text: "text" + i, + description: "description" + i, + embedding: new Embedding(new float[] { 1, 1, 1 })); + _ = await db.UpsertAsync(collection, testRecord); + } + + // Act + var topNResults = db.GetNearestMatchesAsync(collection, compareEmbedding, limit: topN, minRelevanceScore: 0.75).ToEnumerable().ToArray(); + IEnumerable topNKeys = topNResults.Select(x => x.Item1.Key).ToImmutableSortedSet(); + + // Assert + Assert.Equal(topN, topNResults.Length); + Assert.Equal(topN, topNKeys.Count()); + + for (int i = 0; i < topNResults.Length; i++) + { + int compare = topNResults[i].Item2.CompareTo(0.75); + Assert.True(compare >= 0); + } + } + + [Fact(Skip = SkipOrNot)] + public async Task ItCanBatchUpsertRecordsAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + int numRecords = 10; + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + IEnumerable records = this.CreateBatchRecords(numRecords); + + // Act + await db.CreateCollectionAsync(collection); + var keys = db.UpsertBatchAsync(collection, records); + var resultRecords = db.GetBatchAsync(collection, keys.ToEnumerable()); + + // Assert + Assert.NotNull(keys); + Assert.Equal(numRecords, keys.ToEnumerable().Count()); + Assert.Equal(numRecords, resultRecords.ToEnumerable().Count()); + } + + [Fact(Skip = SkipOrNot)] + public async Task ItCanBatchGetRecordsAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + int numRecords = 10; + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + IEnumerable records = this.CreateBatchRecords(numRecords); + var keys = db.UpsertBatchAsync(collection, records); + + // Act + await db.CreateCollectionAsync(collection); + var results = db.GetBatchAsync(collection, keys.ToEnumerable()); + + // Assert + Assert.NotNull(keys); + Assert.NotNull(results); + Assert.Equal(numRecords, results.ToEnumerable().Count()); + } + + [Fact(Skip = SkipOrNot)] + public async Task ItCanBatchRemoveRecordsAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + int numRecords = 10; + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + IEnumerable records = this.CreateBatchRecords(numRecords); + await db.CreateCollectionAsync(collection); + + List keys = new(); + + // Act + await foreach (var key in db.UpsertBatchAsync(collection, records)) + { + keys.Add(key); + } + + await db.RemoveBatchAsync(collection, keys); + + // Assert + await foreach (var result in db.GetBatchAsync(collection, keys)) + { + Assert.Null(result); + } + } + + [Fact(Skip = SkipOrNot)] + public async Task DeletingNonExistentCollectionDoesNothingAsync() + { + // Arrange + using PostgresMemoryStore db = await this.CreateMemoryStoreAsync(); + string collection = "test_collection" + this._collectionNum; + this._collectionNum++; + + // Act + await db.DeleteCollectionAsync(collection); + } +} diff --git a/dotnet/src/IntegrationTests/IntegrationTests.csproj b/dotnet/src/IntegrationTests/IntegrationTests.csproj index 706210b64442..b0f5c7701b44 100644 --- a/dotnet/src/IntegrationTests/IntegrationTests.csproj +++ b/dotnet/src/IntegrationTests/IntegrationTests.csproj @@ -32,6 +32,7 @@ + diff --git a/samples/dotnet/kernel-syntax-examples/Example39_Postgres.cs b/samples/dotnet/kernel-syntax-examples/Example39_Postgres.cs new file mode 100644 index 000000000000..fecbec424488 --- /dev/null +++ b/samples/dotnet/kernel-syntax-examples/Example39_Postgres.cs @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Threading.Tasks; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Connectors.Memory.Postgres; +using Microsoft.SemanticKernel.Memory; +using RepoUtils; + +// ReSharper disable once InconsistentNaming +public static class Example39_Postgres +{ + private const string MemoryCollectionName = "postgres-test"; + + public static async Task RunAsync() + { + string connectionString = Env.Var("POSTGRES_CONNECTIONSTRING"); + using PostgresMemoryStore memoryStore = await PostgresMemoryStore.ConnectAsync(connectionString, vectorSize: 1536); + IKernel kernel = Kernel.Builder + .WithLogger(ConsoleLogger.Log) + .Configure(c => + { + c.AddOpenAITextCompletionService("text-davinci-003", Env.Var("OPENAI_API_KEY")); + c.AddOpenAITextEmbeddingGenerationService("text-embedding-ada-002", Env.Var("OPENAI_API_KEY")); + }) + .WithMemoryStorage(memoryStore) + .Build(); + + Console.WriteLine("== Printing Collections in DB =="); + var collections = memoryStore.GetCollectionsAsync(); + await foreach (var collection in collections) + { + Console.WriteLine(collection); + } + + Console.WriteLine("== Adding Memories =="); + + var key1 = await kernel.Memory.SaveInformationAsync(MemoryCollectionName, id: "cat1", text: "british short hair"); + var key2 = await kernel.Memory.SaveInformationAsync(MemoryCollectionName, id: "cat2", text: "orange tabby"); + var key3 = await kernel.Memory.SaveInformationAsync(MemoryCollectionName, id: "cat3", text: "norwegian forest cat"); + + Console.WriteLine("== Printing Collections in DB =="); + collections = memoryStore.GetCollectionsAsync(); + await foreach (var collection in collections) + { + Console.WriteLine(collection); + } + + Console.WriteLine("== Retrieving Memories Through the Kernel =="); + MemoryQueryResult? lookup = await kernel.Memory.GetAsync(MemoryCollectionName, "cat1"); + Console.WriteLine(lookup != null ? lookup.Metadata.Text : "ERROR: memory not found"); + + Console.WriteLine("== Retrieving Memories Directly From the Store =="); + var memory1 = await memoryStore.GetAsync(MemoryCollectionName, key1); + var memory2 = await memoryStore.GetAsync(MemoryCollectionName, key2); + var memory3 = await memoryStore.GetAsync(MemoryCollectionName, key3); + Console.WriteLine(memory1 != null ? memory1.Metadata.Text : "ERROR: memory not found"); + Console.WriteLine(memory2 != null ? memory2.Metadata.Text : "ERROR: memory not found"); + Console.WriteLine(memory3 != null ? memory3.Metadata.Text : "ERROR: memory not found"); + + Console.WriteLine("== Similarity Searching Memories: My favorite color is orange =="); + var searchResults = kernel.Memory.SearchAsync(MemoryCollectionName, "My favorite color is orange", limit: 3, minRelevanceScore: 0.8); + + await foreach (var item in searchResults) + { + Console.WriteLine(item.Metadata.Text + " : " + item.Relevance); + } + + Console.WriteLine("== Removing Collection {0} ==", MemoryCollectionName); + await memoryStore.DeleteCollectionAsync(MemoryCollectionName); + + Console.WriteLine("== Printing Collections in DB =="); + await foreach (var collection in collections) + { + Console.WriteLine(collection); + } + } +} diff --git a/samples/dotnet/kernel-syntax-examples/KernelSyntaxExamples.csproj b/samples/dotnet/kernel-syntax-examples/KernelSyntaxExamples.csproj index 39d82c074b03..b03a8ad55479 100644 --- a/samples/dotnet/kernel-syntax-examples/KernelSyntaxExamples.csproj +++ b/samples/dotnet/kernel-syntax-examples/KernelSyntaxExamples.csproj @@ -28,6 +28,7 @@ + diff --git a/samples/dotnet/kernel-syntax-examples/Program.cs b/samples/dotnet/kernel-syntax-examples/Program.cs index dd7f9c28113a..0e3763ed4112 100644 --- a/samples/dotnet/kernel-syntax-examples/Program.cs +++ b/samples/dotnet/kernel-syntax-examples/Program.cs @@ -121,5 +121,8 @@ public static async Task Main() await Example38_Pinecone.RunAsync(); Console.WriteLine("== DONE =="); + + await Example39_Postgres.RunAsync(); + Console.WriteLine("== DONE =="); } } diff --git a/samples/dotnet/kernel-syntax-examples/README.md b/samples/dotnet/kernel-syntax-examples/README.md index 16c063da5f32..92a5e5e4b5ba 100644 --- a/samples/dotnet/kernel-syntax-examples/README.md +++ b/samples/dotnet/kernel-syntax-examples/README.md @@ -28,6 +28,7 @@ dotnet user-secrets set "ACS_API_KEY" "..." dotnet user-secrets set "QDRANT_ENDPOINT" "..." dotnet user-secrets set "QDRANT_PORT" "..." dotnet user-secrets set "GITHUB_PERSONAL_ACCESS_TOKEN" "github_pat_..." +dotnet user-secrets set "POSTGRES_CONNECTIONSTRING" "..." ``` To set your secrets with environment variables, use these names: @@ -43,3 +44,4 @@ To set your secrets with environment variables, use these names: * QDRANT_ENDPOINT * QDRANT_PORT * GITHUB_PERSONAL_ACCESS_TOKEN +* POSTGRES_CONNECTIONSTRING