Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JdbcChatMemory #1528

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ One way to run integration tests on part of the code is to first do a quick comp
```shell
./mvnw clean install -DskipTests -Dmaven.javadoc.skip=true
```
Then run the integration test for a specifi module using the `-pl` option
Then run the integration test for a specific module using the `-pl` option
```shell
./mvnw verify -Pintegration-tests -pl spring-ai-spring-boot-autoconfigure
./mvnw verify -Pintegration-tests -pl spring-ai-spring-boot-autoconfigure
```

### Documentation
Expand Down Expand Up @@ -134,4 +134,4 @@ To build with checkstyles enabled.
Checkstyles are currently disabled, but you can enable them by doing the following:
```shell
./mvnw clean package -DskipTests -Ddisable.checks=false
```
```
1 change: 1 addition & 0 deletions chat-memory/spring-ai-chat-memory-jdbc/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[Chat Memory Documentation](https://docs.spring.io/spring-ai/reference/api/chatclient.html#_chat_memory)
107 changes: 107 additions & 0 deletions chat-memory/spring-ai-chat-memory-jdbc/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Copyright 2023-2024 the original author or authors.
~
~ Licensed under the Apache License, Version 2.0 (the "License");
~ you may not use this file except in compliance with the License.
~ You may obtain a copy of the License at
~
~ https://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->

<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai</artifactId>
<version>1.0.0-SNAPSHOT</version>
<relativePath>../../pom.xml</relativePath>
</parent>
<artifactId>spring-ai-chat-memory-jdbc</artifactId>
<packaging>jar</packaging>
<name>Spring AI Chat Memory JDBC</name>
<description>Spring AI Chat Memory implementation with JDBC</description>
<url>https://github.com/spring-projects/spring-ai</url>

<scm>
<url>https://github.com/spring-projects/spring-ai</url>
<connection>git://github.com/spring-projects/spring-ai.git</connection>
<developerConnection>[email protected]:spring-projects/spring-ai.git</developerConnection>
</scm>

<properties>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
</properties>

<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-core</artifactId>
<version>${project.parent.version}</version>
</dependency>

<dependency>
<groupId>com.zaxxer</groupId>
<artifactId>HikariCP</artifactId>
</dependency>

<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-jdbc</artifactId>
</dependency>

<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
<version>${postgresql.version}</version>
<optional>true</optional>
</dependency>

<dependency>
<groupId>org.mariadb.jdbc</groupId>
<artifactId>mariadb-java-client</artifactId>
<version>${mariadb.version}</version>
<optional>true</optional>
</dependency>

<!-- TESTING -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>testcontainers</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>postgresql</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>mariadb</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.chat.memory;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;

import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;

/**
* An implementation of {@link ChatMemory} for JDBC. When this class is instantiated,
* {@link JdbcChatMemoryConfig#initializeSchema()} will automatically be called. Creating
* an instance of JdbcChatMemory example:
* <code>JdbcChatMemory.create(JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build());</code>
*
* @author Jonathan Leijendekker
* @since 1.0.0
*/
public class JdbcChatMemory implements ChatMemory {

private static final String QUERY_ADD = """
INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)""";

private static final String QUERY_GET = """
SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC LIMIT ?""";

private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?";

private final JdbcTemplate jdbcTemplate;

public JdbcChatMemory(JdbcChatMemoryConfig config) {
config.initializeSchema();

this.jdbcTemplate = config.getJdbcTemplate();
}

public static JdbcChatMemory create(JdbcChatMemoryConfig config) {
return new JdbcChatMemory(config);
}

@Override
public void add(String conversationId, List<Message> messages) {
this.jdbcTemplate.batchUpdate(QUERY_ADD, new AddBatchPreparedStatement(conversationId, messages));
}

@Override
public List<Message> get(String conversationId, int lastN) {
return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN);
}

@Override
public void clear(String conversationId) {
this.jdbcTemplate.update(QUERY_CLEAR, conversationId);
}

private record AddBatchPreparedStatement(String conversationId,
List<Message> messages) implements BatchPreparedStatementSetter {
@Override
public void setValues(PreparedStatement ps, int i) throws SQLException {
var message = this.messages.get(i);

ps.setString(1, this.conversationId);
ps.setString(2, message.getText());
ps.setString(3, message.getMessageType().name());
}

@Override
public int getBatchSize() {
return this.messages.size();
}
}

private static class MessageRowMapper implements RowMapper<Message> {

@Override
public Message mapRow(ResultSet rs, int i) throws SQLException {
var content = rs.getString(1);
var type = MessageType.valueOf(rs.getString(2));

return switch (type) {
case USER -> new UserMessage(content);
case ASSISTANT -> new AssistantMessage(content);
default -> null;
};
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.chat.memory;

import java.sql.SQLException;
import java.util.Objects;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.core.io.ClassPathResource;
import org.springframework.jdbc.CannotGetJdbcConnectionException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.datasource.init.ResourceDatabasePopulator;
import org.springframework.util.Assert;

/**
* Configuration for the JDBC {@link ChatMemory}. When
* {@link JdbcChatMemoryConfig#initializeSchema} is set to {@code true} (default is
* {@code false}) and {@link #initializeSchema()} is called, then the schema based on the
* database will be created.
*
* @author Jonathan Leijendekker
* @since 1.0.0
*/
public final class JdbcChatMemoryConfig {

private static final Logger logger = LoggerFactory.getLogger(JdbcChatMemoryConfig.class);

public static final boolean DEFAULT_SCHEMA_INITIALIZATION = false;

private final boolean initializeSchema;

private final JdbcTemplate jdbcTemplate;

private JdbcChatMemoryConfig(Builder builder) {
this.initializeSchema = builder.initializeSchema;
this.jdbcTemplate = builder.jdbcTemplate;
}

public static Builder builder() {
return new Builder();
}

JdbcTemplate getJdbcTemplate() {
return this.jdbcTemplate;
}

void initializeSchema() {
if (!this.initializeSchema) {
return;
}

logger.info("Initializing JdbcChatMemory schema");

String productName;

try (var connection = Objects.requireNonNull(this.jdbcTemplate.getDataSource()).getConnection()) {
var metadata = connection.getMetaData();
productName = metadata.getDatabaseProductName();
}
catch (SQLException e) {
throw new CannotGetJdbcConnectionException("Failed to obtain JDBC metadata", e);
}

var fileName = String.format("schema.%s.sql", productName.toLowerCase());
var resource = new ClassPathResource(fileName, JdbcChatMemoryConfig.class);
var databasePopulator = new ResourceDatabasePopulator(resource);
databasePopulator.execute(this.jdbcTemplate.getDataSource());
}

public static final class Builder {

private boolean initializeSchema = DEFAULT_SCHEMA_INITIALIZATION;

private JdbcTemplate jdbcTemplate;

private Builder() {
}

public Builder setInitializeSchema(boolean initializeSchema) {
this.initializeSchema = initializeSchema;
return this;
}

public Builder jdbcTemplate(JdbcTemplate jdbcTemplate) {
Assert.notNull(jdbcTemplate, "jdbc template must not be null");

this.jdbcTemplate = jdbcTemplate;
return this;
}

public JdbcChatMemoryConfig build() {
Assert.notNull(this.jdbcTemplate, "jdbc template must not be null");

return new JdbcChatMemoryConfig(this);
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package org.springframework.ai.chat.memory.aot.hint;

import javax.sql.DataSource;

import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;

/**
* A {@link RuntimeHintsRegistrar} for JDBC Chat Memory hints
*
* @author Jonathan Leijendekker
*/
class JdbcChatMemoryRuntimeHints implements RuntimeHintsRegistrar {

@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
hints.reflection()
.registerType(DataSource.class, (hint) -> hint.withMembers(MemberCategory.INVOKE_DECLARED_METHODS));

hints.resources()
.registerPattern("org/springframework/ai/chat/memory/schema.mariadb.sql")
.registerPattern("org/springframework/ai/chat/memory/schema.postgresql.sql");
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
org.springframework.aot.hint.RuntimeHintsRegistrar=\
org.springframework.ai.chat.memory.aot.hint.JdbcChatMemoryRuntimeHints
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DROP TABLE IF EXISTS ai_chat_memory;
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DROP TABLE IF EXISTS ai_chat_memory;
Loading