From fe56bbd76e28cdedf52d74923301a9412ed2d070 Mon Sep 17 00:00:00 2001 From: leijendary Date: Fri, 11 Oct 2024 23:44:59 +0200 Subject: [PATCH 1/8] Added ChatMemory implementation for PgVector --- .../modules/ROOT/pages/api/chatclient.adoc | 11 +- .../PgVectorChatMemoryAutoConfiguration.java | 56 +++++ .../PgVectorChatMemoryProperties.java | 91 +++++++ ...ot.autoconfigure.AutoConfiguration.imports | 1 + ...PgVectorChatMemoryAutoConfigurationIT.java | 69 ++++++ .../PgVectorChatMemoryPropertiesTests.java | 45 ++++ .../ai/chat/memory/PgVectorChatMemory.java | 130 ++++++++++ .../chat/memory/PgVectorChatMemoryConfig.java | 222 ++++++++++++++++++ .../memory/PgVectorChatMemoryConfigIT.java | 157 +++++++++++++ .../ai/chat/memory/PgVectorChatMemoryIT.java | 219 +++++++++++++++++ 10 files changed, 999 insertions(+), 2 deletions(-) create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfiguration.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryProperties.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfigurationIT.java create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryPropertiesTests.java create mode 100644 vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemory.java create mode 100644 vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfig.java create mode 100644 vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java create mode 100644 vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryIT.java diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc index 06ab23a8976..c1d6cc786d8 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc @@ -410,7 +410,7 @@ The `FILTER_EXPRESSION` parameter allows you to dynamically filter the search re The interface `ChatMemory` represents a storage for chat conversation history. It provides methods to add messages to a conversation, retrieve messages from a conversation, and clear the conversation history. -There are currently two implementations, `InMemoryChatMemory` and `CassandraChatMemory`, that provide storage for chat conversation history, in-memory and persisted with `time-to-live`, correspondingly. +There are currently three implementations, `InMemoryChatMemory`, `CassandraChatMemory`, and `PgVectorChatMemory` that provide storage for chat conversation history, in-memory and persisted with `time-to-live`, correspondingly. To create a `CassandraChatMemory` with `time-to-live`: @@ -419,11 +419,18 @@ To create a `CassandraChatMemory` with `time-to-live`: CassandraChatMemory.create(CassandraChatMemoryConfig.builder().withTimeToLive(Duration.ofDays(1)).build()); ---- +To create a `PgVectorChatMemory`: + +[source,java] +---- +PgVectorChatMemory.create(PgVectorChatMemoryConfig.builder().withJdbcTemplate(jdbcTemplate).build()); +---- + The following advisor implementations use the `ChatMemory` interface to advice the prompt with conversation history which differ in the details of how the memory is added to the prompt * `MessageChatMemoryAdvisor` : Memory is retrieved and added as a collection of messages to the prompt * `PromptChatMemoryAdvisor` : Memory is retrieved and added into the prompt's system text. -* `VectorStoreChatMemoryAdvisor` : The constructor `VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId, int chatHistoryWindowSize, int order)` This constructor allows you to: +* `VectorStoreChatMemoryAdvisor` : The constructor `VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId, int chatHistoryWindowSize, int order)` This constructor allows you to: . Specify the VectorStore instance used for managing and querying documents. . Set a default conversation ID to be used if none is provided in the context. diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfiguration.java new file mode 100644 index 00000000000..0cc5d0359ed --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfiguration.java @@ -0,0 +1,56 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.chat.memory.pgvector; + +import org.springframework.ai.chat.memory.PgVectorChatMemory; +import org.springframework.ai.chat.memory.PgVectorChatMemoryConfig; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.jdbc.core.JdbcTemplate; + +import javax.sql.DataSource; + +/** + * @author Jonathan Leijendekker + * @since 1.0.0 + */ +@AutoConfiguration(after = JdbcTemplateAutoConfiguration.class) +@ConditionalOnClass({ PgVectorChatMemory.class, DataSource.class, JdbcTemplate.class }) +@EnableConfigurationProperties(PgVectorChatMemoryProperties.class) +public class PgVectorChatMemoryAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + public PgVectorChatMemory chatMemory(PgVectorChatMemoryProperties properties, JdbcTemplate jdbcTemplate) { + var config = PgVectorChatMemoryConfig.builder() + .withInitializeSchema(properties.isInitializeSchema()) + .withSchemaName(properties.getSchemaName()) + .withTableName(properties.getTableName()) + .withSessionIdColumnName(properties.getSessionIdColumnName()) + .withExchangeIdColumnName(properties.getExchangeIdColumnName()) + .withAssistantColumnName(properties.getAssistantColumnName()) + .withUserColumnName(properties.getUserColumnName()) + .withJdbcTemplate(jdbcTemplate) + .build(); + + return PgVectorChatMemory.create(config); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryProperties.java new file mode 100644 index 00000000000..be1b192682e --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryProperties.java @@ -0,0 +1,91 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.autoconfigure.chat.memory.pgvector; + +import org.springframework.ai.autoconfigure.chat.memory.CommonChatMemoryProperties; +import org.springframework.ai.chat.memory.PgVectorChatMemoryConfig; +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * @author Jonathan Leijendekker + * @since 1.0.0 + */ +@ConfigurationProperties(PgVectorChatMemoryProperties.CONFIG_PREFIX) +public class PgVectorChatMemoryProperties extends CommonChatMemoryProperties { + + public static final String CONFIG_PREFIX = "spring.ai.chat.memory.pgvector"; + + private String schemaName = PgVectorChatMemoryConfig.DEFAULT_SCHEMA_NAME; + + private String tableName = PgVectorChatMemoryConfig.DEFAULT_TABLE_NAME; + + private String sessionIdColumnName = PgVectorChatMemoryConfig.DEFAULT_SESSION_ID_COLUMN_NAME; + + private String exchangeIdColumnName = PgVectorChatMemoryConfig.DEFAULT_EXCHANGE_ID_COLUMN_NAME; + + private String assistantColumnName = PgVectorChatMemoryConfig.DEFAULT_ASSISTANT_COLUMN_NAME; + + private String userColumnName = PgVectorChatMemoryConfig.DEFAULT_USER_COLUMN_NAME; + + public String getSchemaName() { + return schemaName; + } + + public void setSchemaName(String schemaName) { + this.schemaName = schemaName; + } + + public String getTableName() { + return tableName; + } + + public void setTableName(String tableName) { + this.tableName = tableName; + } + + public String getSessionIdColumnName() { + return sessionIdColumnName; + } + + public void setSessionIdColumnName(String sessionIdColumnName) { + this.sessionIdColumnName = sessionIdColumnName; + } + + public String getExchangeIdColumnName() { + return exchangeIdColumnName; + } + + public void setExchangeIdColumnName(String exchangeIdColumnName) { + this.exchangeIdColumnName = exchangeIdColumnName; + } + + public String getAssistantColumnName() { + return assistantColumnName; + } + + public void setAssistantColumnName(String assistantColumnName) { + this.assistantColumnName = assistantColumnName; + } + + public String getUserColumnName() { + return userColumnName; + } + + public void setUserColumnName(String userColumnName) { + this.userColumnName = userColumnName; + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index d5b849aa3c8..7866086de32 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -48,3 +48,4 @@ org.springframework.ai.autoconfigure.minimax.MiniMaxAutoConfiguration org.springframework.ai.autoconfigure.vertexai.embedding.VertexAiEmbeddingAutoConfiguration org.springframework.ai.autoconfigure.chat.memory.cassandra.CassandraChatMemoryAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.observation.VectorStoreObservationAutoConfiguration +org.springframework.ai.autoconfigure.chat.memory.pgvector.PgVectorChatMemoryAutoConfiguration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfigurationIT.java new file mode 100644 index 00000000000..b9522382dd0 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfigurationIT.java @@ -0,0 +1,69 @@ +package org.springframework.ai.autoconfigure.chat.memory.pgvector; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.memory.PgVectorChatMemory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.util.List; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Jonathan Leijendekker + */ +@Testcontainers +class PgVectorChatMemoryAutoConfigurationIT { + + @Container + @SuppressWarnings("resource") + static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>("pgvector/pgvector:pg17") + .withUsername("postgres") + .withPassword("postgres"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(PgVectorChatMemoryAutoConfiguration.class, + JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class)) + .withPropertyValues("spring.ai.chat.memory.pgvector.schemaName=test_autoconfigure", + // JdbcTemplate configuration + String.format("spring.datasource.url=jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(), + postgresContainer.getMappedPort(5432), postgresContainer.getDatabaseName()), + "spring.datasource.username=" + postgresContainer.getUsername(), + "spring.datasource.password=" + postgresContainer.getPassword()); + + @Test + void addGetAndClear_shouldAllExecute() { + contextRunner.run(context -> { + var chatMemory = context.getBean(PgVectorChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var userMessage = new UserMessage("Message from the user"); + + chatMemory.add(conversationId, userMessage); + + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).hasSize(1); + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEqualTo(List.of(userMessage)); + + chatMemory.clear(conversationId); + + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEmpty(); + + var multipleMessages = List.of(new UserMessage("Message from the user 1"), + new AssistantMessage("Message from the assistant 1")); + + chatMemory.add(conversationId, multipleMessages); + + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).hasSize(multipleMessages.size()); + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEqualTo(multipleMessages); + }); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryPropertiesTests.java new file mode 100644 index 00000000000..4925c4bec69 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryPropertiesTests.java @@ -0,0 +1,45 @@ +package org.springframework.ai.autoconfigure.chat.memory.pgvector; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.memory.PgVectorChatMemoryConfig; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * @author Jonathan Leijendekker + */ +class PgVectorChatMemoryPropertiesTests { + + @Test + void defaultValues() { + var props = new PgVectorChatMemoryProperties(); + assertEquals(PgVectorChatMemoryConfig.DEFAULT_SCHEMA_NAME, props.getSchemaName()); + assertEquals(PgVectorChatMemoryConfig.DEFAULT_TABLE_NAME, props.getTableName()); + assertEquals(PgVectorChatMemoryConfig.DEFAULT_SESSION_ID_COLUMN_NAME, props.getSessionIdColumnName()); + assertEquals(PgVectorChatMemoryConfig.DEFAULT_EXCHANGE_ID_COLUMN_NAME, props.getExchangeIdColumnName()); + assertEquals(PgVectorChatMemoryConfig.DEFAULT_ASSISTANT_COLUMN_NAME, props.getAssistantColumnName()); + assertEquals(PgVectorChatMemoryConfig.DEFAULT_USER_COLUMN_NAME, props.getUserColumnName()); + assertTrue(props.isInitializeSchema()); + } + + @Test + void customValues() { + var props = new PgVectorChatMemoryProperties(); + props.setSchemaName("custom_schema_name"); + props.setTableName("custom_table_name"); + props.setSessionIdColumnName("custom_session_id_column_name"); + props.setExchangeIdColumnName("custom_exchange_id_column_name"); + props.setAssistantColumnName("custom_assistant_column_name"); + props.setUserColumnName("custom_user_column_name"); + props.setInitializeSchema(false); + + assertEquals("custom_schema_name", props.getSchemaName()); + assertEquals("custom_table_name", props.getTableName()); + assertEquals("custom_session_id_column_name", props.getSessionIdColumnName()); + assertEquals("custom_exchange_id_column_name", props.getExchangeIdColumnName()); + assertEquals("custom_assistant_column_name", props.getAssistantColumnName()); + assertEquals("custom_user_column_name", props.getUserColumnName()); + assertFalse(props.isInitializeSchema()); + } + +} diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemory.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemory.java new file mode 100644 index 00000000000..31c371f286e --- /dev/null +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemory.java @@ -0,0 +1,130 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.memory; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.jdbc.core.BatchPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.RowMapper; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Types; +import java.util.List; + +/** + * An implementation of {@link ChatMemory} for PgVector. Creating an instance of + * PgVectorChatMemory example: + * PgVectorChatMemory.create(PgVectorChatMemoryConfig.builder().withJdbcTemplate(jdbcTemplate).build()); + * + * @author Jonathan Leijendekker + * @since 1.0.0 + */ +public class PgVectorChatMemory implements ChatMemory { + + private final PgVectorChatMemoryConfig config; + + private final JdbcTemplate jdbcTemplate; + + public PgVectorChatMemory(PgVectorChatMemoryConfig config) { + this.config = config; + this.config.initializeSchema(); + this.jdbcTemplate = this.config.getJdbcTemplate(); + } + + public static PgVectorChatMemory create(PgVectorChatMemoryConfig config) { + return new PgVectorChatMemory(config); + } + + @Override + public void add(String conversationId, List messages) { + var sql = String.format("INSERT INTO %s (%s, %s, %s) VALUES (?, ?, ?)", + this.config.getFullyQualifiedTableName(), this.config.getSessionIdColumnName(), + this.config.getAssistantColumnName(), this.config.getUserColumnName()); + + this.jdbcTemplate.batchUpdate(sql, new AddBatchPreparedStatement(conversationId, messages)); + } + + @Override + public List get(String conversationId, int lastN) { + var sql = String.format("SELECT %s, %s FROM %s WHERE %s = ? ORDER BY %s DESC LIMIT ?", + this.config.getAssistantColumnName(), this.config.getUserColumnName(), + this.config.getFullyQualifiedTableName(), this.config.getSessionIdColumnName(), + this.config.getExchangeIdColumnName()); + + return this.jdbcTemplate.query(sql, new MessageRowMapper(), conversationId, lastN); + } + + @Override + public void clear(String conversationId) { + var sql = String.format("DELETE FROM %s WHERE %s = ?", this.config.getFullyQualifiedTableName(), + this.config.getSessionIdColumnName()); + + this.jdbcTemplate.update(sql, conversationId); + } + + private record AddBatchPreparedStatement(String conversationId, + List messages) implements BatchPreparedStatementSetter { + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + var message = this.messages.get(i); + + ps.setString(1, this.conversationId); + + switch (message.getMessageType()) { + case ASSISTANT -> { + ps.setString(2, message.getContent()); + ps.setNull(3, Types.VARCHAR); + } + case USER -> { + ps.setNull(2, Types.VARCHAR); + ps.setString(3, message.getContent()); + } + default -> throw new IllegalArgumentException("Can't add type " + message); + } + } + + @Override + public int getBatchSize() { + return this.messages.size(); + } + } + + private static class MessageRowMapper implements RowMapper { + + @Override + public Message mapRow(ResultSet rs, int i) throws SQLException { + var assistant = rs.getString(1); + + if (assistant != null) { + return new AssistantMessage(assistant); + } + + var user = rs.getString(2); + + if (user != null) { + return new UserMessage(user); + } + + return null; + } + + } + +} diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfig.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfig.java new file mode 100644 index 00000000000..09deb248a2c --- /dev/null +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfig.java @@ -0,0 +1,222 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.memory; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.util.Assert; + +/** + * @author Jonathan Leijendekker + * @since 1.0.0 + */ +public class PgVectorChatMemoryConfig { + + private static final Logger logger = LoggerFactory.getLogger(PgVectorChatMemoryConfig.class); + + public static final boolean DEFAULT_SCHEMA_INITIALIZATION = false; + + public static final String DEFAULT_SCHEMA_NAME = "public"; + + public static final String DEFAULT_TABLE_NAME = "ai_chat_memory"; + + public static final String DEFAULT_SESSION_ID_COLUMN_NAME = "session_id"; + + public static final String DEFAULT_EXCHANGE_ID_COLUMN_NAME = "message_timestamp"; + + public static final String DEFAULT_ASSISTANT_COLUMN_NAME = "assistant"; + + // "user" is a reserved keyword in postgres, hence the double quotes. + public static final String DEFAULT_USER_COLUMN_NAME = "\"user\""; + + private final boolean initializeSchema; + + private final String schemaName; + + private final String tableName; + + private final String sessionIdColumnName; + + private final String exchangeIdColumnName; + + private final String assistantColumnName; + + private final String userColumnName; + + private final JdbcTemplate jdbcTemplate; + + private PgVectorChatMemoryConfig(Builder builder) { + this.initializeSchema = builder.initializeSchema; + this.schemaName = builder.schemaName; + this.tableName = builder.tableName; + this.sessionIdColumnName = builder.sessionIdColumnName; + this.exchangeIdColumnName = builder.exchangeIdColumnName; + this.assistantColumnName = builder.assistantColumnName; + this.userColumnName = builder.userColumnName; + this.jdbcTemplate = builder.jdbcTemplate; + } + + public static Builder builder() { + return new Builder(); + } + + String getFullyQualifiedTableName() { + return this.schemaName + "." + this.tableName; + } + + String getSchemaName() { + return this.schemaName; + } + + String getTableName() { + return this.tableName; + } + + String getSessionIdColumnName() { + return this.sessionIdColumnName; + } + + String getExchangeIdColumnName() { + return this.exchangeIdColumnName; + } + + String getAssistantColumnName() { + return this.assistantColumnName; + } + + String getUserColumnName() { + return this.userColumnName; + } + + JdbcTemplate getJdbcTemplate() { + return this.jdbcTemplate; + } + + void initializeSchema() { + if (!this.initializeSchema) { + logger.warn("Skipping the schema initialization for table: {}", this.getFullyQualifiedTableName()); + return; + } + + logger.info("Initializing PGChatMemory schema for table: {} in schema: {}", this.getTableName(), + this.getSchemaName()); + + var indexName = String + .format("%s_%s_%s_idx", this.getTableName(), this.getSessionIdColumnName(), this.getExchangeIdColumnName()) + // Keywords in postgres has to be wrapped in double quotes. It is possible + // that the table or column may + // be a reserved keyword. If so, just remove them. + .replaceAll("\"", ""); + + this.jdbcTemplate.execute(String.format("CREATE SCHEMA IF NOT EXISTS %s", this.getSchemaName())); + this.jdbcTemplate.execute(String.format(""" + CREATE TABLE IF NOT EXISTS %s ( + %s character varying(40) NOT NULL, + %s timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, + %s text, + %s text + ) + """, this.getFullyQualifiedTableName(), this.getSessionIdColumnName(), this.getExchangeIdColumnName(), + this.getAssistantColumnName(), this.getUserColumnName())); + this.jdbcTemplate.execute(String.format("CREATE INDEX IF NOT EXISTS %s ON %s(%s, %s DESC)", indexName, + this.getFullyQualifiedTableName(), this.getSessionIdColumnName(), this.getExchangeIdColumnName())); + } + + public static class Builder { + + private boolean initializeSchema = DEFAULT_SCHEMA_INITIALIZATION; + + private String schemaName = DEFAULT_SCHEMA_NAME; + + private String tableName = DEFAULT_TABLE_NAME; + + private String sessionIdColumnName = DEFAULT_SESSION_ID_COLUMN_NAME; + + private String exchangeIdColumnName = DEFAULT_EXCHANGE_ID_COLUMN_NAME; + + private String assistantColumnName = DEFAULT_ASSISTANT_COLUMN_NAME; + + private String userColumnName = DEFAULT_USER_COLUMN_NAME; + + private JdbcTemplate jdbcTemplate; + + private Builder() { + } + + public Builder withInitializeSchema(boolean initializeSchema) { + this.initializeSchema = initializeSchema; + return this; + } + + public Builder withSchemaName(String schemaName) { + Assert.hasText(schemaName, "schema name must not be empty"); + + this.schemaName = schemaName; + return this; + } + + public Builder withTableName(String tableName) { + Assert.hasText(tableName, "table name must not be empty"); + + this.tableName = tableName; + return this; + } + + public Builder withSessionIdColumnName(String sessionIdColumnName) { + Assert.hasText(sessionIdColumnName, "session id column name must not be empty"); + + this.sessionIdColumnName = sessionIdColumnName; + return this; + } + + public Builder withExchangeIdColumnName(String exchangeIdColumnName) { + Assert.hasText(exchangeIdColumnName, "exchange id column name must not be empty"); + + this.exchangeIdColumnName = exchangeIdColumnName; + return this; + } + + public Builder withAssistantColumnName(String assistantColumnName) { + Assert.hasText(assistantColumnName, "assistant column name must not be empty"); + + this.assistantColumnName = assistantColumnName; + return this; + } + + public Builder withUserColumnName(String userColumnName) { + Assert.hasText(userColumnName, "user column name must not be empty"); + + this.userColumnName = userColumnName; + return this; + } + + public Builder withJdbcTemplate(JdbcTemplate jdbcTemplate) { + Assert.notNull(jdbcTemplate, "jdbc template must not be null"); + + this.jdbcTemplate = jdbcTemplate; + return this; + } + + public PgVectorChatMemoryConfig build() { + Assert.notNull(jdbcTemplate, "jdbc template must not be null"); + + return new PgVectorChatMemoryConfig(this); + } + + } + +} diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java new file mode 100644 index 00000000000..188051c6a11 --- /dev/null +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java @@ -0,0 +1,157 @@ +package org.springframework.ai.chat.memory; + +import com.zaxxer.hikari.HikariDataSource; +import org.junit.jupiter.api.Test; +import org.springframework.ai.vectorstore.PgVectorImage; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Primary; +import org.springframework.jdbc.core.JdbcTemplate; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import javax.sql.DataSource; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * @author Jonathan Leijendekker + */ +@Testcontainers +class PgVectorChatMemoryConfigIT { + + @Container + @SuppressWarnings("resource") + static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>(PgVectorImage.DEFAULT_IMAGE) + .withUsername("postgres") + .withPassword("postgres"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(PgVectorChatMemoryIT.TestApplication.class) + .withPropertyValues( + // JdbcTemplate configuration + String.format("app.datasource.url=jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(), + postgresContainer.getMappedPort(5432), "postgres"), + "app.datasource.username=postgres", "app.datasource.password=postgres", + "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); + + static String schemaName = "config_test"; + static String tableName = "ai_chat_config_test"; + static String sessionIdColumnName = "id_config_test"; + static String exchangeIdColumnName = "timestamp_config_test"; + static String assistantColumnName = "assistant_config_test"; + static String userColumnName = "\"user_config_test\""; + + @Test + void initializeSchema_withValueTrue_shouldCreateSchema() { + contextRunner.run(context -> { + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var config = PgVectorChatMemoryConfig.builder() + .withInitializeSchema(true) + .withSchemaName(schemaName) + .withTableName(tableName) + .withSessionIdColumnName(sessionIdColumnName) + .withExchangeIdColumnName(exchangeIdColumnName) + .withAssistantColumnName(assistantColumnName) + .withUserColumnName(userColumnName) + .withJdbcTemplate(jdbcTemplate) + .build(); + config.initializeSchema(); + + var expectedColumns = List.of(sessionIdColumnName, exchangeIdColumnName, assistantColumnName, + userColumnName.replace("\"", "")); + + // Verify that the schema, table, and index was created + var hasSchema = jdbcTemplate.queryForObject( + "SELECT EXISTS (SELECT 1 FROM information_schema.schemata WHERE schema_name = ?)", boolean.class, + schemaName); + var hasTable = jdbcTemplate.queryForObject( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = ? AND table_name = ?)", + boolean.class, schemaName, tableName); + var tableColumns = jdbcTemplate.queryForList( + "SELECT column_name FROM information_schema.columns WHERE table_schema = ? AND table_name = ?", + String.class, schemaName, tableName); + var indexName = jdbcTemplate.queryForObject( + "SELECT indexname FROM pg_indexes WHERE schemaname = ? AND tablename = ?", String.class, schemaName, + tableName); + + assertEquals(Boolean.TRUE, hasSchema); + assertEquals(Boolean.TRUE, hasTable); + assertTrue(expectedColumns.containsAll(tableColumns)); + assertEquals(String.format("%s_%s_%s_idx", tableName, sessionIdColumnName, exchangeIdColumnName), + indexName); + + // Cleanup for the other tests + jdbcTemplate.update(String.format("DROP SCHEMA IF EXISTS %s CASCADE", schemaName)); + }); + } + + @Test + void initializeSchema_withValueFalse_shouldNotCreateSchema() { + contextRunner.run(context -> { + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var config = PgVectorChatMemoryConfig.builder() + .withInitializeSchema(false) + .withSchemaName(schemaName) + .withTableName(tableName) + .withSessionIdColumnName(sessionIdColumnName) + .withExchangeIdColumnName(exchangeIdColumnName) + .withAssistantColumnName(assistantColumnName) + .withUserColumnName(userColumnName) + .withJdbcTemplate(jdbcTemplate) + .build(); + config.initializeSchema(); + + // Verify that the schema, table, and index was created + var hasSchema = jdbcTemplate.queryForObject( + "SELECT EXISTS (SELECT 1 FROM information_schema.schemata WHERE schema_name = ?)", boolean.class, + schemaName); + var hasTable = jdbcTemplate.queryForObject( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = ? AND table_name = ?)", + boolean.class, schemaName, tableName); + var columnCount = jdbcTemplate.queryForObject( + "SELECT COUNT(*) FROM information_schema.columns WHERE table_schema = ? AND table_name = ?", + Integer.class, schemaName, tableName); + var hasIndex = jdbcTemplate.queryForObject( + "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = ? AND tablename = ?)", Boolean.class, + schemaName, tableName); + + assertEquals(Boolean.FALSE, hasSchema); + assertEquals(Boolean.FALSE, hasTable); + assertEquals(0, columnCount); + assertEquals(Boolean.FALSE, hasIndex); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class TestApplication { + + @Bean + public JdbcTemplate jdbcTemplate(DataSource dataSource) { + return new JdbcTemplate(dataSource); + } + + @Bean + @Primary + @ConfigurationProperties("app.datasource") + public DataSourceProperties dataSourceProperties() { + return new DataSourceProperties(); + } + + @Bean + public HikariDataSource dataSource(DataSourceProperties dataSourceProperties) { + return dataSourceProperties.initializeDataSourceBuilder().type(HikariDataSource.class).build(); + } + + } + +} diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryIT.java new file mode 100644 index 00000000000..347d0fc1d14 --- /dev/null +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryIT.java @@ -0,0 +1,219 @@ +package org.springframework.ai.chat.memory; + +import com.zaxxer.hikari.HikariDataSource; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.vectorstore.PgVectorImage; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Primary; +import org.springframework.jdbc.core.JdbcTemplate; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import javax.sql.DataSource; +import java.sql.Timestamp; +import java.util.List; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * @author Jonathan Leijendekker + */ +@Testcontainers +class PgVectorChatMemoryIT { + + @Container + @SuppressWarnings("resource") + static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>(PgVectorImage.DEFAULT_IMAGE) + .withUsername("postgres") + .withPassword("postgres"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class) + .withPropertyValues( + // JdbcTemplate configuration + String.format("app.datasource.url=jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(), + postgresContainer.getMappedPort(5432), "postgres"), + "app.datasource.username=postgres", "app.datasource.password=postgres", + "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); + + static String schemaName = "public_test"; + static String tableName = "ai_chat_memory_test"; + static String sessionIdColumnName = "session_id_test"; + static String exchangeIdColumnName = "message_timestamp_test"; + static String assistantColumnName = "assistant_test"; + static String userColumnName = "\"user_test\""; + + @Test + void correctChatMemoryInstance() { + contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + + assertInstanceOf(PgVectorChatMemory.class, chatMemory); + }); + } + + @ParameterizedTest + @CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER" }) + void add_shouldInsertSingleMessage(String content, MessageType messageType) { + contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + String assistantContent = null; + String userContent = null; + var message = switch (messageType) { + case ASSISTANT -> { + assistantContent = content + " - " + conversationId; + yield new AssistantMessage(assistantContent); + } + case USER -> { + userContent = content + " - " + conversationId; + yield new UserMessage(userContent); + } + default -> throw new IllegalArgumentException("Type not supported: " + messageType); + }; + + chatMemory.add(conversationId, message); + + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var query = String.format("SELECT %s, %s, %s, %s FROM %s.%s WHERE %s = ?", sessionIdColumnName, + exchangeIdColumnName, assistantColumnName, userColumnName, schemaName, tableName, + sessionIdColumnName); + var result = jdbcTemplate.queryForMap(query, conversationId); + + assertEquals(4, result.size()); + assertEquals(conversationId, result.get(sessionIdColumnName)); + assertInstanceOf(Timestamp.class, result.get(exchangeIdColumnName)); + assertNotNull(result.get(exchangeIdColumnName)); + assertEquals(assistantContent, result.get(assistantColumnName)); + assertEquals(userContent, result.get(userColumnName.replace("\"", ""))); + }); + } + + @Test + void add_shouldInsertMessages() { + contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant - " + conversationId), + new UserMessage("Message from user - " + conversationId)); + + chatMemory.add(conversationId, messages); + + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var query = String.format("SELECT %s, %s, %s, %s FROM %s.%s WHERE %s = ?", sessionIdColumnName, + exchangeIdColumnName, assistantColumnName, userColumnName, schemaName, tableName, + sessionIdColumnName); + var results = jdbcTemplate.queryForList(query, conversationId); + + assertEquals(messages.size(), results.size()); + + for (var i = 0; i < messages.size(); i++) { + var message = messages.get(i); + var result = results.get(i); + + assertEquals(conversationId, result.get(sessionIdColumnName)); + assertInstanceOf(Timestamp.class, result.get(exchangeIdColumnName)); + assertNotNull(result.get(exchangeIdColumnName)); + + if (message.getMessageType() == MessageType.ASSISTANT) { + assertEquals(message.getContent(), result.get(assistantColumnName)); + } + else { + assertEquals(message.getContent(), result.get(userColumnName.replace("\"", ""))); + } + } + }); + } + + @Test + void get_shouldReturnMessages() { + contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant 1 - " + conversationId), + new AssistantMessage("Message from assistant 2 - " + conversationId), + new UserMessage("Message from user - " + conversationId)); + + chatMemory.add(conversationId, messages); + + var results = chatMemory.get(conversationId, Integer.MAX_VALUE); + + assertEquals(messages.size(), results.size()); + assertEquals(messages, results); + }); + } + + @Test + void clear_shouldDeleteMessages() { + contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant - " + conversationId), + new UserMessage("Message from user - " + conversationId)); + + chatMemory.add(conversationId, messages); + + chatMemory.clear(conversationId); + + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var count = jdbcTemplate.queryForObject(String.format("SELECT COUNT(*) FROM %s.%s WHERE %s = ?", schemaName, + tableName, sessionIdColumnName), Integer.class, conversationId); + + assertEquals(0, count); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class TestApplication { + + @Bean + public ChatMemory chatMemory(JdbcTemplate jdbcTemplate) { + var config = PgVectorChatMemoryConfig.builder() + .withInitializeSchema(true) + .withSchemaName(schemaName) + .withTableName(tableName) + .withSessionIdColumnName(sessionIdColumnName) + .withExchangeIdColumnName(exchangeIdColumnName) + .withAssistantColumnName(assistantColumnName) + .withUserColumnName(userColumnName) + .withJdbcTemplate(jdbcTemplate) + .build(); + + return PgVectorChatMemory.create(config); + } + + @Bean + public JdbcTemplate jdbcTemplate(DataSource dataSource) { + return new JdbcTemplate(dataSource); + } + + @Bean + @Primary + @ConfigurationProperties("app.datasource") + public DataSourceProperties dataSourceProperties() { + return new DataSourceProperties(); + } + + @Bean + public HikariDataSource dataSource(DataSourceProperties dataSourceProperties) { + return dataSourceProperties.initializeDataSourceBuilder().type(HikariDataSource.class).build(); + } + + } + +} From 7f4f23dedfe29482b6c1cd0ea8f44dc712b046df Mon Sep 17 00:00:00 2001 From: leijendary Date: Sat, 12 Oct 2024 09:23:31 +0200 Subject: [PATCH 2/8] Used values from postgresContainer for properties --- .../pgvector/PgVectorChatMemoryAutoConfigurationIT.java | 8 ++++---- .../ai/chat/memory/PgVectorChatMemoryConfigIT.java | 6 +++--- .../ai/chat/memory/PgVectorChatMemoryIT.java | 6 +++--- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfigurationIT.java index b9522382dd0..a6c6a3672ab 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfigurationIT.java @@ -35,10 +35,10 @@ class PgVectorChatMemoryAutoConfigurationIT { JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class)) .withPropertyValues("spring.ai.chat.memory.pgvector.schemaName=test_autoconfigure", // JdbcTemplate configuration - String.format("spring.datasource.url=jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(), - postgresContainer.getMappedPort(5432), postgresContainer.getDatabaseName()), - "spring.datasource.username=" + postgresContainer.getUsername(), - "spring.datasource.password=" + postgresContainer.getPassword()); + String.format("spring.datasource.url=%s", postgresContainer.getJdbcUrl()), + String.format("spring.datasource.username=%s", postgresContainer.getUsername()), + String.format("spring.datasource.password=%s", postgresContainer.getPassword()), + "spring.datasource.type=com.zaxxer.hikari.HikariDataSource"); @Test void addGetAndClear_shouldAllExecute() { diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java index 188051c6a11..2aac72965f7 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java @@ -38,9 +38,9 @@ class PgVectorChatMemoryConfigIT { .withUserConfiguration(PgVectorChatMemoryIT.TestApplication.class) .withPropertyValues( // JdbcTemplate configuration - String.format("app.datasource.url=jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(), - postgresContainer.getMappedPort(5432), "postgres"), - "app.datasource.username=postgres", "app.datasource.password=postgres", + String.format("app.datasource.url=%s", postgresContainer.getJdbcUrl()), + String.format("app.datasource.username=%s", postgresContainer.getUsername()), + String.format("app.datasource.password=%s", postgresContainer.getPassword()), "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); static String schemaName = "config_test"; diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryIT.java index 347d0fc1d14..695975a644c 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryIT.java @@ -45,9 +45,9 @@ class PgVectorChatMemoryIT { .withUserConfiguration(TestApplication.class) .withPropertyValues( // JdbcTemplate configuration - String.format("app.datasource.url=jdbc:postgresql://%s:%d/%s", postgresContainer.getHost(), - postgresContainer.getMappedPort(5432), "postgres"), - "app.datasource.username=postgres", "app.datasource.password=postgres", + String.format("app.datasource.url=%s", postgresContainer.getJdbcUrl()), + String.format("app.datasource.username=%s", postgresContainer.getUsername()), + String.format("app.datasource.password=%s", postgresContainer.getPassword()), "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); static String schemaName = "public_test"; From c85c421fc81a7b3f5ac418ca77cbfcafb224d37a Mon Sep 17 00:00:00 2001 From: leijendary Date: Sun, 3 Nov 2024 13:01:26 +0100 Subject: [PATCH 3/8] Updated comment --- .../ai/chat/memory/PgVectorChatMemoryConfig.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfig.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfig.java index 09deb248a2c..0eeea60d96c 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfig.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfig.java @@ -118,8 +118,8 @@ void initializeSchema() { var indexName = String .format("%s_%s_%s_idx", this.getTableName(), this.getSessionIdColumnName(), this.getExchangeIdColumnName()) // Keywords in postgres has to be wrapped in double quotes. It is possible - // that the table or column may - // be a reserved keyword. If so, just remove them. + // that the table or column may be a reserved keyword. If so, just remove + // them. .replaceAll("\"", ""); this.jdbcTemplate.execute(String.format("CREATE SCHEMA IF NOT EXISTS %s", this.getSchemaName())); From 81cff4e3e59ac63583909b2d7a1806f00769c2f4 Mon Sep 17 00:00:00 2001 From: leijendary Date: Mon, 4 Nov 2024 01:21:45 +0100 Subject: [PATCH 4/8] Fixed checkstyle errors --- .../PgVectorChatMemoryAutoConfiguration.java | 9 ++- .../PgVectorChatMemoryProperties.java | 17 ++-- ...PgVectorChatMemoryAutoConfigurationIT.java | 31 ++++++-- .../PgVectorChatMemoryPropertiesTests.java | 47 +++++++---- .../ai/chat/memory/PgVectorChatMemory.java | 17 ++-- .../chat/memory/PgVectorChatMemoryConfig.java | 12 +-- .../memory/PgVectorChatMemoryConfigIT.java | 55 ++++++++----- .../ai/chat/memory/PgVectorChatMemoryIT.java | 78 ++++++++++++------- 8 files changed, 170 insertions(+), 96 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfiguration.java index 0cc5d0359ed..ff8f237b1b7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfiguration.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2024-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.memory.pgvector; +import javax.sql.DataSource; + import org.springframework.ai.chat.memory.PgVectorChatMemory; import org.springframework.ai.chat.memory.PgVectorChatMemoryConfig; import org.springframework.boot.autoconfigure.AutoConfiguration; @@ -25,8 +28,6 @@ import org.springframework.context.annotation.Bean; import org.springframework.jdbc.core.JdbcTemplate; -import javax.sql.DataSource; - /** * @author Jonathan Leijendekker * @since 1.0.0 diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryProperties.java index be1b192682e..ab35c7204a8 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryProperties.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2024-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.autoconfigure.chat.memory.pgvector; import org.springframework.ai.autoconfigure.chat.memory.CommonChatMemoryProperties; @@ -41,7 +42,7 @@ public class PgVectorChatMemoryProperties extends CommonChatMemoryProperties { private String userColumnName = PgVectorChatMemoryConfig.DEFAULT_USER_COLUMN_NAME; public String getSchemaName() { - return schemaName; + return this.schemaName; } public void setSchemaName(String schemaName) { @@ -49,7 +50,7 @@ public void setSchemaName(String schemaName) { } public String getTableName() { - return tableName; + return this.tableName; } public void setTableName(String tableName) { @@ -57,7 +58,7 @@ public void setTableName(String tableName) { } public String getSessionIdColumnName() { - return sessionIdColumnName; + return this.sessionIdColumnName; } public void setSessionIdColumnName(String sessionIdColumnName) { @@ -65,7 +66,7 @@ public void setSessionIdColumnName(String sessionIdColumnName) { } public String getExchangeIdColumnName() { - return exchangeIdColumnName; + return this.exchangeIdColumnName; } public void setExchangeIdColumnName(String exchangeIdColumnName) { @@ -73,7 +74,7 @@ public void setExchangeIdColumnName(String exchangeIdColumnName) { } public String getAssistantColumnName() { - return assistantColumnName; + return this.assistantColumnName; } public void setAssistantColumnName(String assistantColumnName) { @@ -81,7 +82,7 @@ public void setAssistantColumnName(String assistantColumnName) { } public String getUserColumnName() { - return userColumnName; + return this.userColumnName; } public void setUserColumnName(String userColumnName) { diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfigurationIT.java index a6c6a3672ab..f5b81f667d7 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfigurationIT.java @@ -1,6 +1,29 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.autoconfigure.chat.memory.pgvector; +import java.util.List; +import java.util.UUID; + import org.junit.jupiter.api.Test; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.chat.memory.PgVectorChatMemory; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -9,12 +32,6 @@ import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.util.List; -import java.util.UUID; import static org.assertj.core.api.Assertions.assertThat; @@ -42,7 +59,7 @@ class PgVectorChatMemoryAutoConfigurationIT { @Test void addGetAndClear_shouldAllExecute() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { var chatMemory = context.getBean(PgVectorChatMemory.class); var conversationId = UUID.randomUUID().toString(); var userMessage = new UserMessage("Message from the user"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryPropertiesTests.java index 4925c4bec69..be0ed55c54c 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryPropertiesTests.java @@ -1,9 +1,26 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.autoconfigure.chat.memory.pgvector; import org.junit.jupiter.api.Test; + import org.springframework.ai.chat.memory.PgVectorChatMemoryConfig; -import static org.junit.jupiter.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Jonathan Leijendekker @@ -13,13 +30,13 @@ class PgVectorChatMemoryPropertiesTests { @Test void defaultValues() { var props = new PgVectorChatMemoryProperties(); - assertEquals(PgVectorChatMemoryConfig.DEFAULT_SCHEMA_NAME, props.getSchemaName()); - assertEquals(PgVectorChatMemoryConfig.DEFAULT_TABLE_NAME, props.getTableName()); - assertEquals(PgVectorChatMemoryConfig.DEFAULT_SESSION_ID_COLUMN_NAME, props.getSessionIdColumnName()); - assertEquals(PgVectorChatMemoryConfig.DEFAULT_EXCHANGE_ID_COLUMN_NAME, props.getExchangeIdColumnName()); - assertEquals(PgVectorChatMemoryConfig.DEFAULT_ASSISTANT_COLUMN_NAME, props.getAssistantColumnName()); - assertEquals(PgVectorChatMemoryConfig.DEFAULT_USER_COLUMN_NAME, props.getUserColumnName()); - assertTrue(props.isInitializeSchema()); + assertThat(props.getSchemaName()).isEqualTo(PgVectorChatMemoryConfig.DEFAULT_SCHEMA_NAME); + assertThat(props.getTableName()).isEqualTo(PgVectorChatMemoryConfig.DEFAULT_TABLE_NAME); + assertThat(props.getSessionIdColumnName()).isEqualTo(PgVectorChatMemoryConfig.DEFAULT_SESSION_ID_COLUMN_NAME); + assertThat(props.getExchangeIdColumnName()).isEqualTo(PgVectorChatMemoryConfig.DEFAULT_EXCHANGE_ID_COLUMN_NAME); + assertThat(props.getAssistantColumnName()).isEqualTo(PgVectorChatMemoryConfig.DEFAULT_ASSISTANT_COLUMN_NAME); + assertThat(props.getUserColumnName()).isEqualTo(PgVectorChatMemoryConfig.DEFAULT_USER_COLUMN_NAME); + assertThat(props.isInitializeSchema()).isTrue(); } @Test @@ -33,13 +50,13 @@ void customValues() { props.setUserColumnName("custom_user_column_name"); props.setInitializeSchema(false); - assertEquals("custom_schema_name", props.getSchemaName()); - assertEquals("custom_table_name", props.getTableName()); - assertEquals("custom_session_id_column_name", props.getSessionIdColumnName()); - assertEquals("custom_exchange_id_column_name", props.getExchangeIdColumnName()); - assertEquals("custom_assistant_column_name", props.getAssistantColumnName()); - assertEquals("custom_user_column_name", props.getUserColumnName()); - assertFalse(props.isInitializeSchema()); + assertThat(props.getSchemaName()).isEqualTo("custom_schema_name"); + assertThat(props.getTableName()).isEqualTo("custom_table_name"); + assertThat(props.getSessionIdColumnName()).isEqualTo("custom_session_id_column_name"); + assertThat(props.getExchangeIdColumnName()).isEqualTo("custom_exchange_id_column_name"); + assertThat(props.getAssistantColumnName()).isEqualTo("custom_assistant_column_name"); + assertThat(props.getUserColumnName()).isEqualTo("custom_user_column_name"); + assertThat(props.isInitializeSchema()).isFalse(); } } diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemory.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemory.java index 31c371f286e..9ce6cc2196c 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemory.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemory.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2024-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.memory; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Types; +import java.util.List; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -22,12 +29,6 @@ import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.RowMapper; -import java.sql.PreparedStatement; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Types; -import java.util.List; - /** * An implementation of {@link ChatMemory} for PgVector. Creating an instance of * PgVectorChatMemory example: diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfig.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfig.java index 0eeea60d96c..7b86f03c78e 100644 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfig.java +++ b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfig.java @@ -1,11 +1,11 @@ /* - * Copyright 2024 - 2024 the original author or authors. + * Copyright 2024-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * https://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.springframework.ai.chat.memory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.util.Assert; @@ -24,7 +26,7 @@ * @author Jonathan Leijendekker * @since 1.0.0 */ -public class PgVectorChatMemoryConfig { +public final class PgVectorChatMemoryConfig { private static final Logger logger = LoggerFactory.getLogger(PgVectorChatMemoryConfig.class); @@ -136,7 +138,7 @@ void initializeSchema() { this.getFullyQualifiedTableName(), this.getSessionIdColumnName(), this.getExchangeIdColumnName())); } - public static class Builder { + public static final class Builder { private boolean initializeSchema = DEFAULT_SCHEMA_INITIALIZATION; @@ -212,7 +214,7 @@ public Builder withJdbcTemplate(JdbcTemplate jdbcTemplate) { } public PgVectorChatMemoryConfig build() { - Assert.notNull(jdbcTemplate, "jdbc template must not be null"); + Assert.notNull(this.jdbcTemplate, "jdbc template must not be null"); return new PgVectorChatMemoryConfig(this); } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java index 2aac72965f7..17bf707f744 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java @@ -1,7 +1,31 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.memory; +import java.util.List; + +import javax.sql.DataSource; + import com.zaxxer.hikari.HikariDataSource; import org.junit.jupiter.api.Test; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.vectorstore.PgVectorImage; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; @@ -12,15 +36,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Primary; import org.springframework.jdbc.core.JdbcTemplate; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import javax.sql.DataSource; -import java.util.List; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Jonathan Leijendekker @@ -52,7 +69,7 @@ class PgVectorChatMemoryConfigIT { @Test void initializeSchema_withValueTrue_shouldCreateSchema() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { var jdbcTemplate = context.getBean(JdbcTemplate.class); var config = PgVectorChatMemoryConfig.builder() .withInitializeSchema(true) @@ -83,11 +100,11 @@ void initializeSchema_withValueTrue_shouldCreateSchema() { "SELECT indexname FROM pg_indexes WHERE schemaname = ? AND tablename = ?", String.class, schemaName, tableName); - assertEquals(Boolean.TRUE, hasSchema); - assertEquals(Boolean.TRUE, hasTable); - assertTrue(expectedColumns.containsAll(tableColumns)); - assertEquals(String.format("%s_%s_%s_idx", tableName, sessionIdColumnName, exchangeIdColumnName), - indexName); + assertThat(hasSchema).isTrue(); + assertThat(hasTable).isTrue(); + assertThat(expectedColumns.containsAll(tableColumns)).isTrue(); + assertThat(String.format("%s_%s_%s_idx", tableName, sessionIdColumnName, exchangeIdColumnName)) + .isEqualTo(indexName); // Cleanup for the other tests jdbcTemplate.update(String.format("DROP SCHEMA IF EXISTS %s CASCADE", schemaName)); @@ -96,7 +113,7 @@ void initializeSchema_withValueTrue_shouldCreateSchema() { @Test void initializeSchema_withValueFalse_shouldNotCreateSchema() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { var jdbcTemplate = context.getBean(JdbcTemplate.class); var config = PgVectorChatMemoryConfig.builder() .withInitializeSchema(false) @@ -124,10 +141,10 @@ void initializeSchema_withValueFalse_shouldNotCreateSchema() { "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = ? AND tablename = ?)", Boolean.class, schemaName, tableName); - assertEquals(Boolean.FALSE, hasSchema); - assertEquals(Boolean.FALSE, hasTable); - assertEquals(0, columnCount); - assertEquals(Boolean.FALSE, hasIndex); + assertThat(hasSchema).isFalse(); + assertThat(hasTable).isFalse(); + assertThat(columnCount).isZero(); + assertThat(hasIndex).isFalse(); }); } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryIT.java index 695975a644c..0f6e8e0236a 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryIT.java @@ -1,9 +1,35 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.springframework.ai.chat.memory; +import java.sql.Timestamp; +import java.util.List; +import java.util.UUID; + +import javax.sql.DataSource; + import com.zaxxer.hikari.HikariDataSource; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; @@ -18,16 +44,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Primary; import org.springframework.jdbc.core.JdbcTemplate; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import javax.sql.DataSource; -import java.sql.Timestamp; -import java.util.List; -import java.util.UUID; -import static org.junit.jupiter.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; /** * @author Jonathan Leijendekker @@ -59,17 +77,17 @@ class PgVectorChatMemoryIT { @Test void correctChatMemoryInstance() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { var chatMemory = context.getBean(ChatMemory.class); - assertInstanceOf(PgVectorChatMemory.class, chatMemory); + assertThat(chatMemory).isInstanceOf(PgVectorChatMemory.class); }); } @ParameterizedTest @CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER" }) void add_shouldInsertSingleMessage(String content, MessageType messageType) { - contextRunner.run(context -> { + this.contextRunner.run(context -> { var chatMemory = context.getBean(ChatMemory.class); var conversationId = UUID.randomUUID().toString(); String assistantContent = null; @@ -94,18 +112,18 @@ void add_shouldInsertSingleMessage(String content, MessageType messageType) { sessionIdColumnName); var result = jdbcTemplate.queryForMap(query, conversationId); - assertEquals(4, result.size()); - assertEquals(conversationId, result.get(sessionIdColumnName)); - assertInstanceOf(Timestamp.class, result.get(exchangeIdColumnName)); - assertNotNull(result.get(exchangeIdColumnName)); - assertEquals(assistantContent, result.get(assistantColumnName)); - assertEquals(userContent, result.get(userColumnName.replace("\"", ""))); + assertThat(result.size()).isEqualTo(4); + assertThat(result.get(sessionIdColumnName)).isEqualTo(conversationId); + assertThat(result.get(exchangeIdColumnName)).isInstanceOf(Timestamp.class); + assertThat(result.get(exchangeIdColumnName)).isNotNull(); + assertThat(result.get(assistantColumnName)).isEqualTo(assistantContent); + assertThat(result.get(userColumnName.replace("\"", ""))).isEqualTo(userContent); }); } @Test void add_shouldInsertMessages() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { var chatMemory = context.getBean(ChatMemory.class); var conversationId = UUID.randomUUID().toString(); var messages = List.of(new AssistantMessage("Message from assistant - " + conversationId), @@ -119,21 +137,21 @@ void add_shouldInsertMessages() { sessionIdColumnName); var results = jdbcTemplate.queryForList(query, conversationId); - assertEquals(messages.size(), results.size()); + assertThat(results.size()).isEqualTo(messages.size()); for (var i = 0; i < messages.size(); i++) { var message = messages.get(i); var result = results.get(i); - assertEquals(conversationId, result.get(sessionIdColumnName)); - assertInstanceOf(Timestamp.class, result.get(exchangeIdColumnName)); - assertNotNull(result.get(exchangeIdColumnName)); + assertThat(result.get(exchangeIdColumnName)).isNotNull(); + assertThat(result.get(sessionIdColumnName)).isEqualTo(conversationId); + assertThat(result.get(exchangeIdColumnName)).isInstanceOf(Timestamp.class); if (message.getMessageType() == MessageType.ASSISTANT) { - assertEquals(message.getContent(), result.get(assistantColumnName)); + assertThat(result.get(assistantColumnName)).isEqualTo(message.getContent()); } else { - assertEquals(message.getContent(), result.get(userColumnName.replace("\"", ""))); + assertThat(result.get(userColumnName.replace("\"", ""))).isEqualTo(message.getContent()); } } }); @@ -141,7 +159,7 @@ void add_shouldInsertMessages() { @Test void get_shouldReturnMessages() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { var chatMemory = context.getBean(ChatMemory.class); var conversationId = UUID.randomUUID().toString(); var messages = List.of(new AssistantMessage("Message from assistant 1 - " + conversationId), @@ -152,14 +170,14 @@ void get_shouldReturnMessages() { var results = chatMemory.get(conversationId, Integer.MAX_VALUE); - assertEquals(messages.size(), results.size()); - assertEquals(messages, results); + assertThat(results.size()).isEqualTo(messages.size()); + assertThat(results).isEqualTo(messages); }); } @Test void clear_shouldDeleteMessages() { - contextRunner.run(context -> { + this.contextRunner.run(context -> { var chatMemory = context.getBean(ChatMemory.class); var conversationId = UUID.randomUUID().toString(); var messages = List.of(new AssistantMessage("Message from assistant - " + conversationId), @@ -173,7 +191,7 @@ void clear_shouldDeleteMessages() { var count = jdbcTemplate.queryForObject(String.format("SELECT COUNT(*) FROM %s.%s WHERE %s = ?", schemaName, tableName, sessionIdColumnName), Integer.class, conversationId); - assertEquals(0, count); + assertThat(count).isZero(); }); } From 4a38070340206569a7ea2403040ccb51c42d6e7b Mon Sep 17 00:00:00 2001 From: leijendary Date: Mon, 4 Nov 2024 01:26:01 +0100 Subject: [PATCH 5/8] Fixed javaformat --- .../ai/chat/memory/PgVectorChatMemoryConfigIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java index 17bf707f744..0392c557813 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java @@ -104,7 +104,7 @@ void initializeSchema_withValueTrue_shouldCreateSchema() { assertThat(hasTable).isTrue(); assertThat(expectedColumns.containsAll(tableColumns)).isTrue(); assertThat(String.format("%s_%s_%s_idx", tableName, sessionIdColumnName, exchangeIdColumnName)) - .isEqualTo(indexName); + .isEqualTo(indexName); // Cleanup for the other tests jdbcTemplate.update(String.format("DROP SCHEMA IF EXISTS %s CASCADE", schemaName)); From 7ec3cd897f938c90dd1e141aa2a88ef1d92de620 Mon Sep 17 00:00:00 2001 From: leijendary Date: Sat, 14 Dec 2024 00:37:41 +0100 Subject: [PATCH 6/8] Moved PgVectorChatMemory to JdbcChatMemory --- README.md | 7 +- .../spring-ai-chat-memory-jdbc/README.md | 1 + .../spring-ai-chat-memory-jdbc/pom.xml | 107 ++++++++ .../ai/chat/memory/JdbcChatMemory.java | 109 ++++++++ .../ai/chat/memory/JdbcChatMemoryConfig.java | 115 +++++++++ .../aot/hint/JdbcChatMemoryRuntimeHints.java | 26 ++ .../resources/META-INF/spring/aot.factories | 2 + .../ai/chat/memory/schema-drop.mariadb.sql | 1 + .../ai/chat/memory/schema-drop.postgresql.sql | 1 + .../ai/chat/memory/schema.mariadb.sql | 10 + .../ai/chat/memory/schema.postgresql.sql | 9 + .../chat/memory/JdbcChatMemoryConfigIT.java | 244 ++++++++++++++++++ .../ai/chat/memory/JdbcChatMemoryIT.java | 78 ++---- .../hint/JdbcChatMemoryRuntimeHintsTest.java | 82 ++++++ pom.xml | 7 +- spring-ai-bom/pom.xml | 6 + .../modules/ROOT/pages/api/chatclient.adoc | 6 +- spring-ai-spring-boot-autoconfigure/pom.xml | 8 + .../JdbcChatMemoryAutoConfiguration.java} | 28 +- .../memory/jdbc/JdbcChatMemoryProperties.java | 31 +++ .../PgVectorChatMemoryProperties.java | 92 ------- ...ot.autoconfigure.AutoConfiguration.imports | 2 +- .../JdbcChatMemoryAutoConfigurationIT.java} | 15 +- .../jdbc/JdbcChatMemoryPropertiesTests.java | 42 +++ .../PgVectorChatMemoryPropertiesTests.java | 62 ----- .../pom.xml | 58 +++++ .../ai/chat/memory/PgVectorChatMemory.java | 131 ---------- .../chat/memory/PgVectorChatMemoryConfig.java | 224 ---------------- .../memory/PgVectorChatMemoryConfigIT.java | 174 ------------- 29 files changed, 905 insertions(+), 773 deletions(-) create mode 100644 chat-memory/spring-ai-chat-memory-jdbc/README.md create mode 100644 chat-memory/spring-ai-chat-memory-jdbc/pom.xml create mode 100644 chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemory.java create mode 100644 chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemoryConfig.java create mode 100644 chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/aot/hint/JdbcChatMemoryRuntimeHints.java create mode 100644 chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/META-INF/spring/aot.factories create mode 100644 chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema-drop.mariadb.sql create mode 100644 chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema-drop.postgresql.sql create mode 100644 chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.mariadb.sql create mode 100644 chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.postgresql.sql create mode 100644 chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/JdbcChatMemoryConfigIT.java rename vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryIT.java => chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/JdbcChatMemoryIT.java (69%) create mode 100644 chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/aot/hint/JdbcChatMemoryRuntimeHintsTest.java rename spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/{pgvector/PgVectorChatMemoryAutoConfiguration.java => jdbc/JdbcChatMemoryAutoConfiguration.java} (55%) create mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryProperties.java delete mode 100644 spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryProperties.java rename spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/{pgvector/PgVectorChatMemoryAutoConfigurationIT.java => jdbc/JdbcChatMemoryAutoConfigurationIT.java} (87%) create mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryPropertiesTests.java delete mode 100644 spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryPropertiesTests.java create mode 100644 spring-ai-spring-boot-starters/spring-ai-starter-chat-memory-jdbc/pom.xml delete mode 100644 vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemory.java delete mode 100644 vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfig.java delete mode 100644 vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java diff --git a/README.md b/README.md index c57ad89e468..0a48fad155f 100644 --- a/README.md +++ b/README.md @@ -100,9 +100,9 @@ One way to run integration tests on part of the code is to first do a quick comp ```shell ./mvnw clean install -DskipTests -Dmaven.javadoc.skip=true ``` -Then run the integration test for a specifi module using the `-pl` option +Then run the integration test for a specific module using the `-pl` option ```shell -./mvnw verify -Pintegration-tests -pl spring-ai-spring-boot-autoconfigure +./mvnw verify -Pintegration-tests -pl spring-ai-spring-boot-autoconfigure ``` ### Documentation @@ -134,6 +134,3 @@ Checkstyles are currently disabled, but you can enable them by doing the followi ```shell ./mvnw clean package -DskipTests -Ddisable.checks=false ``` - - - diff --git a/chat-memory/spring-ai-chat-memory-jdbc/README.md b/chat-memory/spring-ai-chat-memory-jdbc/README.md new file mode 100644 index 00000000000..8e100ad20a3 --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/README.md @@ -0,0 +1 @@ +[Chat Memory Documentation](https://docs.spring.io/spring-ai/reference/api/chatclient.html#_chat_memory) diff --git a/chat-memory/spring-ai-chat-memory-jdbc/pom.xml b/chat-memory/spring-ai-chat-memory-jdbc/pom.xml new file mode 100644 index 00000000000..61e1737e263 --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/pom.xml @@ -0,0 +1,107 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-chat-memory-jdbc + jar + Spring AI Chat Memory JDBC + Spring AI Chat Memory implementation with JDBC + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + 17 + 17 + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + com.zaxxer + HikariCP + + + + org.springframework + spring-jdbc + + + + org.postgresql + postgresql + ${postgresql.version} + true + + + + org.mariadb.jdbc + mariadb-java-client + ${mariadb.version} + true + + + + + org.springframework.boot + spring-boot-starter-test + test + + + + org.testcontainers + testcontainers + test + + + + org.testcontainers + postgresql + test + + + + org.testcontainers + mariadb + test + + + + org.testcontainers + junit-jupiter + test + + + diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemory.java b/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemory.java new file mode 100644 index 00000000000..7f6957d7fd9 --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemory.java @@ -0,0 +1,109 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.List; + +import org.springframework.ai.chat.messages.*; +import org.springframework.jdbc.core.BatchPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.RowMapper; + +/** + * An implementation of {@link ChatMemory} for JDBC. When this class is instantiated, + * {@link JdbcChatMemoryConfig#initializeSchema()} will automatically be called. Creating + * an instance of JdbcChatMemory example: + * JdbcChatMemory.create(JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build()); + * + * @author Jonathan Leijendekker + * @since 1.0.0 + */ +public class JdbcChatMemory implements ChatMemory { + + private static final String QUERY_ADD = """ + INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)"""; + + private static final String QUERY_GET = """ + SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC LIMIT ?"""; + + private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?"; + + private final JdbcTemplate jdbcTemplate; + + public JdbcChatMemory(JdbcChatMemoryConfig config) { + config.initializeSchema(); + + this.jdbcTemplate = config.getJdbcTemplate(); + } + + public static JdbcChatMemory create(JdbcChatMemoryConfig config) { + return new JdbcChatMemory(config); + } + + @Override + public void add(String conversationId, List messages) { + this.jdbcTemplate.batchUpdate(QUERY_ADD, new AddBatchPreparedStatement(conversationId, messages)); + } + + @Override + public List get(String conversationId, int lastN) { + return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN); + } + + @Override + public void clear(String conversationId) { + this.jdbcTemplate.update(QUERY_CLEAR, conversationId); + } + + private record AddBatchPreparedStatement(String conversationId, + List messages) implements BatchPreparedStatementSetter { + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + var message = this.messages.get(i); + + ps.setString(1, this.conversationId); + ps.setString(2, message.getText()); + ps.setString(3, message.getMessageType().getValue()); + } + + @Override + public int getBatchSize() { + return this.messages.size(); + } + } + + private static class MessageRowMapper implements RowMapper { + + @Override + public Message mapRow(ResultSet rs, int i) throws SQLException { + var content = rs.getString(1); + var type = MessageType.fromValue(rs.getString(2)); + + return switch (type) { + case USER -> new UserMessage(content); + case ASSISTANT -> new AssistantMessage(content); + case SYSTEM -> new SystemMessage(content); + default -> null; + }; + } + + } + +} diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemoryConfig.java b/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemoryConfig.java new file mode 100644 index 00000000000..cda7fba971f --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemoryConfig.java @@ -0,0 +1,115 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import java.sql.SQLException; +import java.util.Objects; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.jdbc.CannotGetJdbcConnectionException; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.datasource.init.ResourceDatabasePopulator; +import org.springframework.util.Assert; + +/** + * Configuration for the JDBC {@link ChatMemory}. When + * {@link JdbcChatMemoryConfig#initializeSchema} is set to {@code true} (default is + * {@code false}) and {@link #initializeSchema()} is called, then the schema based on the + * database will be created. + * + * @author Jonathan Leijendekker + * @since 1.0.0 + */ +public final class JdbcChatMemoryConfig { + + private static final Logger logger = LoggerFactory.getLogger(JdbcChatMemoryConfig.class); + + public static final boolean DEFAULT_SCHEMA_INITIALIZATION = false; + + private final boolean initializeSchema; + + private final JdbcTemplate jdbcTemplate; + + private JdbcChatMemoryConfig(Builder builder) { + this.initializeSchema = builder.initializeSchema; + this.jdbcTemplate = builder.jdbcTemplate; + } + + public static Builder builder() { + return new Builder(); + } + + JdbcTemplate getJdbcTemplate() { + return this.jdbcTemplate; + } + + void initializeSchema() { + if (!this.initializeSchema) { + return; + } + + logger.info("Initializing JdbcChatMemory schema"); + + String productName; + + try (var connection = Objects.requireNonNull(this.jdbcTemplate.getDataSource()).getConnection()) { + var metadata = connection.getMetaData(); + productName = metadata.getDatabaseProductName(); + } + catch (SQLException e) { + throw new CannotGetJdbcConnectionException("Failed to obtain JDBC metadata", e); + } + + var fileName = String.format("schema.%s.sql", productName.toLowerCase()); + var resource = new ClassPathResource(fileName, JdbcChatMemoryConfig.class); + var databasePopulator = new ResourceDatabasePopulator(resource); + databasePopulator.execute(this.jdbcTemplate.getDataSource()); + } + + public static final class Builder { + + private boolean initializeSchema = DEFAULT_SCHEMA_INITIALIZATION; + + private JdbcTemplate jdbcTemplate; + + private Builder() { + } + + public Builder setInitializeSchema(boolean initializeSchema) { + this.initializeSchema = initializeSchema; + return this; + } + + public Builder jdbcTemplate(JdbcTemplate jdbcTemplate) { + Assert.notNull(jdbcTemplate, "jdbc template must not be null"); + + this.jdbcTemplate = jdbcTemplate; + return this; + } + + public JdbcChatMemoryConfig build() { + Assert.notNull(this.jdbcTemplate, "jdbc template must not be null"); + + return new JdbcChatMemoryConfig(this); + } + + } + +} diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/aot/hint/JdbcChatMemoryRuntimeHints.java b/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/aot/hint/JdbcChatMemoryRuntimeHints.java new file mode 100644 index 00000000000..2ffc99af9f7 --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/aot/hint/JdbcChatMemoryRuntimeHints.java @@ -0,0 +1,26 @@ +package org.springframework.ai.chat.memory.aot.hint; + +import javax.sql.DataSource; + +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; + +/** + * A {@link RuntimeHintsRegistrar} for JDBC Chat Memory hints + * + * @author Jonathan Leijendekker + */ +class JdbcChatMemoryRuntimeHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(RuntimeHints hints, ClassLoader classLoader) { + hints.reflection() + .registerType(DataSource.class, (hint) -> hint.withMembers(MemberCategory.INVOKE_DECLARED_METHODS)); + + hints.resources() + .registerPattern("org/springframework/ai/chat/memory/schema.mariadb.sql") + .registerPattern("org/springframework/ai/chat/memory/schema.postgresql.sql"); + } + +} diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/META-INF/spring/aot.factories b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/META-INF/spring/aot.factories new file mode 100644 index 00000000000..1877d6377a4 --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/META-INF/spring/aot.factories @@ -0,0 +1,2 @@ +org.springframework.aot.hint.RuntimeHintsRegistrar=\ +org.springframework.ai.chat.memory.aot.hint.JdbcChatMemoryRuntimeHints diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema-drop.mariadb.sql b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema-drop.mariadb.sql new file mode 100644 index 00000000000..72f313114ba --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema-drop.mariadb.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS ai_chat_memory; diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema-drop.postgresql.sql b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema-drop.postgresql.sql new file mode 100644 index 00000000000..72f313114ba --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema-drop.postgresql.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS ai_chat_memory; diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.mariadb.sql b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.mariadb.sql new file mode 100644 index 00000000000..3e3256151b8 --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.mariadb.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS ai_chat_memory ( + conversation_id VARCHAR(36) NOT NULL, + content TEXT NOT NULL, + type VARCHAR(10) NOT NULL, + `timestamp` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT type_check CHECK (type IN ('user', 'assistant')) +); + +CREATE INDEX ai_chat_memory_conversation_id_timestamp_idx +ON ai_chat_memory(conversation_id, `timestamp` DESC); diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.postgresql.sql b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.postgresql.sql new file mode 100644 index 00000000000..3c93d24ad77 --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.postgresql.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS ai_chat_memory ( + conversation_id VARCHAR(36) NOT NULL, + content text NOT NULL, + type VARCHAR(10) NOT NULL CHECK (type IN ('user', 'assistant')), + "timestamp" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX IF NOT EXISTS ai_chat_memory_conversation_id_timestamp_idx +ON ai_chat_memory(conversation_id, "timestamp" DESC); diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/JdbcChatMemoryConfigIT.java b/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/JdbcChatMemoryConfigIT.java new file mode 100644 index 00000000000..b6bb75f87c0 --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/JdbcChatMemoryConfigIT.java @@ -0,0 +1,244 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import java.util.List; +import java.util.Objects; + +import javax.sql.DataSource; + +import com.zaxxer.hikari.HikariDataSource; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.MariaDBContainer; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Primary; +import org.springframework.core.io.ClassPathResource; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.datasource.init.ScriptUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Jonathan Leijendekker + */ +@Testcontainers +class JdbcChatMemoryConfigIT { + + @Nested + class PostgresChatMemoryConfigIT { + + @Container + @SuppressWarnings("resource") + static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>("postgres:17") + .withDatabaseName("chat_memory_config_test") + .withUsername("postgres") + .withPassword("postgres"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class) + .withPropertyValues( + // JdbcTemplate configuration + String.format("app.datasource.url=%s", postgresContainer.getJdbcUrl()), + String.format("app.datasource.username=%s", postgresContainer.getUsername()), + String.format("app.datasource.password=%s", postgresContainer.getPassword()), + "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); + + @Test + void initializeSchema_withValueTrue_shouldCreateSchema() { + this.contextRunner.run(context -> { + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var config = JdbcChatMemoryConfig.builder() + .setInitializeSchema(true) + .jdbcTemplate(jdbcTemplate) + .build(); + config.initializeSchema(); + + var expectedColumns = List.of("conversation_id", "content", "type", "timestamp"); + + // Verify that the table and index are created + var hasTable = jdbcTemplate.queryForObject( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = ?)", boolean.class, + "ai_chat_memory"); + var tableColumns = jdbcTemplate.queryForList( + "SELECT column_name FROM information_schema.columns WHERE table_name = ?", String.class, + "ai_chat_memory"); + var indexName = jdbcTemplate.queryForObject("SELECT indexname FROM pg_indexes WHERE tablename = ?", + String.class, "ai_chat_memory"); + + assertThat(hasTable).isTrue(); + assertThat(expectedColumns.containsAll(tableColumns)).isTrue(); + assertThat(indexName).isEqualTo("ai_chat_memory_conversation_id_timestamp_idx"); + }); + } + + @Test + void initializeSchema_withValueFalse_shouldNotCreateSchema() { + this.contextRunner.run(context -> { + var jdbcTemplate = context.getBean(JdbcTemplate.class); + + // Make sure the schema does not exist in the first place + ScriptUtils.executeSqlScript(Objects.requireNonNull(jdbcTemplate.getDataSource()).getConnection(), + new ClassPathResource("schema-drop.postgresql.sql", JdbcChatMemoryConfigIT.class)); + + var config = JdbcChatMemoryConfig.builder() + .setInitializeSchema(false) + .jdbcTemplate(jdbcTemplate) + .build(); + config.initializeSchema(); + + // Verify that the table and index was created + var hasTable = jdbcTemplate.queryForObject( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = ?)", boolean.class, + "ai_chat_memory"); + var columnCount = jdbcTemplate.queryForObject( + "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = ?", Integer.class, + "ai_chat_memory"); + var hasIndex = jdbcTemplate.queryForObject( + "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE tablename = ?)", Boolean.class, + "ai_chat_memory"); + + assertThat(hasTable).isFalse(); + assertThat(columnCount).isZero(); + assertThat(hasIndex).isFalse(); + }); + } + + } + + @Nested + class MariaChatMemoryConfigIT { + + @Container + @SuppressWarnings("resource") + static MariaDBContainer mariaContainer = new MariaDBContainer<>("mariadb:11") + .withDatabaseName("chat_memory_config_test") + .withUsername("mysql") + .withPassword("mysql"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(JdbcChatMemoryConfigIT.TestApplication.class) + .withPropertyValues( + // JdbcTemplate configuration + String.format("app.datasource.url=%s", mariaContainer.getJdbcUrl()), + String.format("app.datasource.username=%s", mariaContainer.getUsername()), + String.format("app.datasource.password=%s", mariaContainer.getPassword()), + "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); + + @Test + void initializeSchema_withValueTrue_shouldCreateSchema() { + this.contextRunner.run(context -> { + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var config = JdbcChatMemoryConfig.builder() + .setInitializeSchema(true) + .jdbcTemplate(jdbcTemplate) + .build(); + config.initializeSchema(); + + var expectedColumns = List.of("conversation_id", "content", "type", "timestamp"); + + // Verify that the table and index are created + var hasTable = jdbcTemplate.queryForObject( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = ?)", boolean.class, + "ai_chat_memory"); + var tableColumns = jdbcTemplate.queryForList( + "SELECT column_name FROM information_schema.columns WHERE table_name = ?", String.class, + "ai_chat_memory"); + var indexNames = jdbcTemplate.queryForList( + "SELECT index_name, column_name FROM information_schema.statistics WHERE table_name = ?", + "ai_chat_memory"); + + assertThat(hasTable).isTrue(); + assertThat(expectedColumns.containsAll(tableColumns)).isTrue(); + assertThat(indexNames).hasSize(2); + assertThat(indexNames.get(0).get("index_name")) + .isEqualTo("ai_chat_memory_conversation_id_timestamp_idx"); + assertThat(indexNames.get(0).get("column_name")).isEqualTo("conversation_id"); + assertThat(indexNames.get(1).get("index_name")) + .isEqualTo("ai_chat_memory_conversation_id_timestamp_idx"); + assertThat(indexNames.get(1).get("column_name")).isEqualTo("timestamp"); + }); + } + + @Test + void initializeSchema_withValueFalse_shouldNotCreateSchema() { + this.contextRunner.run(context -> { + var jdbcTemplate = context.getBean(JdbcTemplate.class); + + // Make sure the schema does not exist in the first place + ScriptUtils.executeSqlScript(Objects.requireNonNull(jdbcTemplate.getDataSource()).getConnection(), + new ClassPathResource("schema-drop.mariadb.sql", JdbcChatMemoryConfigIT.class)); + + var config = JdbcChatMemoryConfig.builder() + .setInitializeSchema(false) + .jdbcTemplate(jdbcTemplate) + .build(); + config.initializeSchema(); + + // Verify that the table and index was created + var hasTable = jdbcTemplate.queryForObject( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = ?)", boolean.class, + "ai_chat_memory"); + var columnCount = jdbcTemplate.queryForObject( + "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = ?", Integer.class, + "ai_chat_memory"); + var hasIndex = jdbcTemplate.queryForObject( + "SELECT EXISTS (SELECT 1 FROM information_schema.statistics WHERE table_name = ?)", + Boolean.class, "ai_chat_memory"); + + assertThat(hasTable).isFalse(); + assertThat(columnCount).isZero(); + assertThat(hasIndex).isFalse(); + }); + } + + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class TestApplication { + + @Bean + public JdbcTemplate jdbcTemplate(DataSource dataSource) { + return new JdbcTemplate(dataSource); + } + + @Bean + @Primary + @ConfigurationProperties("app.datasource") + public DataSourceProperties dataSourceProperties() { + return new DataSourceProperties(); + } + + @Bean + public HikariDataSource dataSource(DataSourceProperties dataSourceProperties) { + return dataSourceProperties.initializeDataSourceBuilder().type(HikariDataSource.class).build(); + } + + } + +} diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryIT.java b/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/JdbcChatMemoryIT.java similarity index 69% rename from vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryIT.java rename to chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/JdbcChatMemoryIT.java index 0f6e8e0236a..5d7308a8906 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryIT.java +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/JdbcChatMemoryIT.java @@ -34,7 +34,6 @@ import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.vectorstore.PgVectorImage; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; @@ -51,11 +50,12 @@ * @author Jonathan Leijendekker */ @Testcontainers -class PgVectorChatMemoryIT { +class JdbcChatMemoryIT { @Container @SuppressWarnings("resource") - static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>(PgVectorImage.DEFAULT_IMAGE) + static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>("postgres:17") + .withDatabaseName("chat_memory_test") .withUsername("postgres") .withPassword("postgres"); @@ -68,19 +68,12 @@ class PgVectorChatMemoryIT { String.format("app.datasource.password=%s", postgresContainer.getPassword()), "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); - static String schemaName = "public_test"; - static String tableName = "ai_chat_memory_test"; - static String sessionIdColumnName = "session_id_test"; - static String exchangeIdColumnName = "message_timestamp_test"; - static String assistantColumnName = "assistant_test"; - static String userColumnName = "\"user_test\""; - @Test void correctChatMemoryInstance() { this.contextRunner.run(context -> { var chatMemory = context.getBean(ChatMemory.class); - assertThat(chatMemory).isInstanceOf(PgVectorChatMemory.class); + assertThat(chatMemory).isInstanceOf(JdbcChatMemory.class); }); } @@ -90,34 +83,23 @@ void add_shouldInsertSingleMessage(String content, MessageType messageType) { this.contextRunner.run(context -> { var chatMemory = context.getBean(ChatMemory.class); var conversationId = UUID.randomUUID().toString(); - String assistantContent = null; - String userContent = null; var message = switch (messageType) { - case ASSISTANT -> { - assistantContent = content + " - " + conversationId; - yield new AssistantMessage(assistantContent); - } - case USER -> { - userContent = content + " - " + conversationId; - yield new UserMessage(userContent); - } + case ASSISTANT -> new AssistantMessage(content + " - " + conversationId); + case USER -> new UserMessage(content + " - " + conversationId); default -> throw new IllegalArgumentException("Type not supported: " + messageType); }; chatMemory.add(conversationId, message); var jdbcTemplate = context.getBean(JdbcTemplate.class); - var query = String.format("SELECT %s, %s, %s, %s FROM %s.%s WHERE %s = ?", sessionIdColumnName, - exchangeIdColumnName, assistantColumnName, userColumnName, schemaName, tableName, - sessionIdColumnName); + var query = "SELECT conversation_id, content, type, \"timestamp\" FROM ai_chat_memory WHERE conversation_id = ?"; var result = jdbcTemplate.queryForMap(query, conversationId); assertThat(result.size()).isEqualTo(4); - assertThat(result.get(sessionIdColumnName)).isEqualTo(conversationId); - assertThat(result.get(exchangeIdColumnName)).isInstanceOf(Timestamp.class); - assertThat(result.get(exchangeIdColumnName)).isNotNull(); - assertThat(result.get(assistantColumnName)).isEqualTo(assistantContent); - assertThat(result.get(userColumnName.replace("\"", ""))).isEqualTo(userContent); + assertThat(result.get("conversation_id")).isEqualTo(conversationId); + assertThat(result.get("content")).isEqualTo(message.getText()); + assertThat(result.get("type")).isEqualTo(messageType.getValue()); + assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class); }); } @@ -132,9 +114,7 @@ void add_shouldInsertMessages() { chatMemory.add(conversationId, messages); var jdbcTemplate = context.getBean(JdbcTemplate.class); - var query = String.format("SELECT %s, %s, %s, %s FROM %s.%s WHERE %s = ?", sessionIdColumnName, - exchangeIdColumnName, assistantColumnName, userColumnName, schemaName, tableName, - sessionIdColumnName); + var query = "SELECT conversation_id, content, type, \"timestamp\" FROM ai_chat_memory WHERE conversation_id = ?"; var results = jdbcTemplate.queryForList(query, conversationId); assertThat(results.size()).isEqualTo(messages.size()); @@ -143,16 +123,11 @@ void add_shouldInsertMessages() { var message = messages.get(i); var result = results.get(i); - assertThat(result.get(exchangeIdColumnName)).isNotNull(); - assertThat(result.get(sessionIdColumnName)).isEqualTo(conversationId); - assertThat(result.get(exchangeIdColumnName)).isInstanceOf(Timestamp.class); - - if (message.getMessageType() == MessageType.ASSISTANT) { - assertThat(result.get(assistantColumnName)).isEqualTo(message.getContent()); - } - else { - assertThat(result.get(userColumnName.replace("\"", ""))).isEqualTo(message.getContent()); - } + assertThat(result.get("conversation_id")).isNotNull(); + assertThat(result.get("conversation_id")).isEqualTo(conversationId); + assertThat(result.get("content")).isEqualTo(message.getText()); + assertThat(result.get("type")).isEqualTo(message.getMessageType().getValue()); + assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class); } }); } @@ -188,8 +163,8 @@ void clear_shouldDeleteMessages() { chatMemory.clear(conversationId); var jdbcTemplate = context.getBean(JdbcTemplate.class); - var count = jdbcTemplate.queryForObject(String.format("SELECT COUNT(*) FROM %s.%s WHERE %s = ?", schemaName, - tableName, sessionIdColumnName), Integer.class, conversationId); + var count = jdbcTemplate.queryForObject("SELECT COUNT(*) FROM ai_chat_memory WHERE conversation_id = ?", + Integer.class, conversationId); assertThat(count).isZero(); }); @@ -201,18 +176,9 @@ static class TestApplication { @Bean public ChatMemory chatMemory(JdbcTemplate jdbcTemplate) { - var config = PgVectorChatMemoryConfig.builder() - .withInitializeSchema(true) - .withSchemaName(schemaName) - .withTableName(tableName) - .withSessionIdColumnName(sessionIdColumnName) - .withExchangeIdColumnName(exchangeIdColumnName) - .withAssistantColumnName(assistantColumnName) - .withUserColumnName(userColumnName) - .withJdbcTemplate(jdbcTemplate) - .build(); - - return PgVectorChatMemory.create(config); + var config = JdbcChatMemoryConfig.builder().setInitializeSchema(true).jdbcTemplate(jdbcTemplate).build(); + + return JdbcChatMemory.create(config); } @Bean diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/aot/hint/JdbcChatMemoryRuntimeHintsTest.java b/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/aot/hint/JdbcChatMemoryRuntimeHintsTest.java new file mode 100644 index 00000000000..5abf8495828 --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/aot/hint/JdbcChatMemoryRuntimeHintsTest.java @@ -0,0 +1,82 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory.aot.hint; + +import java.io.IOException; +import java.util.Arrays; +import java.util.stream.Stream; + +import javax.sql.DataSource; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.aot.hint.predicate.RuntimeHintsPredicates; +import org.springframework.core.io.Resource; +import org.springframework.core.io.support.PathMatchingResourcePatternResolver; +import org.springframework.core.io.support.SpringFactoriesLoader; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Jonathan Leijendekker + */ +class JdbcChatMemoryRuntimeHintsTest { + + private final RuntimeHints hints = new RuntimeHints(); + + private final JdbcChatMemoryRuntimeHints jdbcChatMemoryRuntimeHints = new JdbcChatMemoryRuntimeHints(); + + @Test + void aotFactoriesContainsRegistrar() { + var match = SpringFactoriesLoader.forResourceLocation("META-INF/spring/aot.factories") + .load(RuntimeHintsRegistrar.class) + .stream() + .anyMatch((registrar) -> registrar instanceof JdbcChatMemoryRuntimeHints); + + assertThat(match).isTrue(); + } + + @ParameterizedTest + @MethodSource("getSchemaFileNames") + void jdbcSchemasHasHints(String schemaFileName) { + this.jdbcChatMemoryRuntimeHints.registerHints(this.hints, getClass().getClassLoader()); + + var predicate = RuntimeHintsPredicates.resource() + .forResource("org/springframework/ai/chat/memory/" + schemaFileName); + + assertThat(predicate).accepts(this.hints); + } + + @Test + void dataSourceHasHints() { + this.jdbcChatMemoryRuntimeHints.registerHints(this.hints, getClass().getClassLoader()); + + assertThat(RuntimeHintsPredicates.reflection().onType(DataSource.class)).accepts(this.hints); + } + + private static Stream getSchemaFileNames() throws IOException { + var resources = new PathMatchingResourcePatternResolver() + .getResources("classpath*:org/springframework/ai/chat/memory/schema.*.sql"); + + return Arrays.stream(resources).map(Resource::getFilename); + } + +} diff --git a/pom.xml b/pom.xml index a1383154451..6cd4909fac1 100644 --- a/pom.xml +++ b/pom.xml @@ -65,6 +65,8 @@ vector-stores/spring-ai-typesense-store vector-stores/spring-ai-weaviate-store + chat-memory/spring-ai-chat-memory-jdbc + spring-ai-spring-boot-starters/spring-ai-starter-aws-opensearch-store spring-ai-spring-boot-starters/spring-ai-starter-azure-cosmos-db-store spring-ai-spring-boot-starters/spring-ai-starter-azure-store @@ -111,6 +113,7 @@ spring-ai-spring-boot-starters/spring-ai-starter-azure-openai spring-ai-spring-boot-starters/spring-ai-starter-bedrock-ai spring-ai-spring-boot-starters/spring-ai-starter-bedrock-converse + spring-ai-spring-boot-starters/spring-ai-starter-chat-memory-jdbc spring-ai-spring-boot-starters/spring-ai-starter-huggingface spring-ai-spring-boot-starters/spring-ai-starter-minimax spring-ai-spring-boot-starters/spring-ai-starter-mistral-ai @@ -647,6 +650,9 @@ --> + + org.springframework.ai.chat.memory/**/*IT.java + org.springframework.ai.anthropic/**/*IT.java org.springframework.ai.azure.openai/**/*IT.java @@ -670,7 +676,6 @@ org.springframework.ai.vectorstore**/CosmosDB**IT.java org.springframework.ai.vectorstore.azure/**IT.java - org.springframework.ai.chat.memory/**/Cassandra**IT.java org.springframework.ai.vectorstore**/Cassandra**IT.java org.springframework.ai.chroma/**IT.java diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index fb8fee83f4e..73acadbb4a6 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -581,6 +581,12 @@ ${project.version} + + org.springframework.ai + spring-ai-chat-memory-jdbc-spring-boot-starter + ${project.version} + + diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc index 85045b4b443..0035b6abab9 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc @@ -410,7 +410,7 @@ The `FILTER_EXPRESSION` parameter allows you to dynamically filter the search re The interface `ChatMemory` represents a storage for chat conversation history. It provides methods to add messages to a conversation, retrieve messages from a conversation, and clear the conversation history. -There are currently three implementations, `InMemoryChatMemory`, `CassandraChatMemory`, and `PgVectorChatMemory` that provide storage for chat conversation history, in-memory and persisted with `time-to-live`, correspondingly. +There are currently three implementations, `InMemoryChatMemory`, `CassandraChatMemory`, and `JdbcChatMemory` that provide storage for chat conversation history, in-memory and persisted with `time-to-live`, correspondingly. To create a `CassandraChatMemory` with `time-to-live`: @@ -419,11 +419,11 @@ To create a `CassandraChatMemory` with `time-to-live`: CassandraChatMemory.create(CassandraChatMemoryConfig.builder().withTimeToLive(Duration.ofDays(1)).build()); ---- -To create a `PgVectorChatMemory`: +To create a `JdbcChatMemory`: [source,java] ---- -PgVectorChatMemory.create(PgVectorChatMemoryConfig.builder().withJdbcTemplate(jdbcTemplate).build()); +JdbcChatMemory.create(JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build()); ---- The following advisor implementations use the `ChatMemory` interface to advice the prompt with conversation history which differ in the details of how the memory is added to the prompt diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index d2abc8de56e..e865a43acf6 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -412,6 +412,14 @@ true + + + org.springframework.ai + spring-ai-chat-memory-jdbc + ${project.parent.version} + true + + diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryAutoConfiguration.java similarity index 55% rename from spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfiguration.java rename to spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryAutoConfiguration.java index ff8f237b1b7..264d20db6fb 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryAutoConfiguration.java @@ -14,12 +14,12 @@ * limitations under the License. */ -package org.springframework.ai.autoconfigure.chat.memory.pgvector; +package org.springframework.ai.autoconfigure.chat.memory.jdbc; import javax.sql.DataSource; -import org.springframework.ai.chat.memory.PgVectorChatMemory; -import org.springframework.ai.chat.memory.PgVectorChatMemoryConfig; +import org.springframework.ai.chat.memory.JdbcChatMemory; +import org.springframework.ai.chat.memory.JdbcChatMemoryConfig; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; @@ -33,25 +33,19 @@ * @since 1.0.0 */ @AutoConfiguration(after = JdbcTemplateAutoConfiguration.class) -@ConditionalOnClass({ PgVectorChatMemory.class, DataSource.class, JdbcTemplate.class }) -@EnableConfigurationProperties(PgVectorChatMemoryProperties.class) -public class PgVectorChatMemoryAutoConfiguration { +@ConditionalOnClass({ JdbcChatMemory.class, DataSource.class, JdbcTemplate.class }) +@EnableConfigurationProperties(JdbcChatMemoryProperties.class) +public class JdbcChatMemoryAutoConfiguration { @Bean @ConditionalOnMissingBean - public PgVectorChatMemory chatMemory(PgVectorChatMemoryProperties properties, JdbcTemplate jdbcTemplate) { - var config = PgVectorChatMemoryConfig.builder() - .withInitializeSchema(properties.isInitializeSchema()) - .withSchemaName(properties.getSchemaName()) - .withTableName(properties.getTableName()) - .withSessionIdColumnName(properties.getSessionIdColumnName()) - .withExchangeIdColumnName(properties.getExchangeIdColumnName()) - .withAssistantColumnName(properties.getAssistantColumnName()) - .withUserColumnName(properties.getUserColumnName()) - .withJdbcTemplate(jdbcTemplate) + public JdbcChatMemory chatMemory(JdbcChatMemoryProperties properties, JdbcTemplate jdbcTemplate) { + var config = JdbcChatMemoryConfig.builder() + .setInitializeSchema(properties.isInitializeSchema()) + .jdbcTemplate(jdbcTemplate) .build(); - return PgVectorChatMemory.create(config); + return JdbcChatMemory.create(config); } } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryProperties.java new file mode 100644 index 00000000000..f13aabe5b25 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryProperties.java @@ -0,0 +1,31 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.chat.memory.jdbc; + +import org.springframework.ai.autoconfigure.chat.memory.CommonChatMemoryProperties; +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * @author Jonathan Leijendekker + * @since 1.0.0 + */ +@ConfigurationProperties(JdbcChatMemoryProperties.CONFIG_PREFIX) +public class JdbcChatMemoryProperties extends CommonChatMemoryProperties { + + public static final String CONFIG_PREFIX = "spring.ai.chat.memory.jdbc"; + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryProperties.java deleted file mode 100644 index ab35c7204a8..00000000000 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryProperties.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.autoconfigure.chat.memory.pgvector; - -import org.springframework.ai.autoconfigure.chat.memory.CommonChatMemoryProperties; -import org.springframework.ai.chat.memory.PgVectorChatMemoryConfig; -import org.springframework.boot.context.properties.ConfigurationProperties; - -/** - * @author Jonathan Leijendekker - * @since 1.0.0 - */ -@ConfigurationProperties(PgVectorChatMemoryProperties.CONFIG_PREFIX) -public class PgVectorChatMemoryProperties extends CommonChatMemoryProperties { - - public static final String CONFIG_PREFIX = "spring.ai.chat.memory.pgvector"; - - private String schemaName = PgVectorChatMemoryConfig.DEFAULT_SCHEMA_NAME; - - private String tableName = PgVectorChatMemoryConfig.DEFAULT_TABLE_NAME; - - private String sessionIdColumnName = PgVectorChatMemoryConfig.DEFAULT_SESSION_ID_COLUMN_NAME; - - private String exchangeIdColumnName = PgVectorChatMemoryConfig.DEFAULT_EXCHANGE_ID_COLUMN_NAME; - - private String assistantColumnName = PgVectorChatMemoryConfig.DEFAULT_ASSISTANT_COLUMN_NAME; - - private String userColumnName = PgVectorChatMemoryConfig.DEFAULT_USER_COLUMN_NAME; - - public String getSchemaName() { - return this.schemaName; - } - - public void setSchemaName(String schemaName) { - this.schemaName = schemaName; - } - - public String getTableName() { - return this.tableName; - } - - public void setTableName(String tableName) { - this.tableName = tableName; - } - - public String getSessionIdColumnName() { - return this.sessionIdColumnName; - } - - public void setSessionIdColumnName(String sessionIdColumnName) { - this.sessionIdColumnName = sessionIdColumnName; - } - - public String getExchangeIdColumnName() { - return this.exchangeIdColumnName; - } - - public void setExchangeIdColumnName(String exchangeIdColumnName) { - this.exchangeIdColumnName = exchangeIdColumnName; - } - - public String getAssistantColumnName() { - return this.assistantColumnName; - } - - public void setAssistantColumnName(String assistantColumnName) { - this.assistantColumnName = assistantColumnName; - } - - public String getUserColumnName() { - return this.userColumnName; - } - - public void setUserColumnName(String userColumnName) { - this.userColumnName = userColumnName; - } - -} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index a32af7e3c2c..2a7e2ae859d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -66,4 +66,4 @@ org.springframework.ai.autoconfigure.minimax.MiniMaxAutoConfiguration org.springframework.ai.autoconfigure.vertexai.embedding.VertexAiEmbeddingAutoConfiguration org.springframework.ai.autoconfigure.chat.memory.cassandra.CassandraChatMemoryAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.observation.VectorStoreObservationAutoConfiguration -org.springframework.ai.autoconfigure.chat.memory.pgvector.PgVectorChatMemoryAutoConfiguration +org.springframework.ai.autoconfigure.chat.memory.jdbc.JdbcChatMemoryAutoConfiguration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryAutoConfigurationIT.java similarity index 87% rename from spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfigurationIT.java rename to spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryAutoConfigurationIT.java index f5b81f667d7..850534d615e 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryAutoConfigurationIT.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.autoconfigure.chat.memory.pgvector; +package org.springframework.ai.autoconfigure.chat.memory.jdbc; import java.util.List; import java.util.UUID; @@ -24,7 +24,7 @@ import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; -import org.springframework.ai.chat.memory.PgVectorChatMemory; +import org.springframework.ai.chat.memory.JdbcChatMemory; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; @@ -39,18 +39,19 @@ * @author Jonathan Leijendekker */ @Testcontainers -class PgVectorChatMemoryAutoConfigurationIT { +class JdbcChatMemoryAutoConfigurationIT { @Container @SuppressWarnings("resource") - static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>("pgvector/pgvector:pg17") + static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>("postgres:17") + .withDatabaseName("chat_memory_auto_configuration_test") .withUsername("postgres") .withPassword("postgres"); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(PgVectorChatMemoryAutoConfiguration.class, + .withConfiguration(AutoConfigurations.of(JdbcChatMemoryAutoConfiguration.class, JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class)) - .withPropertyValues("spring.ai.chat.memory.pgvector.schemaName=test_autoconfigure", + .withPropertyValues( // JdbcTemplate configuration String.format("spring.datasource.url=%s", postgresContainer.getJdbcUrl()), String.format("spring.datasource.username=%s", postgresContainer.getUsername()), @@ -60,7 +61,7 @@ class PgVectorChatMemoryAutoConfigurationIT { @Test void addGetAndClear_shouldAllExecute() { this.contextRunner.run(context -> { - var chatMemory = context.getBean(PgVectorChatMemory.class); + var chatMemory = context.getBean(JdbcChatMemory.class); var conversationId = UUID.randomUUID().toString(); var userMessage = new UserMessage("Message from the user"); diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryPropertiesTests.java new file mode 100644 index 00000000000..2238389ceb6 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryPropertiesTests.java @@ -0,0 +1,42 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.chat.memory.jdbc; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Jonathan Leijendekker + */ +class JdbcChatMemoryPropertiesTests { + + @Test + void defaultValues() { + var props = new JdbcChatMemoryProperties(); + assertThat(props.isInitializeSchema()).isTrue(); + } + + @Test + void customValues() { + var props = new JdbcChatMemoryProperties(); + props.setInitializeSchema(false); + + assertThat(props.isInitializeSchema()).isFalse(); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryPropertiesTests.java deleted file mode 100644 index be0ed55c54c..00000000000 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/pgvector/PgVectorChatMemoryPropertiesTests.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.autoconfigure.chat.memory.pgvector; - -import org.junit.jupiter.api.Test; - -import org.springframework.ai.chat.memory.PgVectorChatMemoryConfig; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * @author Jonathan Leijendekker - */ -class PgVectorChatMemoryPropertiesTests { - - @Test - void defaultValues() { - var props = new PgVectorChatMemoryProperties(); - assertThat(props.getSchemaName()).isEqualTo(PgVectorChatMemoryConfig.DEFAULT_SCHEMA_NAME); - assertThat(props.getTableName()).isEqualTo(PgVectorChatMemoryConfig.DEFAULT_TABLE_NAME); - assertThat(props.getSessionIdColumnName()).isEqualTo(PgVectorChatMemoryConfig.DEFAULT_SESSION_ID_COLUMN_NAME); - assertThat(props.getExchangeIdColumnName()).isEqualTo(PgVectorChatMemoryConfig.DEFAULT_EXCHANGE_ID_COLUMN_NAME); - assertThat(props.getAssistantColumnName()).isEqualTo(PgVectorChatMemoryConfig.DEFAULT_ASSISTANT_COLUMN_NAME); - assertThat(props.getUserColumnName()).isEqualTo(PgVectorChatMemoryConfig.DEFAULT_USER_COLUMN_NAME); - assertThat(props.isInitializeSchema()).isTrue(); - } - - @Test - void customValues() { - var props = new PgVectorChatMemoryProperties(); - props.setSchemaName("custom_schema_name"); - props.setTableName("custom_table_name"); - props.setSessionIdColumnName("custom_session_id_column_name"); - props.setExchangeIdColumnName("custom_exchange_id_column_name"); - props.setAssistantColumnName("custom_assistant_column_name"); - props.setUserColumnName("custom_user_column_name"); - props.setInitializeSchema(false); - - assertThat(props.getSchemaName()).isEqualTo("custom_schema_name"); - assertThat(props.getTableName()).isEqualTo("custom_table_name"); - assertThat(props.getSessionIdColumnName()).isEqualTo("custom_session_id_column_name"); - assertThat(props.getExchangeIdColumnName()).isEqualTo("custom_exchange_id_column_name"); - assertThat(props.getAssistantColumnName()).isEqualTo("custom_assistant_column_name"); - assertThat(props.getUserColumnName()).isEqualTo("custom_user_column_name"); - assertThat(props.isInitializeSchema()).isFalse(); - } - -} diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-chat-memory-jdbc/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-chat-memory-jdbc/pom.xml new file mode 100644 index 00000000000..08c373c0088 --- /dev/null +++ b/spring-ai-spring-boot-starters/spring-ai-starter-chat-memory-jdbc/pom.xml @@ -0,0 +1,58 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-chat-memory-jdbc-spring-boot-starter + jar + Spring AI Starter - Chat Memory JDBC + Spring AI Chat Memory JDBC Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.ai + spring-ai-spring-boot-autoconfigure + ${project.parent.version} + + + + org.springframework.ai + spring-ai-chat-memory-jdbc + ${project.parent.version} + + + + diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemory.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemory.java deleted file mode 100644 index 9ce6cc2196c..00000000000 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemory.java +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.chat.memory; - -import java.sql.PreparedStatement; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Types; -import java.util.List; - -import org.springframework.ai.chat.messages.AssistantMessage; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.jdbc.core.BatchPreparedStatementSetter; -import org.springframework.jdbc.core.JdbcTemplate; -import org.springframework.jdbc.core.RowMapper; - -/** - * An implementation of {@link ChatMemory} for PgVector. Creating an instance of - * PgVectorChatMemory example: - * PgVectorChatMemory.create(PgVectorChatMemoryConfig.builder().withJdbcTemplate(jdbcTemplate).build()); - * - * @author Jonathan Leijendekker - * @since 1.0.0 - */ -public class PgVectorChatMemory implements ChatMemory { - - private final PgVectorChatMemoryConfig config; - - private final JdbcTemplate jdbcTemplate; - - public PgVectorChatMemory(PgVectorChatMemoryConfig config) { - this.config = config; - this.config.initializeSchema(); - this.jdbcTemplate = this.config.getJdbcTemplate(); - } - - public static PgVectorChatMemory create(PgVectorChatMemoryConfig config) { - return new PgVectorChatMemory(config); - } - - @Override - public void add(String conversationId, List messages) { - var sql = String.format("INSERT INTO %s (%s, %s, %s) VALUES (?, ?, ?)", - this.config.getFullyQualifiedTableName(), this.config.getSessionIdColumnName(), - this.config.getAssistantColumnName(), this.config.getUserColumnName()); - - this.jdbcTemplate.batchUpdate(sql, new AddBatchPreparedStatement(conversationId, messages)); - } - - @Override - public List get(String conversationId, int lastN) { - var sql = String.format("SELECT %s, %s FROM %s WHERE %s = ? ORDER BY %s DESC LIMIT ?", - this.config.getAssistantColumnName(), this.config.getUserColumnName(), - this.config.getFullyQualifiedTableName(), this.config.getSessionIdColumnName(), - this.config.getExchangeIdColumnName()); - - return this.jdbcTemplate.query(sql, new MessageRowMapper(), conversationId, lastN); - } - - @Override - public void clear(String conversationId) { - var sql = String.format("DELETE FROM %s WHERE %s = ?", this.config.getFullyQualifiedTableName(), - this.config.getSessionIdColumnName()); - - this.jdbcTemplate.update(sql, conversationId); - } - - private record AddBatchPreparedStatement(String conversationId, - List messages) implements BatchPreparedStatementSetter { - @Override - public void setValues(PreparedStatement ps, int i) throws SQLException { - var message = this.messages.get(i); - - ps.setString(1, this.conversationId); - - switch (message.getMessageType()) { - case ASSISTANT -> { - ps.setString(2, message.getContent()); - ps.setNull(3, Types.VARCHAR); - } - case USER -> { - ps.setNull(2, Types.VARCHAR); - ps.setString(3, message.getContent()); - } - default -> throw new IllegalArgumentException("Can't add type " + message); - } - } - - @Override - public int getBatchSize() { - return this.messages.size(); - } - } - - private static class MessageRowMapper implements RowMapper { - - @Override - public Message mapRow(ResultSet rs, int i) throws SQLException { - var assistant = rs.getString(1); - - if (assistant != null) { - return new AssistantMessage(assistant); - } - - var user = rs.getString(2); - - if (user != null) { - return new UserMessage(user); - } - - return null; - } - - } - -} diff --git a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfig.java b/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfig.java deleted file mode 100644 index 7b86f03c78e..00000000000 --- a/vector-stores/spring-ai-pgvector-store/src/main/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfig.java +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.chat.memory; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.springframework.jdbc.core.JdbcTemplate; -import org.springframework.util.Assert; - -/** - * @author Jonathan Leijendekker - * @since 1.0.0 - */ -public final class PgVectorChatMemoryConfig { - - private static final Logger logger = LoggerFactory.getLogger(PgVectorChatMemoryConfig.class); - - public static final boolean DEFAULT_SCHEMA_INITIALIZATION = false; - - public static final String DEFAULT_SCHEMA_NAME = "public"; - - public static final String DEFAULT_TABLE_NAME = "ai_chat_memory"; - - public static final String DEFAULT_SESSION_ID_COLUMN_NAME = "session_id"; - - public static final String DEFAULT_EXCHANGE_ID_COLUMN_NAME = "message_timestamp"; - - public static final String DEFAULT_ASSISTANT_COLUMN_NAME = "assistant"; - - // "user" is a reserved keyword in postgres, hence the double quotes. - public static final String DEFAULT_USER_COLUMN_NAME = "\"user\""; - - private final boolean initializeSchema; - - private final String schemaName; - - private final String tableName; - - private final String sessionIdColumnName; - - private final String exchangeIdColumnName; - - private final String assistantColumnName; - - private final String userColumnName; - - private final JdbcTemplate jdbcTemplate; - - private PgVectorChatMemoryConfig(Builder builder) { - this.initializeSchema = builder.initializeSchema; - this.schemaName = builder.schemaName; - this.tableName = builder.tableName; - this.sessionIdColumnName = builder.sessionIdColumnName; - this.exchangeIdColumnName = builder.exchangeIdColumnName; - this.assistantColumnName = builder.assistantColumnName; - this.userColumnName = builder.userColumnName; - this.jdbcTemplate = builder.jdbcTemplate; - } - - public static Builder builder() { - return new Builder(); - } - - String getFullyQualifiedTableName() { - return this.schemaName + "." + this.tableName; - } - - String getSchemaName() { - return this.schemaName; - } - - String getTableName() { - return this.tableName; - } - - String getSessionIdColumnName() { - return this.sessionIdColumnName; - } - - String getExchangeIdColumnName() { - return this.exchangeIdColumnName; - } - - String getAssistantColumnName() { - return this.assistantColumnName; - } - - String getUserColumnName() { - return this.userColumnName; - } - - JdbcTemplate getJdbcTemplate() { - return this.jdbcTemplate; - } - - void initializeSchema() { - if (!this.initializeSchema) { - logger.warn("Skipping the schema initialization for table: {}", this.getFullyQualifiedTableName()); - return; - } - - logger.info("Initializing PGChatMemory schema for table: {} in schema: {}", this.getTableName(), - this.getSchemaName()); - - var indexName = String - .format("%s_%s_%s_idx", this.getTableName(), this.getSessionIdColumnName(), this.getExchangeIdColumnName()) - // Keywords in postgres has to be wrapped in double quotes. It is possible - // that the table or column may be a reserved keyword. If so, just remove - // them. - .replaceAll("\"", ""); - - this.jdbcTemplate.execute(String.format("CREATE SCHEMA IF NOT EXISTS %s", this.getSchemaName())); - this.jdbcTemplate.execute(String.format(""" - CREATE TABLE IF NOT EXISTS %s ( - %s character varying(40) NOT NULL, - %s timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, - %s text, - %s text - ) - """, this.getFullyQualifiedTableName(), this.getSessionIdColumnName(), this.getExchangeIdColumnName(), - this.getAssistantColumnName(), this.getUserColumnName())); - this.jdbcTemplate.execute(String.format("CREATE INDEX IF NOT EXISTS %s ON %s(%s, %s DESC)", indexName, - this.getFullyQualifiedTableName(), this.getSessionIdColumnName(), this.getExchangeIdColumnName())); - } - - public static final class Builder { - - private boolean initializeSchema = DEFAULT_SCHEMA_INITIALIZATION; - - private String schemaName = DEFAULT_SCHEMA_NAME; - - private String tableName = DEFAULT_TABLE_NAME; - - private String sessionIdColumnName = DEFAULT_SESSION_ID_COLUMN_NAME; - - private String exchangeIdColumnName = DEFAULT_EXCHANGE_ID_COLUMN_NAME; - - private String assistantColumnName = DEFAULT_ASSISTANT_COLUMN_NAME; - - private String userColumnName = DEFAULT_USER_COLUMN_NAME; - - private JdbcTemplate jdbcTemplate; - - private Builder() { - } - - public Builder withInitializeSchema(boolean initializeSchema) { - this.initializeSchema = initializeSchema; - return this; - } - - public Builder withSchemaName(String schemaName) { - Assert.hasText(schemaName, "schema name must not be empty"); - - this.schemaName = schemaName; - return this; - } - - public Builder withTableName(String tableName) { - Assert.hasText(tableName, "table name must not be empty"); - - this.tableName = tableName; - return this; - } - - public Builder withSessionIdColumnName(String sessionIdColumnName) { - Assert.hasText(sessionIdColumnName, "session id column name must not be empty"); - - this.sessionIdColumnName = sessionIdColumnName; - return this; - } - - public Builder withExchangeIdColumnName(String exchangeIdColumnName) { - Assert.hasText(exchangeIdColumnName, "exchange id column name must not be empty"); - - this.exchangeIdColumnName = exchangeIdColumnName; - return this; - } - - public Builder withAssistantColumnName(String assistantColumnName) { - Assert.hasText(assistantColumnName, "assistant column name must not be empty"); - - this.assistantColumnName = assistantColumnName; - return this; - } - - public Builder withUserColumnName(String userColumnName) { - Assert.hasText(userColumnName, "user column name must not be empty"); - - this.userColumnName = userColumnName; - return this; - } - - public Builder withJdbcTemplate(JdbcTemplate jdbcTemplate) { - Assert.notNull(jdbcTemplate, "jdbc template must not be null"); - - this.jdbcTemplate = jdbcTemplate; - return this; - } - - public PgVectorChatMemoryConfig build() { - Assert.notNull(this.jdbcTemplate, "jdbc template must not be null"); - - return new PgVectorChatMemoryConfig(this); - } - - } - -} diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java deleted file mode 100644 index 0392c557813..00000000000 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/chat/memory/PgVectorChatMemoryConfigIT.java +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.chat.memory; - -import java.util.List; - -import javax.sql.DataSource; - -import com.zaxxer.hikari.HikariDataSource; -import org.junit.jupiter.api.Test; -import org.testcontainers.containers.PostgreSQLContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import org.springframework.ai.vectorstore.PgVectorImage; -import org.springframework.boot.SpringBootConfiguration; -import org.springframework.boot.autoconfigure.EnableAutoConfiguration; -import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; -import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties; -import org.springframework.boot.context.properties.ConfigurationProperties; -import org.springframework.boot.test.context.runner.ApplicationContextRunner; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Primary; -import org.springframework.jdbc.core.JdbcTemplate; - -import static org.assertj.core.api.Assertions.assertThat; - -/** - * @author Jonathan Leijendekker - */ -@Testcontainers -class PgVectorChatMemoryConfigIT { - - @Container - @SuppressWarnings("resource") - static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>(PgVectorImage.DEFAULT_IMAGE) - .withUsername("postgres") - .withPassword("postgres"); - - private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() - .withUserConfiguration(PgVectorChatMemoryIT.TestApplication.class) - .withPropertyValues( - // JdbcTemplate configuration - String.format("app.datasource.url=%s", postgresContainer.getJdbcUrl()), - String.format("app.datasource.username=%s", postgresContainer.getUsername()), - String.format("app.datasource.password=%s", postgresContainer.getPassword()), - "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); - - static String schemaName = "config_test"; - static String tableName = "ai_chat_config_test"; - static String sessionIdColumnName = "id_config_test"; - static String exchangeIdColumnName = "timestamp_config_test"; - static String assistantColumnName = "assistant_config_test"; - static String userColumnName = "\"user_config_test\""; - - @Test - void initializeSchema_withValueTrue_shouldCreateSchema() { - this.contextRunner.run(context -> { - var jdbcTemplate = context.getBean(JdbcTemplate.class); - var config = PgVectorChatMemoryConfig.builder() - .withInitializeSchema(true) - .withSchemaName(schemaName) - .withTableName(tableName) - .withSessionIdColumnName(sessionIdColumnName) - .withExchangeIdColumnName(exchangeIdColumnName) - .withAssistantColumnName(assistantColumnName) - .withUserColumnName(userColumnName) - .withJdbcTemplate(jdbcTemplate) - .build(); - config.initializeSchema(); - - var expectedColumns = List.of(sessionIdColumnName, exchangeIdColumnName, assistantColumnName, - userColumnName.replace("\"", "")); - - // Verify that the schema, table, and index was created - var hasSchema = jdbcTemplate.queryForObject( - "SELECT EXISTS (SELECT 1 FROM information_schema.schemata WHERE schema_name = ?)", boolean.class, - schemaName); - var hasTable = jdbcTemplate.queryForObject( - "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = ? AND table_name = ?)", - boolean.class, schemaName, tableName); - var tableColumns = jdbcTemplate.queryForList( - "SELECT column_name FROM information_schema.columns WHERE table_schema = ? AND table_name = ?", - String.class, schemaName, tableName); - var indexName = jdbcTemplate.queryForObject( - "SELECT indexname FROM pg_indexes WHERE schemaname = ? AND tablename = ?", String.class, schemaName, - tableName); - - assertThat(hasSchema).isTrue(); - assertThat(hasTable).isTrue(); - assertThat(expectedColumns.containsAll(tableColumns)).isTrue(); - assertThat(String.format("%s_%s_%s_idx", tableName, sessionIdColumnName, exchangeIdColumnName)) - .isEqualTo(indexName); - - // Cleanup for the other tests - jdbcTemplate.update(String.format("DROP SCHEMA IF EXISTS %s CASCADE", schemaName)); - }); - } - - @Test - void initializeSchema_withValueFalse_shouldNotCreateSchema() { - this.contextRunner.run(context -> { - var jdbcTemplate = context.getBean(JdbcTemplate.class); - var config = PgVectorChatMemoryConfig.builder() - .withInitializeSchema(false) - .withSchemaName(schemaName) - .withTableName(tableName) - .withSessionIdColumnName(sessionIdColumnName) - .withExchangeIdColumnName(exchangeIdColumnName) - .withAssistantColumnName(assistantColumnName) - .withUserColumnName(userColumnName) - .withJdbcTemplate(jdbcTemplate) - .build(); - config.initializeSchema(); - - // Verify that the schema, table, and index was created - var hasSchema = jdbcTemplate.queryForObject( - "SELECT EXISTS (SELECT 1 FROM information_schema.schemata WHERE schema_name = ?)", boolean.class, - schemaName); - var hasTable = jdbcTemplate.queryForObject( - "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_schema = ? AND table_name = ?)", - boolean.class, schemaName, tableName); - var columnCount = jdbcTemplate.queryForObject( - "SELECT COUNT(*) FROM information_schema.columns WHERE table_schema = ? AND table_name = ?", - Integer.class, schemaName, tableName); - var hasIndex = jdbcTemplate.queryForObject( - "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = ? AND tablename = ?)", Boolean.class, - schemaName, tableName); - - assertThat(hasSchema).isFalse(); - assertThat(hasTable).isFalse(); - assertThat(columnCount).isZero(); - assertThat(hasIndex).isFalse(); - }); - } - - @SpringBootConfiguration - @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) - static class TestApplication { - - @Bean - public JdbcTemplate jdbcTemplate(DataSource dataSource) { - return new JdbcTemplate(dataSource); - } - - @Bean - @Primary - @ConfigurationProperties("app.datasource") - public DataSourceProperties dataSourceProperties() { - return new DataSourceProperties(); - } - - @Bean - public HikariDataSource dataSource(DataSourceProperties dataSourceProperties) { - return dataSourceProperties.initializeDataSourceBuilder().type(HikariDataSource.class).build(); - } - - } - -} From 6ff3cb756d3dd45eeb0ecc7be2834ec19649027c Mon Sep 17 00:00:00 2001 From: leijendary Date: Sun, 15 Dec 2024 20:49:55 +0100 Subject: [PATCH 7/8] Used MessageType enum name instead of value --- .../org/springframework/ai/chat/memory/JdbcChatMemory.java | 4 ++-- .../org/springframework/ai/chat/memory/schema.mariadb.sql | 2 +- .../org/springframework/ai/chat/memory/schema.postgresql.sql | 2 +- .../org/springframework/ai/chat/memory/JdbcChatMemoryIT.java | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemory.java b/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemory.java index 7f6957d7fd9..6b77862a079 100644 --- a/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemory.java +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemory.java @@ -80,7 +80,7 @@ public void setValues(PreparedStatement ps, int i) throws SQLException { ps.setString(1, this.conversationId); ps.setString(2, message.getText()); - ps.setString(3, message.getMessageType().getValue()); + ps.setString(3, message.getMessageType().name()); } @Override @@ -94,7 +94,7 @@ private static class MessageRowMapper implements RowMapper { @Override public Message mapRow(ResultSet rs, int i) throws SQLException { var content = rs.getString(1); - var type = MessageType.fromValue(rs.getString(2)); + var type = MessageType.valueOf(rs.getString(2)); return switch (type) { case USER -> new UserMessage(content); diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.mariadb.sql b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.mariadb.sql index 3e3256151b8..b524b255a5e 100644 --- a/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.mariadb.sql +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.mariadb.sql @@ -3,7 +3,7 @@ CREATE TABLE IF NOT EXISTS ai_chat_memory ( content TEXT NOT NULL, type VARCHAR(10) NOT NULL, `timestamp` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - CONSTRAINT type_check CHECK (type IN ('user', 'assistant')) + CONSTRAINT type_check CHECK (type IN ('USER', 'ASSISTANT')) ); CREATE INDEX ai_chat_memory_conversation_id_timestamp_idx diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.postgresql.sql b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.postgresql.sql index 3c93d24ad77..bd7a8efed8b 100644 --- a/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.postgresql.sql +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.postgresql.sql @@ -1,7 +1,7 @@ CREATE TABLE IF NOT EXISTS ai_chat_memory ( conversation_id VARCHAR(36) NOT NULL, content text NOT NULL, - type VARCHAR(10) NOT NULL CHECK (type IN ('user', 'assistant')), + type VARCHAR(10) NOT NULL CHECK (type IN ('USER', 'ASSISTANT')), "timestamp" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ); diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/JdbcChatMemoryIT.java b/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/JdbcChatMemoryIT.java index 5d7308a8906..308babf4a47 100644 --- a/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/JdbcChatMemoryIT.java +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/JdbcChatMemoryIT.java @@ -98,7 +98,7 @@ void add_shouldInsertSingleMessage(String content, MessageType messageType) { assertThat(result.size()).isEqualTo(4); assertThat(result.get("conversation_id")).isEqualTo(conversationId); assertThat(result.get("content")).isEqualTo(message.getText()); - assertThat(result.get("type")).isEqualTo(messageType.getValue()); + assertThat(result.get("type")).isEqualTo(messageType.name()); assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class); }); } @@ -126,7 +126,7 @@ void add_shouldInsertMessages() { assertThat(result.get("conversation_id")).isNotNull(); assertThat(result.get("conversation_id")).isEqualTo(conversationId); assertThat(result.get("content")).isEqualTo(message.getText()); - assertThat(result.get("type")).isEqualTo(message.getMessageType().getValue()); + assertThat(result.get("type")).isEqualTo(message.getMessageType().name()); assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class); } }); From 253dd30cbe6a68a7746305ff2ea1f35d28fa4396 Mon Sep 17 00:00:00 2001 From: leijendary Date: Wed, 18 Dec 2024 13:18:33 +0100 Subject: [PATCH 8/8] Removed SYSTEM message type from MessageRowMapper --- .../org/springframework/ai/chat/memory/JdbcChatMemory.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemory.java b/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemory.java index 6b77862a079..5e7164ef4b8 100644 --- a/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemory.java +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemory.java @@ -21,7 +21,10 @@ import java.sql.SQLException; import java.util.List; -import org.springframework.ai.chat.messages.*; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.jdbc.core.BatchPreparedStatementSetter; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.core.RowMapper; @@ -99,7 +102,6 @@ public Message mapRow(ResultSet rs, int i) throws SQLException { return switch (type) { case USER -> new UserMessage(content); case ASSISTANT -> new AssistantMessage(content); - case SYSTEM -> new SystemMessage(content); default -> null; }; }