diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1c6a81e85..0628a8516 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -23,7 +23,7 @@ The monorepo has the following main projects, each of which correspond to a Java These packages power our RAG applications. - `mongodb-rag-core`: A set of common resources (modules, functions, types, etc.) shared across projects. - - You need to recompile `mongodb-rag-core` by running `npm run build` every time you update it for the changes to be accessible in the other projects that dependend on it. + - You need to recompile `mongodb-rag-core` by running `npm run build` every time you update it for the changes to be accessible in the other projects that depend on it. - `mongodb-rag-ingest`: CLI application that takes data from data sources and converts it to `embedded_content` used by Atlas Vector Search. ### MongoDB Chatbot Framework @@ -40,7 +40,7 @@ general, we publish these as reusable packages on npm. These packages are our production chatbot. They build on top of the Chatbot Framework packages and add MongoDB-specific implementations. -- `chatbot-eval-mongodb-public`: Test suites, evaluators, and reports for the MongoDB AI Chatbot +- `chatbot-eval-mongodb-public`: Test suites, evaluators, and reports for the MongoDB AI Chatbot. - `chatbot-server-mongodb-public`: Chatbot server implementation with our MongoDB-specific configuration. - `ingest-mongodb-public`: RAG ingest service configured to ingest MongoDB Docs, DevCenter, MDBU, MongoDB Press, etc. @@ -132,7 +132,7 @@ npm run dev ## Infrastructure -The projects uses Drone for its CI/CD pipeline. All drone config is located in `.drone.yml`. +The projects use Drone for their CI/CD pipeline. All drone configs are located in `.drone.yml`. Applications are deployed on Kubernetes using the Kanopy developer platform. Kubernetes/Kanopy configuration are found in the `/environments` diff --git a/package-lock.json b/package-lock.json index 746e2fcdb..1aa9696fc 100644 --- a/package-lock.json +++ b/package-lock.json @@ -9095,9 +9095,9 @@ } }, "node_modules/@langchain/openai/node_modules/openai": { - "version": "4.95.0", - "resolved": "https://registry.npmjs.org/openai/-/openai-4.95.0.tgz", - "integrity": "sha512-tWHLTA+/HHyWlP8qg0mQLDSpI2NQLhk6zHLJL8yb59qn2pEI8rbEiAGSDPViLvi3BRDoQZIX5scaJ3xYGr2nhw==", + "version": "4.104.0", + "resolved": "https://registry.npmjs.org/openai/-/openai-4.104.0.tgz", + "integrity": "sha512-p99EFNsA/yX6UhVO93f5kJsDRLAg+CTA2RBqdHK4RtK8u5IJw32Hyb2dTGKbnnFmnuoBv5r7Z2CURI9sGZpSuA==", "license": "Apache-2.0", "dependencies": { "@types/node": "^18.11.18", @@ -25183,9 +25183,9 @@ } }, "node_modules/braintrust/node_modules/openai": { - "version": "4.95.0", - "resolved": "https://registry.npmjs.org/openai/-/openai-4.95.0.tgz", - "integrity": "sha512-tWHLTA+/HHyWlP8qg0mQLDSpI2NQLhk6zHLJL8yb59qn2pEI8rbEiAGSDPViLvi3BRDoQZIX5scaJ3xYGr2nhw==", + "version": "4.104.0", + "resolved": "https://registry.npmjs.org/openai/-/openai-4.104.0.tgz", + "integrity": "sha512-p99EFNsA/yX6UhVO93f5kJsDRLAg+CTA2RBqdHK4RtK8u5IJw32Hyb2dTGKbnnFmnuoBv5r7Z2CURI9sGZpSuA==", "license": "Apache-2.0", "optional": true, "peer": true, @@ -35656,9 +35656,9 @@ } }, "node_modules/llamaindex/node_modules/openai": { - "version": "4.95.0", - "resolved": "https://registry.npmjs.org/openai/-/openai-4.95.0.tgz", - "integrity": "sha512-tWHLTA+/HHyWlP8qg0mQLDSpI2NQLhk6zHLJL8yb59qn2pEI8rbEiAGSDPViLvi3BRDoQZIX5scaJ3xYGr2nhw==", + "version": "4.104.0", + "resolved": "https://registry.npmjs.org/openai/-/openai-4.104.0.tgz", + "integrity": "sha512-p99EFNsA/yX6UhVO93f5kJsDRLAg+CTA2RBqdHK4RtK8u5IJw32Hyb2dTGKbnnFmnuoBv5r7Z2CURI9sGZpSuA==", "license": "Apache-2.0", "dependencies": { "@types/node": "^18.11.18", @@ -54877,9 +54877,9 @@ "license": "MIT" }, "packages/benchmarks/node_modules/openai": { - "version": "4.47.1", - "resolved": "https://registry.npmjs.org/openai/-/openai-4.47.1.tgz", - "integrity": "sha512-WWSxhC/69ZhYWxH/OBsLEirIjUcfpQ5+ihkXKp06hmeYXgBBIUCa9IptMzYx6NdkiOCsSGYCnTIsxaic3AjRCQ==", + "version": "4.104.0", + "resolved": "https://registry.npmjs.org/openai/-/openai-4.104.0.tgz", + "integrity": "sha512-p99EFNsA/yX6UhVO93f5kJsDRLAg+CTA2RBqdHK4RtK8u5IJw32Hyb2dTGKbnnFmnuoBv5r7Z2CURI9sGZpSuA==", "license": "Apache-2.0", "dependencies": { "@types/node": "^18.11.18", @@ -54888,11 +54888,22 @@ "agentkeepalive": "^4.2.1", "form-data-encoder": "1.7.2", "formdata-node": "^4.3.2", - "node-fetch": "^2.6.7", - "web-streams-polyfill": "^3.2.1" + "node-fetch": "^2.6.7" }, "bin": { "openai": "bin/cli" + }, + "peerDependencies": { + "ws": "^8.18.0", + "zod": "^3.23.8" + }, + "peerDependenciesMeta": { + "ws": { + "optional": true + }, + "zod": { + "optional": true + } } }, "packages/benchmarks/node_modules/openai/node_modules/@types/node": { @@ -56074,21 +56085,12 @@ } }, "packages/mongodb-chatbot-server/node_modules/@langchain/core/node_modules/openai": { - "version": "4.95.0", - "resolved": "https://registry.npmjs.org/openai/-/openai-4.95.0.tgz", - "integrity": "sha512-tWHLTA+/HHyWlP8qg0mQLDSpI2NQLhk6zHLJL8yb59qn2pEI8rbEiAGSDPViLvi3BRDoQZIX5scaJ3xYGr2nhw==", + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/openai/-/openai-5.3.0.tgz", + "integrity": "sha512-VIKmoF7y4oJCDOwP/oHXGzM69+x0dpGFmN9QmYO+uPbLFOmmnwO+x1GbsgUtI+6oraxomGZ566Y421oYVu191w==", "license": "Apache-2.0", "optional": true, "peer": true, - "dependencies": { - "@types/node": "^18.11.18", - "@types/node-fetch": "^2.6.4", - "abort-controller": "^3.0.0", - "agentkeepalive": "^4.2.1", - "form-data-encoder": "1.7.2", - "formdata-node": "^4.3.2", - "node-fetch": "^2.6.7" - }, "bin": { "openai": "bin/cli" }, @@ -56116,17 +56118,6 @@ "uuid": "dist/bin/uuid" } }, - "packages/mongodb-chatbot-server/node_modules/@types/node": { - "version": "18.19.86", - "resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.86.tgz", - "integrity": "sha512-fifKayi175wLyKyc5qUfyENhQ1dCNI1UNjp653d8kuYcPQN5JhX3dGuP/XmvPTg/xRBn1VTLpbmi+H/Mr7tLfQ==", - "license": "MIT", - "optional": true, - "peer": true, - "dependencies": { - "undici-types": "~5.26.4" - } - }, "packages/mongodb-chatbot-server/node_modules/ansi-styles": { "version": "5.2.0", "license": "MIT", @@ -56151,14 +56142,6 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "packages/mongodb-chatbot-server/node_modules/form-data-encoder": { - "version": "1.7.2", - "resolved": "https://registry.npmjs.org/form-data-encoder/-/form-data-encoder-1.7.2.tgz", - "integrity": "sha512-qfqtYan3rxrnCk1VYaA4H+Ms9xdpPqvLZa6xmMgFvhO32x7/3J/ExcTd6qpxM0vH2GdMI+poehyBZvqfMTto8A==", - "license": "MIT", - "optional": true, - "peer": true - }, "packages/mongodb-chatbot-server/node_modules/ip-address": { "version": "8.1.0", "license": "MIT", @@ -56447,21 +56430,12 @@ } }, "packages/mongodb-chatbot-server/node_modules/langchain/node_modules/openai": { - "version": "4.95.0", - "resolved": "https://registry.npmjs.org/openai/-/openai-4.95.0.tgz", - "integrity": "sha512-tWHLTA+/HHyWlP8qg0mQLDSpI2NQLhk6zHLJL8yb59qn2pEI8rbEiAGSDPViLvi3BRDoQZIX5scaJ3xYGr2nhw==", + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/openai/-/openai-5.3.0.tgz", + "integrity": "sha512-VIKmoF7y4oJCDOwP/oHXGzM69+x0dpGFmN9QmYO+uPbLFOmmnwO+x1GbsgUtI+6oraxomGZ566Y421oYVu191w==", "license": "Apache-2.0", "optional": true, "peer": true, - "dependencies": { - "@types/node": "^18.11.18", - "@types/node-fetch": "^2.6.4", - "abort-controller": "^3.0.0", - "agentkeepalive": "^4.2.1", - "form-data-encoder": "1.7.2", - "formdata-node": "^4.3.2", - "node-fetch": "^2.6.7" - }, "bin": { "openai": "bin/cli" }, @@ -58322,7 +58296,7 @@ "ignore": "^5.3.2", "langchain": "^0.3.5", "mongodb": "^6.3.0", - "openai": "^4.95.0", + "openai": "^5.2.0", "rimraf": "^6.0.1", "simple-git": "^3.27.0", "toml": "^3.0.0", @@ -58938,6 +58912,45 @@ "@langchain/core": ">=0.2.26 <0.4.0" } }, + "packages/mongodb-rag-core/node_modules/@langchain/openai/node_modules/@types/node": { + "version": "18.19.111", + "resolved": "https://registry.npmjs.org/@types/node/-/node-18.19.111.tgz", + "integrity": "sha512-90sGdgA+QLJr1F9X79tQuEut0gEYIfkX9pydI4XGRgvFo9g2JWswefI+WUSUHPYVBHYSEfTEqBxA5hQvAZB3Mw==", + "license": "MIT", + "dependencies": { + "undici-types": "~5.26.4" + } + }, + "packages/mongodb-rag-core/node_modules/@langchain/openai/node_modules/openai": { + "version": "4.104.0", + "resolved": "https://registry.npmjs.org/openai/-/openai-4.104.0.tgz", + "integrity": "sha512-p99EFNsA/yX6UhVO93f5kJsDRLAg+CTA2RBqdHK4RtK8u5IJw32Hyb2dTGKbnnFmnuoBv5r7Z2CURI9sGZpSuA==", + "license": "Apache-2.0", + "dependencies": { + "@types/node": "^18.11.18", + "@types/node-fetch": "^2.6.4", + "abort-controller": "^3.0.0", + "agentkeepalive": "^4.2.1", + "form-data-encoder": "1.7.2", + "formdata-node": "^4.3.2", + "node-fetch": "^2.6.7" + }, + "bin": { + "openai": "bin/cli" + }, + "peerDependencies": { + "ws": "^8.18.0", + "zod": "^3.23.8" + }, + "peerDependenciesMeta": { + "ws": { + "optional": true + }, + "zod": { + "optional": true + } + } + }, "packages/mongodb-rag-core/node_modules/@types/jest": { "version": "26.0.24", "dev": true, @@ -59175,19 +59188,10 @@ } }, "packages/mongodb-rag-core/node_modules/openai": { - "version": "4.95.0", - "resolved": "https://registry.npmjs.org/openai/-/openai-4.95.0.tgz", - "integrity": "sha512-tWHLTA+/HHyWlP8qg0mQLDSpI2NQLhk6zHLJL8yb59qn2pEI8rbEiAGSDPViLvi3BRDoQZIX5scaJ3xYGr2nhw==", + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/openai/-/openai-5.3.0.tgz", + "integrity": "sha512-VIKmoF7y4oJCDOwP/oHXGzM69+x0dpGFmN9QmYO+uPbLFOmmnwO+x1GbsgUtI+6oraxomGZ566Y421oYVu191w==", "license": "Apache-2.0", - "dependencies": { - "@types/node": "^18.11.18", - "@types/node-fetch": "^2.6.4", - "abort-controller": "^3.0.0", - "agentkeepalive": "^4.2.1", - "form-data-encoder": "1.7.2", - "formdata-node": "^4.3.2", - "node-fetch": "^2.6.7" - }, "bin": { "openai": "bin/cli" }, @@ -59204,13 +59208,6 @@ } } }, - "packages/mongodb-rag-core/node_modules/openai/node_modules/@types/node": { - "version": "18.19.61", - "license": "MIT", - "dependencies": { - "undici-types": "~5.26.4" - } - }, "packages/mongodb-rag-core/node_modules/path-scurry": { "version": "2.0.0", "license": "BlueOak-1.0.0", @@ -61828,9 +61825,9 @@ } }, "packages/release-notes-generator/node_modules/openai": { - "version": "4.96.0", - "resolved": "https://registry.npmjs.org/openai/-/openai-4.96.0.tgz", - "integrity": "sha512-dKoW56i02Prv2XQolJ9Rl9Svqubqkzg3QpwEOBuSVZLk05Shelu7s+ErRTwFc1Bs3JZ2qBqBfVpXQiJhwOGG8A==", + "version": "4.104.0", + "resolved": "https://registry.npmjs.org/openai/-/openai-4.104.0.tgz", + "integrity": "sha512-p99EFNsA/yX6UhVO93f5kJsDRLAg+CTA2RBqdHK4RtK8u5IJw32Hyb2dTGKbnnFmnuoBv5r7Z2CURI9sGZpSuA==", "license": "Apache-2.0", "dependencies": { "@types/node": "^18.11.18", diff --git a/packages/chatbot-server-mongodb-public/src/config.ts b/packages/chatbot-server-mongodb-public/src/config.ts index 82032a6f9..f79ba18d0 100644 --- a/packages/chatbot-server-mongodb-public/src/config.ts +++ b/packages/chatbot-server-mongodb-public/src/config.ts @@ -19,6 +19,7 @@ import { defaultCreateConversationCustomData, defaultAddMessageToConversationCustomData, makeVerifiedAnswerGenerateResponse, + addMessageToConversationVerifiedAnswerStream, } from "mongodb-chatbot-server"; import cookieParser from "cookie-parser"; import { blockGetRequests } from "./middleware/blockGetRequests"; @@ -53,7 +54,10 @@ import { import { useSegmentIds } from "./middleware/useSegmentIds"; import { makeSearchTool } from "./tools/search"; import { makeMongoDbInputGuardrail } from "./processors/mongoDbInputGuardrail"; -import { makeGenerateResponseWithSearchTool } from "./processors/generateResponseWithSearchTool"; +import { + addMessageToConversationStream, + makeGenerateResponseWithSearchTool, +} from "./processors/generateResponseWithSearchTool"; import { makeBraintrustLogger } from "mongodb-rag-core/braintrust"; import { makeMongoDbScrubbedMessageStore } from "./tracing/scrubbedMessages/MongoDbScrubbedMessageStore"; import { MessageAnalysis } from "./tracing/scrubbedMessages/analyzeMessage"; @@ -218,6 +222,7 @@ export const generateResponse = wrapTraced( references: verifiedAnswer.references.map(addReferenceSourceType), }; }, + stream: addMessageToConversationVerifiedAnswerStream, onNoVerifiedAnswerFound: wrapTraced( makeGenerateResponseWithSearchTool({ languageModel, @@ -240,6 +245,7 @@ export const generateResponse = wrapTraced( searchTool: makeSearchTool(findContent), toolChoice: "auto", maxSteps: 5, + stream: addMessageToConversationStream, }), { name: "GenerateResponseWithSearchTool" } ), diff --git a/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.test.ts b/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.test.ts index 561b66133..9778bc398 100644 --- a/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.test.ts +++ b/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.test.ts @@ -337,6 +337,7 @@ describe("generateResponseWithSearchTool", () => { disconnect: mockDisconnect, streamData: mockStreamData, stream: mockStream, + streamResponsesApiPart: mockStream, } as DataStreamer; return dataStreamer; diff --git a/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts b/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts index bd2436f8d..4ec763904 100644 --- a/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts +++ b/packages/chatbot-server-mongodb-public/src/processors/generateResponseWithSearchTool.ts @@ -5,6 +5,7 @@ import { UserMessage, AssistantMessage, ToolMessage, + DataStreamer, } from "mongodb-rag-core"; import { @@ -53,8 +54,70 @@ export interface GenerateResponseWithSearchToolParams { search_content: SearchTool; }>; searchTool: SearchTool; + stream?: { + onTextDelta: ({ + dataStreamer, + delta, + }: { + dataStreamer: DataStreamer; + delta: string; + }) => void; + onReferenceLinks: ({ + dataStreamer, + references, + }: { + dataStreamer: DataStreamer; + references: References; + }) => void; + onLlmRefusal: ({ + dataStreamer, + refusalMessage, + }: { + dataStreamer: DataStreamer; + refusalMessage: string; + }) => void; + onLlmNotWorking: ({ + dataStreamer, + notWorkingMessage, + }: { + dataStreamer: DataStreamer; + notWorkingMessage: string; + }) => void; + }; } +export const addMessageToConversationStream: GenerateResponseWithSearchToolParams["stream"] = + { + onLlmNotWorking({ dataStreamer, notWorkingMessage }) { + dataStreamer?.streamData({ + type: "delta", + data: notWorkingMessage, + }); + }, + onLlmRefusal({ dataStreamer, refusalMessage }) { + dataStreamer?.streamData({ + type: "delta", + data: refusalMessage, + }); + }, + onReferenceLinks({ dataStreamer, references }) { + dataStreamer?.streamData({ + type: "references", + data: references, + }); + }, + onTextDelta({ dataStreamer, delta }) { + dataStreamer?.streamData({ + type: "delta", + data: delta, + }); + }, + }; + +export const responsesApiStream: GenerateResponseWithSearchToolParams["stream"] = + { + // TODO: Add this... + }; /** Generate chatbot response using RAG and a search tool named {@link SEARCH_TOOL_NAME}. */ @@ -70,6 +133,7 @@ export function makeGenerateResponseWithSearchTool({ maxSteps = 2, searchTool, toolChoice, + stream, }: GenerateResponseWithSearchToolParams): GenerateResponse { return async function generateResponseWithSearchTool({ conversation, @@ -81,9 +145,11 @@ export function makeGenerateResponseWithSearchTool({ dataStreamer, request, }) { - if (shouldStream) { - assert(dataStreamer, "dataStreamer is required for streaming"); - } + const streamingModeActive = + shouldStream === true && + dataStreamer !== undefined && + stream !== undefined; + const userMessage: UserMessage = { role: "user", content: latestMessageText, @@ -165,10 +231,10 @@ export function makeGenerateResponseWithSearchTool({ } switch (chunk.type) { case "text-delta": - if (shouldStream) { - dataStreamer?.streamData({ - data: chunk.textDelta, - type: "delta", + if (streamingModeActive) { + stream.onTextDelta({ + dataStreamer, + delta: chunk.textDelta, }); } break; @@ -187,10 +253,10 @@ export function makeGenerateResponseWithSearchTool({ } try { if (references.length > 0) { - if (shouldStream) { - dataStreamer?.streamData({ - data: references, - type: "references", + if (streamingModeActive) { + stream.onReferenceLinks({ + dataStreamer, + references, }); } } @@ -214,10 +280,10 @@ export function makeGenerateResponseWithSearchTool({ ...userMessageCustomData, ...guardrailResult, }; - if (shouldStream) { - dataStreamer?.streamData({ - type: "delta", - data: llmRefusalMessage, + if (streamingModeActive) { + stream.onLlmRefusal({ + dataStreamer, + refusalMessage: llmRefusalMessage, }); } return handleReturnGeneration({ @@ -269,10 +335,10 @@ export function makeGenerateResponseWithSearchTool({ }); } } catch (error: unknown) { - if (shouldStream) { - dataStreamer?.streamData({ - type: "delta", - data: llmNotWorkingMessage, + if (streamingModeActive) { + stream.onLlmNotWorking({ + dataStreamer, + notWorkingMessage: llmNotWorkingMessage, }); } diff --git a/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.test.ts b/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.test.ts index c5618c9d2..eff74fca0 100644 --- a/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.test.ts +++ b/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.test.ts @@ -1,5 +1,8 @@ import { ObjectId } from "mongodb-rag-core/mongodb"; -import { makeVerifiedAnswerGenerateResponse } from "./makeVerifiedAnswerGenerateResponse"; +import { + addMessageToConversationVerifiedAnswerStream, + makeVerifiedAnswerGenerateResponse, +} from "./makeVerifiedAnswerGenerateResponse"; import { VerifiedAnswer, WithScore, DataStreamer } from "mongodb-rag-core"; import { GenerateResponseReturnValue } from "./GenerateResponse"; @@ -55,6 +58,7 @@ describe("makeVerifiedAnswerGenerateResponse", () => { connect: jest.fn(), disconnect: jest.fn(), stream: jest.fn(), + streamResponsesApiPart: jest.fn(), }); // Create base request parameters @@ -79,6 +83,7 @@ describe("makeVerifiedAnswerGenerateResponse", () => { onNoVerifiedAnswerFound: async () => ({ messages: noVerifiedAnswerFoundMessages, }), + stream: addMessageToConversationVerifiedAnswerStream, }); it("uses onNoVerifiedAnswerFound if no verified answer is found", async () => { diff --git a/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.ts b/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.ts index 01d3be4f6..469fdd6af 100644 --- a/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.ts +++ b/packages/mongodb-chatbot-server/src/processors/makeVerifiedAnswerGenerateResponse.ts @@ -1,4 +1,8 @@ -import { VerifiedAnswer, FindVerifiedAnswerFunc } from "mongodb-rag-core"; +import { + VerifiedAnswer, + FindVerifiedAnswerFunc, + DataStreamer, +} from "mongodb-rag-core"; import { strict as assert } from "assert"; import { GenerateResponse, @@ -17,8 +21,46 @@ export interface MakeVerifiedAnswerGenerateResponseParams { onVerifiedAnswerFound?: (verifiedAnswer: VerifiedAnswer) => VerifiedAnswer; onNoVerifiedAnswerFound: GenerateResponse; + stream?: { + onVerifiedAnswerFound: ({ + verifiedAnswer, + dataStreamer, + }: { + verifiedAnswer: VerifiedAnswer; + dataStreamer: DataStreamer; + }) => void; + }; } +export const addMessageToConversationVerifiedAnswerStream = { + onVerifiedAnswerFound: ({ + verifiedAnswer, + dataStreamer, + }: { + verifiedAnswer: VerifiedAnswer; + dataStreamer: DataStreamer; + }) => { + dataStreamer.streamData({ + type: "metadata", + data: { + verifiedAnswer: { + _id: verifiedAnswer._id, + created: verifiedAnswer.created, + updated: verifiedAnswer.updated, + }, + }, + }); + dataStreamer.streamData({ + type: "delta", + data: verifiedAnswer.answer, + }); + dataStreamer.streamData({ + type: "references", + data: verifiedAnswer.references, + }); + }, +}; + /** Searches for verified answers for the user query. If no verified answer can be found for the given query, the @@ -28,6 +70,7 @@ export const makeVerifiedAnswerGenerateResponse = ({ findVerifiedAnswer, onVerifiedAnswerFound, onNoVerifiedAnswerFound, + stream, }: MakeVerifiedAnswerGenerateResponseParams): GenerateResponse => { return async (args) => { const { latestMessageText, shouldStream, dataStreamer } = args; @@ -54,17 +97,10 @@ export const makeVerifiedAnswerGenerateResponse = ({ if (shouldStream) { assert(dataStreamer, "Must have dataStreamer if shouldStream=true"); - dataStreamer.streamData({ - type: "metadata", - data: metadata, - }); - dataStreamer.streamData({ - type: "delta", - data: answer, - }); - dataStreamer.streamData({ - type: "references", - data: references, + assert(stream, "Must have stream if shouldStream=true"); + stream.onVerifiedAnswerFound({ + verifiedAnswer, + dataStreamer, }); } diff --git a/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.ts b/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.ts index 0910be426..500bfd853 100644 --- a/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.ts +++ b/packages/mongodb-chatbot-server/src/routes/conversations/addMessageToConversation.ts @@ -255,6 +255,7 @@ export function makeAddMessageToConversationRoute({ }), }; + // TODO: resume refactor here... const { messages } = await generateResponseTraced({ conversation: traceConversation, latestMessageText, @@ -276,6 +277,8 @@ export function makeAddMessageToConversationRoute({ const dbAssistantMessage = dbNewMessages[dbNewMessages.length - 1]; assert(dbAssistantMessage !== undefined, "No assistant message found"); + + // TODO: this'll need bunch o refactoring... const apiRes = convertMessageFromDbToApi( dbAssistantMessage, conversation._id diff --git a/packages/mongodb-chatbot-server/src/routes/conversations/createResponse.ts b/packages/mongodb-chatbot-server/src/routes/conversations/createResponse.ts new file mode 100644 index 000000000..f9766b3a2 --- /dev/null +++ b/packages/mongodb-chatbot-server/src/routes/conversations/createResponse.ts @@ -0,0 +1,487 @@ +import { strict as assert } from "assert"; +import { + Request as ExpressRequest, + Response as ExpressResponse, +} from "express"; +import { DbMessage } from "mongodb-rag-core"; +import { ObjectId } from "mongodb-rag-core/mongodb"; +import { + ConversationsService, + Conversation, + SomeMessage, + makeDataStreamer, +} from "mongodb-rag-core"; +import { ApiMessage, RequestError, makeRequestError } from "./utils"; +import { getRequestId, sendErrorResponse } from "../../utils"; +import { z } from "zod"; +import { SomeExpressRequest } from "../../middleware/validateRequestSchema"; +import { + AddCustomDataFunc, + ConversationsRouterLocals, +} from "./conversationsRouter"; +import { Logger } from "mongodb-rag-core/braintrust"; +import { UpdateTraceFunc, updateTraceIfExists } from "./UpdateTraceFunc"; +import { GenerateResponse } from "../../processors/GenerateResponse"; +import { OpenAI } from "mongodb-rag-core/openai"; + +export const DEFAULT_MAX_INPUT_LENGTH = 3000; // magic number for max input size for LLM +export const DEFAULT_MAX_USER_MESSAGES_IN_CONVERSATION = 7; // magic number for max messages in a conversation + +export type CreateResponseRequestBody = z.infer< + typeof CreateResponseRequestBodySchema +>; + +const MessageStatusSchema = z + .enum(["in_progress", "completed", "incomplete"]) + .optional() + .describe( + "The status of the item. One of `in_progress`, `completed`, or `incomplete`. Populated when items are returned via API." + ); + +export const CreateResponseRequestBodySchema = z.object({ + model: z.string(), + instructions: z.string().optional(), + input: z.union([ + z.string(), + + z.array( + z.union([ + z.object({ + role: z.enum(["user", "assistant", "system"]), + content: z.string(), + type: z.literal("message").optional(), + }), + // function tool call + z.object({ + arguments: z + .string() + .describe("JSON string of arguments passed to the function"), + name: z.string().describe("Name of the function to run"), + type: z.literal("function_call"), + id: z.string().optional().describe("Unique ID of the function call"), + status: MessageStatusSchema, + }), + // function tool call output + z.object({ + call_id: z + .string() + .describe( + "Unique ID of the function tool call generated by the model" + ), + output: z.string().describe("JSON string of the function tool call"), + type: z.literal("function_call_output"), + id: z + .string() + .optional() + .describe( + "The unique ID of the function tool call output. Populated when this item is returned via API." + ), + status: MessageStatusSchema, + }), + ]) + ), + ]), + max_output_tokens: z.number().max(4000).optional().default(1000), + metadata: z + .record(z.string(), z.string().max(512)) + .optional() + .refine( + (metadata) => Object.keys(metadata ?? {}).length <= 16, + "Too many metadata fields. Max 16." + ), + previous_response_id: z + .string() + .optional() + .describe( + "The unique ID of the previous response to the model. Use this to create multi-turn conversations." + ), + store: z + .boolean() + .optional() + .describe("Whether to store the response in the conversation.") + .default(true), + stream: z.literal(true, { + errorMap: () => ({ message: "'stream' must be true" }), + }), + temperature: z + .union([ + z.literal(0, { + errorMap: () => ({ message: "Temperature must be 0 or unset" }), + }), + z.undefined(), + ]) + .optional() + .describe("Temperature for the model. Defaults to 0.") + .default(0), + tool_choice: z + .union([ + z.enum(["none", "only", "auto"]), + z + .object({ + name: z.string(), + type: z.literal("function"), + }) + .describe("Function tool choice"), + ]) + .optional() + .describe("Tool choice for the model. Defaults to 'auto'.") + .default("auto"), + tools: z + .array( + z.object({ + name: z.string(), + description: z.string().optional(), + parameters: z + .record(z.string(), z.unknown()) + .describe( + "A JSON schema object describing the parameters of the function." + ), + }) + ) + .optional() + .describe( + "Tools for the model to use. Required if tool_choice is 'function'." + ), + + user: z.string().optional().describe("The user ID of the user."), +}); + +export const CreateResponseRequest = SomeExpressRequest.merge( + z.object({ + headers: z.object({ + "req-id": z.string(), + }), + body: CreateResponseRequestBodySchema, + }) +); + +export type CreateResponseRequest = z.infer; + +export interface CreateResponseRouteParams { + conversations: ConversationsService; + maxInputLengthCharacters?: number; + maxUserMessagesInConversation?: number; + generateResponse: GenerateResponse; + addMessageToConversationCustomData?: AddCustomDataFunc; + openAi: OpenAI; + /** + If present, the route will create a new conversation + when given the `conversationIdPathParam` in the URL. + */ + createConversation?: { + /** + Create a new conversation when the `conversationId` is the string "null". + */ + createOnNullConversationId: boolean; + /** + The custom data to add to the new conversation + when it is created. + */ + addCustomData?: AddCustomDataFunc; + }; + + /** + Custom function to update the Braintrust tracing + after the response has been sent to the user. + Can add additional tags, scores, etc. + */ + updateTrace?: UpdateTraceFunc; + braintrustLogger?: Logger; + supportedModels: string[]; +} + +export function makeCreateResponseRoute({ + conversations, + generateResponse, + maxInputLengthCharacters = DEFAULT_MAX_INPUT_LENGTH, + maxUserMessagesInConversation = DEFAULT_MAX_USER_MESSAGES_IN_CONVERSATION, + addMessageToConversationCustomData, + supportedModels, + updateTrace, +}: CreateResponseRouteParams) { + return async ( + req: ExpressRequest< + CreateResponseRequest["params"], + unknown, + CreateResponseRequest["body"] + >, + res: ExpressResponse + ) => { + const dataStreamer = makeDataStreamer(); + const reqId = getRequestId(req); // TODO: figure this one out... + try { + const { + body: { + input, + model, + metadata, + previous_response_id, + store, + stream, + temperature, + tool_choice, + tools, + user, + max_output_tokens, + instructions, + }, + ip, + } = req; + + // --- MODEL CHECK --- + if (!supportedModels.includes(model)) { + throw makeRequestError({ + httpStatus: 400, + message: `Model ${model} is not supported.`, + }); + } + + // --- MAX INPUT LENGTH CHECK --- + if (JSON.stringify(input).length > maxInputLengthCharacters) { + throw makeRequestError({ + httpStatus: 400, + message: "Message too long", + }); + } + + const customData = await getCustomData({ + req, + res, + addMessageToConversationCustomData, + }); + + // --- LOAD CONVERSATION --- + const conversation = await loadConversationByMessageId({ + messageId: previous_response_id, + conversations, + }); + + // --- MAX CONVERSATION LENGTH CHECK --- + // TODO: both these checks same as in addMessageToConversation, make DRY + const numUserMessages = conversation.messages.reduce( + (acc, message) => (message.role === "user" ? acc + 1 : acc), + 0 + ); + if (numUserMessages >= maxUserMessagesInConversation) { + // Omit the system prompt and assume the user always received one response per message + throw makeRequestError({ + httpStatus: 400, + message: `Too many messages. You cannot send more than ${maxUserMessagesInConversation} messages in this conversation.`, + }); + } + + if (stream !== true) { + throw makeRequestError({ + httpStatus: 400, + message: "Stream must be true", + }); + } + + if (stream) { + dataStreamer.connect(res); + } + + const assistantResponseMessageId = new ObjectId(); + + const streamingMetadata = { + conversationId: conversation._id.toString(), + }; + const baseResponse = { + id: assistantResponseMessageId.toHexString(), + model: model, + object: "response" as const, + created_at: Date.now(), + temperature: temperature, + instructions: instructions ?? null, + metadata: { + // ...Any other stuff that should be in here? something for verified answers? + ...streamingMetadata, + }, + parallel_tool_calls: false, + tool_choice: tool_choice === "only" ? "auto" : tool_choice, + tools: + tools?.map((tool) => { + return { + name: tool.name, + description: tool.description, + parameters: tool.parameters, + container: null, + strict: null, + type: "function", + }; + }) ?? [], + top_p: null, + // TODO: below here, i think these all need additional work... + // TODO: 2x check if the following should be different if incomplete + incomplete_details: null, + // TODO: 2x check if the following should be different if incomplete + error: null, + // TODO: output + output: [], + output_text: "", + } satisfies OpenAI.Responses.ResponseCreatedEvent["response"]; + + dataStreamer.streamResponsesApiPart({ + type: "response.created", + // TODO: fix typescript stuff + response: { + ...baseResponse, + }, + } satisfies Omit); + + // TODO: make a Responses API version of generate response + const { messages } = await generateResponse({ + latestMessageText, + clientContext, + customData, + dataStreamer, + shouldStream, + reqId, + conversation, + traceId, + }); + + // --- SAVE QUESTION & RESPONSE --- + const dbNewMessages = await addMessagesToDatabase({ + conversations, + conversation, + messages, + assistantResponseMessageId, + }); + const dbAssistantMessage = dbNewMessages[dbNewMessages.length - 1]; + + assert(dbAssistantMessage !== undefined, "No assistant message found"); + + dataStreamer.streamResponsesApiPart({ + type: "response.completed", + // TODO: better figure out typescripting here... + response: { + ...baseResponse, + output_text: dbAssistantMessage.content, + output: dbMessageToResponseOutputItem(dbNewMessages), + }, + } satisfies Omit); + if (dataStreamer.connected) { + dataStreamer.disconnect(); + } + + await updateTraceIfExists({ + updateTrace, + reqId, + conversations, + conversationId: conversation._id, + assistantResponseMessageId: dbAssistantMessage.id, + }); + } catch (error) { + // TODO: better error handling, in line with the Responses API + const { httpStatus, message } = + (error as Error).name === "RequestError" + ? (error as RequestError) + : makeRequestError({ + message: (error as Error).message, + stack: (error as Error).stack, + httpStatus: 500, + }); + + sendErrorResponse({ + res, + reqId, + httpStatus, + errorMessage: message, + }); + } finally { + if (dataStreamer.connected) { + dataStreamer.disconnect(); + } + } + }; +} + +// --- HELPERS --- + +// TODO: implement... +function dbMessageToResponseOutputItem( + dbMessages: DbMessage[] +): OpenAI.Responses.ResponseOutputItem[] { + return []; +} + +// TODO: this is same as in addMessageToConversation +// ...have separate helper imported by both +async function getCustomData({ + req, + res, + addMessageToConversationCustomData, +}: { + req: ExpressRequest; + res: ExpressResponse; + addMessageToConversationCustomData?: AddCustomDataFunc; +}) { + try { + return addMessageToConversationCustomData + ? await addMessageToConversationCustomData(req, res) + : undefined; + } catch (_err) { + throw makeRequestError({ + httpStatus: 500, + message: "Unable to process custom data", + }); + } +} + +// TODO: this is same as in addMessageToConversation +// ...have separate helper imported by both +interface AddMessagesToDatabaseParams { + conversation: Conversation; + conversations: ConversationsService; + messages: SomeMessage[]; + assistantResponseMessageId: ObjectId; + store: boolean; +} +async function addMessagesToDatabase({ + conversation, + conversations, + messages, + assistantResponseMessageId, + // TODO: handle if store is false, only include metadata, no content. + store, +}: AddMessagesToDatabaseParams) { + ( + messages as Parameters< + typeof conversations.addManyConversationMessages + >[0]["messages"] + )[messages.length - 1].id = assistantResponseMessageId; + + const conversationId = conversation._id; + const dbMessages = await conversations.addManyConversationMessages({ + conversationId, + messages, + }); + return dbMessages; +} + +async function loadConversationByMessageId({ + messageId, + conversations, + customData, +}: { + messageId?: string; + conversations: ConversationsService; + customData?: Record; +}): Promise { + if (!messageId) { + return await conversations.create({ + customData, + }); + } + const conversation = await conversations.findByMessageId({ + messageId: ObjectId.createFromHexString(messageId), + }); + if (!conversation) { + throw makeRequestError({ + httpStatus: 404, + message: `Message ${messageId} not found`, + }); + } + return conversation; +} diff --git a/packages/mongodb-rag-core/package.json b/packages/mongodb-rag-core/package.json index 644a650fd..5db57758a 100644 --- a/packages/mongodb-rag-core/package.json +++ b/packages/mongodb-rag-core/package.json @@ -100,7 +100,7 @@ "ignore": "^5.3.2", "langchain": "^0.3.5", "mongodb": "^6.3.0", - "openai": "^4.95.0", + "openai": "^5.2.0", "rimraf": "^6.0.1", "simple-git": "^3.27.0", "toml": "^3.0.0", diff --git a/packages/mongodb-rag-core/src/DataStreamer.ts b/packages/mongodb-rag-core/src/DataStreamer.ts index 423e6ec21..b561affe3 100644 --- a/packages/mongodb-rag-core/src/DataStreamer.ts +++ b/packages/mongodb-rag-core/src/DataStreamer.ts @@ -15,6 +15,7 @@ interface ServerSentEventDispatcher { connect(): void; disconnect(): void; sendData(data: Data): void; + sendResponsesApiStreamData(data: OpenAI.Responses.ResponseStreamEvent): void; sendEvent(eventType: string, data: Data): void; } @@ -43,6 +44,9 @@ function makeServerSentEventDispatcher< res.write(`event: ${eventType}\n`); res.write(`data: ${JSON.stringify(data)}\n\n`); }, + sendResponsesApiStreamData(data) { + res.write(`data: ${JSON.stringify(data)}\n\n`); + }, }; } @@ -122,6 +126,9 @@ export interface DataStreamer { disconnect(): void; streamData(data: SomeStreamEvent): void; stream(params: StreamParams): Promise; + streamResponsesApiPart( + data: Omit + ): void; } /** @@ -131,6 +138,7 @@ export function makeDataStreamer(): DataStreamer { let connected = false; let sse: ServerSentEventDispatcher | undefined; + let responseSequenceNumber = 0; return { get connected() { return connected; @@ -170,6 +178,22 @@ export function makeDataStreamer(): DataStreamer { sse?.sendData(data); }, + streamResponsesApiPart( + data: Omit + ) { + if (!this.connected) { + throw new Error( + `Tried to stream data, but there's no SSE connection. Call DataStreamer.connect() first.` + ); + } + sse?.sendResponsesApiStreamData({ + ...data, + sequence_number: responseSequenceNumber, + // TODO: see if can remove the cast + } as OpenAI.Responses.ResponseStreamEvent); + responseSequenceNumber++; + }, + /** Streams all message events in an event stream. */ diff --git a/packages/mongodb-rag-core/src/conversations/ConversationsService.ts b/packages/mongodb-rag-core/src/conversations/ConversationsService.ts index 176155b87..f6b6567ee 100644 --- a/packages/mongodb-rag-core/src/conversations/ConversationsService.ts +++ b/packages/mongodb-rag-core/src/conversations/ConversationsService.ts @@ -213,6 +213,11 @@ export type AddManyConversationMessagesParams = { export interface FindByIdParams { _id: ObjectId; } + +export interface FindByMessageIdParams { + messageId: ObjectId; +} + export interface RateMessageParams { conversationId: ObjectId; messageId: ObjectId; @@ -264,6 +269,10 @@ export interface ConversationsService { ) => Promise; findById: ({ _id }: FindByIdParams) => Promise; + findByMessageId: ({ + messageId, + }: FindByMessageIdParams) => Promise; + /** Rate a {@link Message} in a {@link Conversation}. */ diff --git a/packages/mongodb-rag-core/src/conversations/MongoDbConversations.test.ts b/packages/mongodb-rag-core/src/conversations/MongoDbConversations.test.ts index 23f2176e7..1689ca4be 100644 --- a/packages/mongodb-rag-core/src/conversations/MongoDbConversations.test.ts +++ b/packages/mongodb-rag-core/src/conversations/MongoDbConversations.test.ts @@ -204,6 +204,16 @@ describe("Conversations Service", () => { }); expect(conversationInDb).toBeNull(); }); + + test("should find a conversation by message id", async () => { + // TODO: implement + }); + test("should return null if cannot find a conversation by message id", async () => { + const conversationInDb = await conversationsService.findByMessageId({ + messageId: new BSON.ObjectId(), + }); + expect(conversationInDb).toBeNull(); + }); test("Should rate a message", async () => { const { _id: conversationId } = await conversationsService.create({ initialMessages: [systemPrompt], diff --git a/packages/mongodb-rag-core/src/conversations/MongoDbConversations.ts b/packages/mongodb-rag-core/src/conversations/MongoDbConversations.ts index ea093f2d5..69be5fc8b 100644 --- a/packages/mongodb-rag-core/src/conversations/MongoDbConversations.ts +++ b/packages/mongodb-rag-core/src/conversations/MongoDbConversations.ts @@ -16,6 +16,7 @@ import { SystemMessage, CommentMessageParams, ToolMessage, + FindByMessageIdParams, } from "./ConversationsService"; /** @@ -103,6 +104,13 @@ export function makeMongoDbConversationsService( return conversation; }, + async findByMessageId({ messageId }: FindByMessageIdParams) { + const conversation = await conversationsCollection.findOne({ + "messages.id": messageId, + }); + return conversation; + }, + async rateMessage({ conversationId, messageId,