diff --git a/README.md b/README.md index e4b11b62c41..462e9eadc92 100644 --- a/README.md +++ b/README.md @@ -101,9 +101,9 @@ One way to run integration tests on part of the code is to first do a quick comp ```shell ./mvnw clean install -DskipTests -Dmaven.javadoc.skip=true ``` -Then run the integration test for a specifi module using the `-pl` option +Then run the integration test for a specific module using the `-pl` option ```shell -./mvnw verify -Pintegration-tests -pl spring-ai-spring-boot-autoconfigure +./mvnw verify -Pintegration-tests -pl spring-ai-spring-boot-autoconfigure ``` ### Documentation @@ -134,4 +134,4 @@ To build with checkstyles enabled. Checkstyles are currently disabled, but you can enable them by doing the following: ```shell ./mvnw clean package -DskipTests -Ddisable.checks=false -``` \ No newline at end of file +``` diff --git a/chat-memory/spring-ai-chat-memory-jdbc/README.md b/chat-memory/spring-ai-chat-memory-jdbc/README.md new file mode 100644 index 00000000000..8e100ad20a3 --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/README.md @@ -0,0 +1 @@ +[Chat Memory Documentation](https://docs.spring.io/spring-ai/reference/api/chatclient.html#_chat_memory) diff --git a/chat-memory/spring-ai-chat-memory-jdbc/pom.xml b/chat-memory/spring-ai-chat-memory-jdbc/pom.xml new file mode 100644 index 00000000000..61e1737e263 --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/pom.xml @@ -0,0 +1,107 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-chat-memory-jdbc + jar + Spring AI Chat Memory JDBC + Spring AI Chat Memory implementation with JDBC + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + 17 + 17 + + + + + org.springframework.ai + spring-ai-core + ${project.parent.version} + + + + com.zaxxer + HikariCP + + + + org.springframework + spring-jdbc + + + + org.postgresql + postgresql + ${postgresql.version} + true + + + + org.mariadb.jdbc + mariadb-java-client + ${mariadb.version} + true + + + + + org.springframework.boot + spring-boot-starter-test + test + + + + org.testcontainers + testcontainers + test + + + + org.testcontainers + postgresql + test + + + + org.testcontainers + mariadb + test + + + + org.testcontainers + junit-jupiter + test + + + diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemory.java b/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemory.java new file mode 100644 index 00000000000..5e7164ef4b8 --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemory.java @@ -0,0 +1,111 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.List; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.jdbc.core.BatchPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.RowMapper; + +/** + * An implementation of {@link ChatMemory} for JDBC. When this class is instantiated, + * {@link JdbcChatMemoryConfig#initializeSchema()} will automatically be called. Creating + * an instance of JdbcChatMemory example: + * JdbcChatMemory.create(JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build()); + * + * @author Jonathan Leijendekker + * @since 1.0.0 + */ +public class JdbcChatMemory implements ChatMemory { + + private static final String QUERY_ADD = """ + INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)"""; + + private static final String QUERY_GET = """ + SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC LIMIT ?"""; + + private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?"; + + private final JdbcTemplate jdbcTemplate; + + public JdbcChatMemory(JdbcChatMemoryConfig config) { + config.initializeSchema(); + + this.jdbcTemplate = config.getJdbcTemplate(); + } + + public static JdbcChatMemory create(JdbcChatMemoryConfig config) { + return new JdbcChatMemory(config); + } + + @Override + public void add(String conversationId, List messages) { + this.jdbcTemplate.batchUpdate(QUERY_ADD, new AddBatchPreparedStatement(conversationId, messages)); + } + + @Override + public List get(String conversationId, int lastN) { + return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN); + } + + @Override + public void clear(String conversationId) { + this.jdbcTemplate.update(QUERY_CLEAR, conversationId); + } + + private record AddBatchPreparedStatement(String conversationId, + List messages) implements BatchPreparedStatementSetter { + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + var message = this.messages.get(i); + + ps.setString(1, this.conversationId); + ps.setString(2, message.getText()); + ps.setString(3, message.getMessageType().name()); + } + + @Override + public int getBatchSize() { + return this.messages.size(); + } + } + + private static class MessageRowMapper implements RowMapper { + + @Override + public Message mapRow(ResultSet rs, int i) throws SQLException { + var content = rs.getString(1); + var type = MessageType.valueOf(rs.getString(2)); + + return switch (type) { + case USER -> new UserMessage(content); + case ASSISTANT -> new AssistantMessage(content); + default -> null; + }; + } + + } + +} diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemoryConfig.java b/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemoryConfig.java new file mode 100644 index 00000000000..cda7fba971f --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/JdbcChatMemoryConfig.java @@ -0,0 +1,115 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import java.sql.SQLException; +import java.util.Objects; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.core.io.ClassPathResource; +import org.springframework.jdbc.CannotGetJdbcConnectionException; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.datasource.init.ResourceDatabasePopulator; +import org.springframework.util.Assert; + +/** + * Configuration for the JDBC {@link ChatMemory}. When + * {@link JdbcChatMemoryConfig#initializeSchema} is set to {@code true} (default is + * {@code false}) and {@link #initializeSchema()} is called, then the schema based on the + * database will be created. + * + * @author Jonathan Leijendekker + * @since 1.0.0 + */ +public final class JdbcChatMemoryConfig { + + private static final Logger logger = LoggerFactory.getLogger(JdbcChatMemoryConfig.class); + + public static final boolean DEFAULT_SCHEMA_INITIALIZATION = false; + + private final boolean initializeSchema; + + private final JdbcTemplate jdbcTemplate; + + private JdbcChatMemoryConfig(Builder builder) { + this.initializeSchema = builder.initializeSchema; + this.jdbcTemplate = builder.jdbcTemplate; + } + + public static Builder builder() { + return new Builder(); + } + + JdbcTemplate getJdbcTemplate() { + return this.jdbcTemplate; + } + + void initializeSchema() { + if (!this.initializeSchema) { + return; + } + + logger.info("Initializing JdbcChatMemory schema"); + + String productName; + + try (var connection = Objects.requireNonNull(this.jdbcTemplate.getDataSource()).getConnection()) { + var metadata = connection.getMetaData(); + productName = metadata.getDatabaseProductName(); + } + catch (SQLException e) { + throw new CannotGetJdbcConnectionException("Failed to obtain JDBC metadata", e); + } + + var fileName = String.format("schema.%s.sql", productName.toLowerCase()); + var resource = new ClassPathResource(fileName, JdbcChatMemoryConfig.class); + var databasePopulator = new ResourceDatabasePopulator(resource); + databasePopulator.execute(this.jdbcTemplate.getDataSource()); + } + + public static final class Builder { + + private boolean initializeSchema = DEFAULT_SCHEMA_INITIALIZATION; + + private JdbcTemplate jdbcTemplate; + + private Builder() { + } + + public Builder setInitializeSchema(boolean initializeSchema) { + this.initializeSchema = initializeSchema; + return this; + } + + public Builder jdbcTemplate(JdbcTemplate jdbcTemplate) { + Assert.notNull(jdbcTemplate, "jdbc template must not be null"); + + this.jdbcTemplate = jdbcTemplate; + return this; + } + + public JdbcChatMemoryConfig build() { + Assert.notNull(this.jdbcTemplate, "jdbc template must not be null"); + + return new JdbcChatMemoryConfig(this); + } + + } + +} diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/aot/hint/JdbcChatMemoryRuntimeHints.java b/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/aot/hint/JdbcChatMemoryRuntimeHints.java new file mode 100644 index 00000000000..2ffc99af9f7 --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/aot/hint/JdbcChatMemoryRuntimeHints.java @@ -0,0 +1,26 @@ +package org.springframework.ai.chat.memory.aot.hint; + +import javax.sql.DataSource; + +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; + +/** + * A {@link RuntimeHintsRegistrar} for JDBC Chat Memory hints + * + * @author Jonathan Leijendekker + */ +class JdbcChatMemoryRuntimeHints implements RuntimeHintsRegistrar { + + @Override + public void registerHints(RuntimeHints hints, ClassLoader classLoader) { + hints.reflection() + .registerType(DataSource.class, (hint) -> hint.withMembers(MemberCategory.INVOKE_DECLARED_METHODS)); + + hints.resources() + .registerPattern("org/springframework/ai/chat/memory/schema.mariadb.sql") + .registerPattern("org/springframework/ai/chat/memory/schema.postgresql.sql"); + } + +} diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/META-INF/spring/aot.factories b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/META-INF/spring/aot.factories new file mode 100644 index 00000000000..1877d6377a4 --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/META-INF/spring/aot.factories @@ -0,0 +1,2 @@ +org.springframework.aot.hint.RuntimeHintsRegistrar=\ +org.springframework.ai.chat.memory.aot.hint.JdbcChatMemoryRuntimeHints diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema-drop.mariadb.sql b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema-drop.mariadb.sql new file mode 100644 index 00000000000..72f313114ba --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema-drop.mariadb.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS ai_chat_memory; diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema-drop.postgresql.sql b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema-drop.postgresql.sql new file mode 100644 index 00000000000..72f313114ba --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema-drop.postgresql.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS ai_chat_memory; diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.mariadb.sql b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.mariadb.sql new file mode 100644 index 00000000000..b524b255a5e --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.mariadb.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS ai_chat_memory ( + conversation_id VARCHAR(36) NOT NULL, + content TEXT NOT NULL, + type VARCHAR(10) NOT NULL, + `timestamp` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT type_check CHECK (type IN ('USER', 'ASSISTANT')) +); + +CREATE INDEX ai_chat_memory_conversation_id_timestamp_idx +ON ai_chat_memory(conversation_id, `timestamp` DESC); diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.postgresql.sql b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.postgresql.sql new file mode 100644 index 00000000000..bd7a8efed8b --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/main/resources/org/springframework/ai/chat/memory/schema.postgresql.sql @@ -0,0 +1,9 @@ +CREATE TABLE IF NOT EXISTS ai_chat_memory ( + conversation_id VARCHAR(36) NOT NULL, + content text NOT NULL, + type VARCHAR(10) NOT NULL CHECK (type IN ('USER', 'ASSISTANT')), + "timestamp" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX IF NOT EXISTS ai_chat_memory_conversation_id_timestamp_idx +ON ai_chat_memory(conversation_id, "timestamp" DESC); diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/JdbcChatMemoryConfigIT.java b/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/JdbcChatMemoryConfigIT.java new file mode 100644 index 00000000000..b6bb75f87c0 --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/JdbcChatMemoryConfigIT.java @@ -0,0 +1,244 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import java.util.List; +import java.util.Objects; + +import javax.sql.DataSource; + +import com.zaxxer.hikari.HikariDataSource; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.MariaDBContainer; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Primary; +import org.springframework.core.io.ClassPathResource; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.datasource.init.ScriptUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Jonathan Leijendekker + */ +@Testcontainers +class JdbcChatMemoryConfigIT { + + @Nested + class PostgresChatMemoryConfigIT { + + @Container + @SuppressWarnings("resource") + static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>("postgres:17") + .withDatabaseName("chat_memory_config_test") + .withUsername("postgres") + .withPassword("postgres"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class) + .withPropertyValues( + // JdbcTemplate configuration + String.format("app.datasource.url=%s", postgresContainer.getJdbcUrl()), + String.format("app.datasource.username=%s", postgresContainer.getUsername()), + String.format("app.datasource.password=%s", postgresContainer.getPassword()), + "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); + + @Test + void initializeSchema_withValueTrue_shouldCreateSchema() { + this.contextRunner.run(context -> { + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var config = JdbcChatMemoryConfig.builder() + .setInitializeSchema(true) + .jdbcTemplate(jdbcTemplate) + .build(); + config.initializeSchema(); + + var expectedColumns = List.of("conversation_id", "content", "type", "timestamp"); + + // Verify that the table and index are created + var hasTable = jdbcTemplate.queryForObject( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = ?)", boolean.class, + "ai_chat_memory"); + var tableColumns = jdbcTemplate.queryForList( + "SELECT column_name FROM information_schema.columns WHERE table_name = ?", String.class, + "ai_chat_memory"); + var indexName = jdbcTemplate.queryForObject("SELECT indexname FROM pg_indexes WHERE tablename = ?", + String.class, "ai_chat_memory"); + + assertThat(hasTable).isTrue(); + assertThat(expectedColumns.containsAll(tableColumns)).isTrue(); + assertThat(indexName).isEqualTo("ai_chat_memory_conversation_id_timestamp_idx"); + }); + } + + @Test + void initializeSchema_withValueFalse_shouldNotCreateSchema() { + this.contextRunner.run(context -> { + var jdbcTemplate = context.getBean(JdbcTemplate.class); + + // Make sure the schema does not exist in the first place + ScriptUtils.executeSqlScript(Objects.requireNonNull(jdbcTemplate.getDataSource()).getConnection(), + new ClassPathResource("schema-drop.postgresql.sql", JdbcChatMemoryConfigIT.class)); + + var config = JdbcChatMemoryConfig.builder() + .setInitializeSchema(false) + .jdbcTemplate(jdbcTemplate) + .build(); + config.initializeSchema(); + + // Verify that the table and index was created + var hasTable = jdbcTemplate.queryForObject( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = ?)", boolean.class, + "ai_chat_memory"); + var columnCount = jdbcTemplate.queryForObject( + "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = ?", Integer.class, + "ai_chat_memory"); + var hasIndex = jdbcTemplate.queryForObject( + "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE tablename = ?)", Boolean.class, + "ai_chat_memory"); + + assertThat(hasTable).isFalse(); + assertThat(columnCount).isZero(); + assertThat(hasIndex).isFalse(); + }); + } + + } + + @Nested + class MariaChatMemoryConfigIT { + + @Container + @SuppressWarnings("resource") + static MariaDBContainer mariaContainer = new MariaDBContainer<>("mariadb:11") + .withDatabaseName("chat_memory_config_test") + .withUsername("mysql") + .withPassword("mysql"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(JdbcChatMemoryConfigIT.TestApplication.class) + .withPropertyValues( + // JdbcTemplate configuration + String.format("app.datasource.url=%s", mariaContainer.getJdbcUrl()), + String.format("app.datasource.username=%s", mariaContainer.getUsername()), + String.format("app.datasource.password=%s", mariaContainer.getPassword()), + "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); + + @Test + void initializeSchema_withValueTrue_shouldCreateSchema() { + this.contextRunner.run(context -> { + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var config = JdbcChatMemoryConfig.builder() + .setInitializeSchema(true) + .jdbcTemplate(jdbcTemplate) + .build(); + config.initializeSchema(); + + var expectedColumns = List.of("conversation_id", "content", "type", "timestamp"); + + // Verify that the table and index are created + var hasTable = jdbcTemplate.queryForObject( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = ?)", boolean.class, + "ai_chat_memory"); + var tableColumns = jdbcTemplate.queryForList( + "SELECT column_name FROM information_schema.columns WHERE table_name = ?", String.class, + "ai_chat_memory"); + var indexNames = jdbcTemplate.queryForList( + "SELECT index_name, column_name FROM information_schema.statistics WHERE table_name = ?", + "ai_chat_memory"); + + assertThat(hasTable).isTrue(); + assertThat(expectedColumns.containsAll(tableColumns)).isTrue(); + assertThat(indexNames).hasSize(2); + assertThat(indexNames.get(0).get("index_name")) + .isEqualTo("ai_chat_memory_conversation_id_timestamp_idx"); + assertThat(indexNames.get(0).get("column_name")).isEqualTo("conversation_id"); + assertThat(indexNames.get(1).get("index_name")) + .isEqualTo("ai_chat_memory_conversation_id_timestamp_idx"); + assertThat(indexNames.get(1).get("column_name")).isEqualTo("timestamp"); + }); + } + + @Test + void initializeSchema_withValueFalse_shouldNotCreateSchema() { + this.contextRunner.run(context -> { + var jdbcTemplate = context.getBean(JdbcTemplate.class); + + // Make sure the schema does not exist in the first place + ScriptUtils.executeSqlScript(Objects.requireNonNull(jdbcTemplate.getDataSource()).getConnection(), + new ClassPathResource("schema-drop.mariadb.sql", JdbcChatMemoryConfigIT.class)); + + var config = JdbcChatMemoryConfig.builder() + .setInitializeSchema(false) + .jdbcTemplate(jdbcTemplate) + .build(); + config.initializeSchema(); + + // Verify that the table and index was created + var hasTable = jdbcTemplate.queryForObject( + "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = ?)", boolean.class, + "ai_chat_memory"); + var columnCount = jdbcTemplate.queryForObject( + "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = ?", Integer.class, + "ai_chat_memory"); + var hasIndex = jdbcTemplate.queryForObject( + "SELECT EXISTS (SELECT 1 FROM information_schema.statistics WHERE table_name = ?)", + Boolean.class, "ai_chat_memory"); + + assertThat(hasTable).isFalse(); + assertThat(columnCount).isZero(); + assertThat(hasIndex).isFalse(); + }); + } + + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class TestApplication { + + @Bean + public JdbcTemplate jdbcTemplate(DataSource dataSource) { + return new JdbcTemplate(dataSource); + } + + @Bean + @Primary + @ConfigurationProperties("app.datasource") + public DataSourceProperties dataSourceProperties() { + return new DataSourceProperties(); + } + + @Bean + public HikariDataSource dataSource(DataSourceProperties dataSourceProperties) { + return dataSourceProperties.initializeDataSourceBuilder().type(HikariDataSource.class).build(); + } + + } + +} diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/JdbcChatMemoryIT.java b/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/JdbcChatMemoryIT.java new file mode 100644 index 00000000000..308babf4a47 --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/JdbcChatMemoryIT.java @@ -0,0 +1,203 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory; + +import java.sql.Timestamp; +import java.util.List; +import java.util.UUID; + +import javax.sql.DataSource; + +import com.zaxxer.hikari.HikariDataSource; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Primary; +import org.springframework.jdbc.core.JdbcTemplate; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Jonathan Leijendekker + */ +@Testcontainers +class JdbcChatMemoryIT { + + @Container + @SuppressWarnings("resource") + static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>("postgres:17") + .withDatabaseName("chat_memory_test") + .withUsername("postgres") + .withPassword("postgres"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class) + .withPropertyValues( + // JdbcTemplate configuration + String.format("app.datasource.url=%s", postgresContainer.getJdbcUrl()), + String.format("app.datasource.username=%s", postgresContainer.getUsername()), + String.format("app.datasource.password=%s", postgresContainer.getPassword()), + "app.datasource.type=com.zaxxer.hikari.HikariDataSource"); + + @Test + void correctChatMemoryInstance() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + + assertThat(chatMemory).isInstanceOf(JdbcChatMemory.class); + }); + } + + @ParameterizedTest + @CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER" }) + void add_shouldInsertSingleMessage(String content, MessageType messageType) { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var message = switch (messageType) { + case ASSISTANT -> new AssistantMessage(content + " - " + conversationId); + case USER -> new UserMessage(content + " - " + conversationId); + default -> throw new IllegalArgumentException("Type not supported: " + messageType); + }; + + chatMemory.add(conversationId, message); + + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var query = "SELECT conversation_id, content, type, \"timestamp\" FROM ai_chat_memory WHERE conversation_id = ?"; + var result = jdbcTemplate.queryForMap(query, conversationId); + + assertThat(result.size()).isEqualTo(4); + assertThat(result.get("conversation_id")).isEqualTo(conversationId); + assertThat(result.get("content")).isEqualTo(message.getText()); + assertThat(result.get("type")).isEqualTo(messageType.name()); + assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class); + }); + } + + @Test + void add_shouldInsertMessages() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant - " + conversationId), + new UserMessage("Message from user - " + conversationId)); + + chatMemory.add(conversationId, messages); + + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var query = "SELECT conversation_id, content, type, \"timestamp\" FROM ai_chat_memory WHERE conversation_id = ?"; + var results = jdbcTemplate.queryForList(query, conversationId); + + assertThat(results.size()).isEqualTo(messages.size()); + + for (var i = 0; i < messages.size(); i++) { + var message = messages.get(i); + var result = results.get(i); + + assertThat(result.get("conversation_id")).isNotNull(); + assertThat(result.get("conversation_id")).isEqualTo(conversationId); + assertThat(result.get("content")).isEqualTo(message.getText()); + assertThat(result.get("type")).isEqualTo(message.getMessageType().name()); + assertThat(result.get("timestamp")).isInstanceOf(Timestamp.class); + } + }); + } + + @Test + void get_shouldReturnMessages() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant 1 - " + conversationId), + new AssistantMessage("Message from assistant 2 - " + conversationId), + new UserMessage("Message from user - " + conversationId)); + + chatMemory.add(conversationId, messages); + + var results = chatMemory.get(conversationId, Integer.MAX_VALUE); + + assertThat(results.size()).isEqualTo(messages.size()); + assertThat(results).isEqualTo(messages); + }); + } + + @Test + void clear_shouldDeleteMessages() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant - " + conversationId), + new UserMessage("Message from user - " + conversationId)); + + chatMemory.add(conversationId, messages); + + chatMemory.clear(conversationId); + + var jdbcTemplate = context.getBean(JdbcTemplate.class); + var count = jdbcTemplate.queryForObject("SELECT COUNT(*) FROM ai_chat_memory WHERE conversation_id = ?", + Integer.class, conversationId); + + assertThat(count).isZero(); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class TestApplication { + + @Bean + public ChatMemory chatMemory(JdbcTemplate jdbcTemplate) { + var config = JdbcChatMemoryConfig.builder().setInitializeSchema(true).jdbcTemplate(jdbcTemplate).build(); + + return JdbcChatMemory.create(config); + } + + @Bean + public JdbcTemplate jdbcTemplate(DataSource dataSource) { + return new JdbcTemplate(dataSource); + } + + @Bean + @Primary + @ConfigurationProperties("app.datasource") + public DataSourceProperties dataSourceProperties() { + return new DataSourceProperties(); + } + + @Bean + public HikariDataSource dataSource(DataSourceProperties dataSourceProperties) { + return dataSourceProperties.initializeDataSourceBuilder().type(HikariDataSource.class).build(); + } + + } + +} diff --git a/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/aot/hint/JdbcChatMemoryRuntimeHintsTest.java b/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/aot/hint/JdbcChatMemoryRuntimeHintsTest.java new file mode 100644 index 00000000000..5abf8495828 --- /dev/null +++ b/chat-memory/spring-ai-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/aot/hint/JdbcChatMemoryRuntimeHintsTest.java @@ -0,0 +1,82 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.memory.aot.hint; + +import java.io.IOException; +import java.util.Arrays; +import java.util.stream.Stream; + +import javax.sql.DataSource; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.aot.hint.RuntimeHintsRegistrar; +import org.springframework.aot.hint.predicate.RuntimeHintsPredicates; +import org.springframework.core.io.Resource; +import org.springframework.core.io.support.PathMatchingResourcePatternResolver; +import org.springframework.core.io.support.SpringFactoriesLoader; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Jonathan Leijendekker + */ +class JdbcChatMemoryRuntimeHintsTest { + + private final RuntimeHints hints = new RuntimeHints(); + + private final JdbcChatMemoryRuntimeHints jdbcChatMemoryRuntimeHints = new JdbcChatMemoryRuntimeHints(); + + @Test + void aotFactoriesContainsRegistrar() { + var match = SpringFactoriesLoader.forResourceLocation("META-INF/spring/aot.factories") + .load(RuntimeHintsRegistrar.class) + .stream() + .anyMatch((registrar) -> registrar instanceof JdbcChatMemoryRuntimeHints); + + assertThat(match).isTrue(); + } + + @ParameterizedTest + @MethodSource("getSchemaFileNames") + void jdbcSchemasHasHints(String schemaFileName) { + this.jdbcChatMemoryRuntimeHints.registerHints(this.hints, getClass().getClassLoader()); + + var predicate = RuntimeHintsPredicates.resource() + .forResource("org/springframework/ai/chat/memory/" + schemaFileName); + + assertThat(predicate).accepts(this.hints); + } + + @Test + void dataSourceHasHints() { + this.jdbcChatMemoryRuntimeHints.registerHints(this.hints, getClass().getClassLoader()); + + assertThat(RuntimeHintsPredicates.reflection().onType(DataSource.class)).accepts(this.hints); + } + + private static Stream getSchemaFileNames() throws IOException { + var resources = new PathMatchingResourcePatternResolver() + .getResources("classpath*:org/springframework/ai/chat/memory/schema.*.sql"); + + return Arrays.stream(resources).map(Resource::getFilename); + } + +} diff --git a/pom.xml b/pom.xml index 84187c76c60..7ba0c7d5abd 100644 --- a/pom.xml +++ b/pom.xml @@ -65,6 +65,8 @@ vector-stores/spring-ai-typesense-store vector-stores/spring-ai-weaviate-store + chat-memory/spring-ai-chat-memory-jdbc + spring-ai-spring-boot-starters/spring-ai-starter-aws-opensearch-store spring-ai-spring-boot-starters/spring-ai-starter-azure-cosmos-db-store spring-ai-spring-boot-starters/spring-ai-starter-azure-store @@ -111,6 +113,7 @@ spring-ai-spring-boot-starters/spring-ai-starter-azure-openai spring-ai-spring-boot-starters/spring-ai-starter-bedrock-ai spring-ai-spring-boot-starters/spring-ai-starter-bedrock-converse + spring-ai-spring-boot-starters/spring-ai-starter-chat-memory-jdbc spring-ai-spring-boot-starters/spring-ai-starter-huggingface spring-ai-spring-boot-starters/spring-ai-starter-minimax spring-ai-spring-boot-starters/spring-ai-starter-mistral-ai @@ -648,6 +651,9 @@ --> + + org.springframework.ai.chat.memory/**/*IT.java + org.springframework.ai.anthropic/**/*IT.java org.springframework.ai.azure.openai/**/*IT.java @@ -671,7 +677,6 @@ org.springframework.ai.vectorstore**/CosmosDB**IT.java org.springframework.ai.vectorstore.azure/**IT.java - org.springframework.ai.chat.memory/**/Cassandra**IT.java org.springframework.ai.vectorstore**/Cassandra**IT.java org.springframework.ai.chroma/**IT.java diff --git a/spring-ai-bom/pom.xml b/spring-ai-bom/pom.xml index fb8fee83f4e..73acadbb4a6 100644 --- a/spring-ai-bom/pom.xml +++ b/spring-ai-bom/pom.xml @@ -581,6 +581,12 @@ ${project.version} + + org.springframework.ai + spring-ai-chat-memory-jdbc-spring-boot-starter + ${project.version} + + diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc index 94ef91768a8..9ca3b90cd27 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chatclient.adoc @@ -374,7 +374,7 @@ Refer to the xref:_retrieval_augmented_generation[Retrieval Augmented Generation The interface `ChatMemory` represents a storage for chat conversation history. It provides methods to add messages to a conversation, retrieve messages from a conversation, and clear the conversation history. -There are currently two implementations, `InMemoryChatMemory` and `CassandraChatMemory`, that provide storage for chat conversation history, in-memory and persisted with `time-to-live`, correspondingly. +There are currently three implementations, `InMemoryChatMemory`, `CassandraChatMemory`, and `JdbcChatMemory` that provide storage for chat conversation history, in-memory and persisted with `time-to-live`, correspondingly. To create a `CassandraChatMemory` with `time-to-live`: @@ -383,11 +383,18 @@ To create a `CassandraChatMemory` with `time-to-live`: CassandraChatMemory.create(CassandraChatMemoryConfig.builder().withTimeToLive(Duration.ofDays(1)).build()); ---- +To create a `JdbcChatMemory`: + +[source,java] +---- +JdbcChatMemory.create(JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build()); +---- + The following advisor implementations use the `ChatMemory` interface to advice the prompt with conversation history which differ in the details of how the memory is added to the prompt * `MessageChatMemoryAdvisor` : Memory is retrieved and added as a collection of messages to the prompt * `PromptChatMemoryAdvisor` : Memory is retrieved and added into the prompt's system text. -* `VectorStoreChatMemoryAdvisor` : The constructor `VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId, int chatHistoryWindowSize, int order)` This constructor allows you to: +* `VectorStoreChatMemoryAdvisor` : The constructor `VectorStoreChatMemoryAdvisor(VectorStore vectorStore, String defaultConversationId, int chatHistoryWindowSize, int order)` This constructor allows you to: . Specify the VectorStore instance used for managing and querying documents. . Set a default conversation ID to be used if none is provided in the context. diff --git a/spring-ai-spring-boot-autoconfigure/pom.xml b/spring-ai-spring-boot-autoconfigure/pom.xml index 92ac01362c6..c23bb889a8c 100644 --- a/spring-ai-spring-boot-autoconfigure/pom.xml +++ b/spring-ai-spring-boot-autoconfigure/pom.xml @@ -412,6 +412,14 @@ true + + + org.springframework.ai + spring-ai-chat-memory-jdbc + ${project.parent.version} + true + + diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryAutoConfiguration.java new file mode 100644 index 00000000000..264d20db6fb --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryAutoConfiguration.java @@ -0,0 +1,51 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.chat.memory.jdbc; + +import javax.sql.DataSource; + +import org.springframework.ai.chat.memory.JdbcChatMemory; +import org.springframework.ai.chat.memory.JdbcChatMemoryConfig; +import org.springframework.boot.autoconfigure.AutoConfiguration; +import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.jdbc.core.JdbcTemplate; + +/** + * @author Jonathan Leijendekker + * @since 1.0.0 + */ +@AutoConfiguration(after = JdbcTemplateAutoConfiguration.class) +@ConditionalOnClass({ JdbcChatMemory.class, DataSource.class, JdbcTemplate.class }) +@EnableConfigurationProperties(JdbcChatMemoryProperties.class) +public class JdbcChatMemoryAutoConfiguration { + + @Bean + @ConditionalOnMissingBean + public JdbcChatMemory chatMemory(JdbcChatMemoryProperties properties, JdbcTemplate jdbcTemplate) { + var config = JdbcChatMemoryConfig.builder() + .setInitializeSchema(properties.isInitializeSchema()) + .jdbcTemplate(jdbcTemplate) + .build(); + + return JdbcChatMemory.create(config); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryProperties.java new file mode 100644 index 00000000000..f13aabe5b25 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryProperties.java @@ -0,0 +1,31 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.chat.memory.jdbc; + +import org.springframework.ai.autoconfigure.chat.memory.CommonChatMemoryProperties; +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * @author Jonathan Leijendekker + * @since 1.0.0 + */ +@ConfigurationProperties(JdbcChatMemoryProperties.CONFIG_PREFIX) +public class JdbcChatMemoryProperties extends CommonChatMemoryProperties { + + public static final String CONFIG_PREFIX = "spring.ai.chat.memory.jdbc"; + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index f3e5633efc0..2a7e2ae859d 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/spring-ai-spring-boot-autoconfigure/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -66,3 +66,4 @@ org.springframework.ai.autoconfigure.minimax.MiniMaxAutoConfiguration org.springframework.ai.autoconfigure.vertexai.embedding.VertexAiEmbeddingAutoConfiguration org.springframework.ai.autoconfigure.chat.memory.cassandra.CassandraChatMemoryAutoConfiguration org.springframework.ai.autoconfigure.vectorstore.observation.VectorStoreObservationAutoConfiguration +org.springframework.ai.autoconfigure.chat.memory.jdbc.JdbcChatMemoryAutoConfiguration diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryAutoConfigurationIT.java new file mode 100644 index 00000000000..850534d615e --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryAutoConfigurationIT.java @@ -0,0 +1,87 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.chat.memory.jdbc; + +import java.util.List; +import java.util.UUID; + +import org.junit.jupiter.api.Test; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import org.springframework.ai.chat.memory.JdbcChatMemory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.JdbcTemplateAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Jonathan Leijendekker + */ +@Testcontainers +class JdbcChatMemoryAutoConfigurationIT { + + @Container + @SuppressWarnings("resource") + static PostgreSQLContainer postgresContainer = new PostgreSQLContainer<>("postgres:17") + .withDatabaseName("chat_memory_auto_configuration_test") + .withUsername("postgres") + .withPassword("postgres"); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(JdbcChatMemoryAutoConfiguration.class, + JdbcTemplateAutoConfiguration.class, DataSourceAutoConfiguration.class)) + .withPropertyValues( + // JdbcTemplate configuration + String.format("spring.datasource.url=%s", postgresContainer.getJdbcUrl()), + String.format("spring.datasource.username=%s", postgresContainer.getUsername()), + String.format("spring.datasource.password=%s", postgresContainer.getPassword()), + "spring.datasource.type=com.zaxxer.hikari.HikariDataSource"); + + @Test + void addGetAndClear_shouldAllExecute() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(JdbcChatMemory.class); + var conversationId = UUID.randomUUID().toString(); + var userMessage = new UserMessage("Message from the user"); + + chatMemory.add(conversationId, userMessage); + + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).hasSize(1); + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEqualTo(List.of(userMessage)); + + chatMemory.clear(conversationId); + + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEmpty(); + + var multipleMessages = List.of(new UserMessage("Message from the user 1"), + new AssistantMessage("Message from the assistant 1")); + + chatMemory.add(conversationId, multipleMessages); + + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).hasSize(multipleMessages.size()); + assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEqualTo(multipleMessages); + }); + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryPropertiesTests.java new file mode 100644 index 00000000000..2238389ceb6 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/chat/memory/jdbc/JdbcChatMemoryPropertiesTests.java @@ -0,0 +1,42 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.autoconfigure.chat.memory.jdbc; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Jonathan Leijendekker + */ +class JdbcChatMemoryPropertiesTests { + + @Test + void defaultValues() { + var props = new JdbcChatMemoryProperties(); + assertThat(props.isInitializeSchema()).isTrue(); + } + + @Test + void customValues() { + var props = new JdbcChatMemoryProperties(); + props.setInitializeSchema(false); + + assertThat(props.isInitializeSchema()).isFalse(); + } + +} diff --git a/spring-ai-spring-boot-starters/spring-ai-starter-chat-memory-jdbc/pom.xml b/spring-ai-spring-boot-starters/spring-ai-starter-chat-memory-jdbc/pom.xml new file mode 100644 index 00000000000..08c373c0088 --- /dev/null +++ b/spring-ai-spring-boot-starters/spring-ai-starter-chat-memory-jdbc/pom.xml @@ -0,0 +1,58 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + ../../pom.xml + + spring-ai-chat-memory-jdbc-spring-boot-starter + jar + Spring AI Starter - Chat Memory JDBC + Spring AI Chat Memory JDBC Auto Configuration + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + + + org.springframework.boot + spring-boot-starter + + + + org.springframework.ai + spring-ai-spring-boot-autoconfigure + ${project.parent.version} + + + + org.springframework.ai + spring-ai-chat-memory-jdbc + ${project.parent.version} + + + +