diff --git a/apps/chat-e2e/src/assertions/api/apiAssertion.ts b/apps/chat-e2e/src/assertions/api/apiAssertion.ts index b64b93150..0b5de1a92 100644 --- a/apps/chat-e2e/src/assertions/api/apiAssertion.ts +++ b/apps/chat-e2e/src/assertions/api/apiAssertion.ts @@ -58,7 +58,7 @@ export class ApiAssertion { expectedModel: DialAIEntityModel, ) { expect - .soft(request.modelId, ExpectedMessages.chatRequestModelIsValid) + .soft(request.model?.id, ExpectedMessages.chatRequestModelIsValid) .toBe(expectedModel.id); } diff --git a/apps/chat-e2e/src/testData/api/chatApiHelper.ts b/apps/chat-e2e/src/testData/api/chatApiHelper.ts index 405e66ad8..5ea414bf5 100644 --- a/apps/chat-e2e/src/testData/api/chatApiHelper.ts +++ b/apps/chat-e2e/src/testData/api/chatApiHelper.ts @@ -1,7 +1,7 @@ import { Conversation } from '@/chat/types/chat'; import { API } from '@/src/testData'; import { BaseApiHelper } from '@/src/testData/api/baseApiHelper'; -import { BucketUtil } from '@/src/utils'; +import { BucketUtil, ModelsUtil } from '@/src/utils'; export class ChatApiHelper extends BaseApiHelper { public buildRequestData(conversation: Conversation) { @@ -25,7 +25,7 @@ export class ChatApiHelper extends BaseApiHelper { const commonData = { id: `conversations/${BucketUtil.getBucket()}/` + conversation.id, messages: [userMessage], - modelId: conversation.model.id, + model: ModelsUtil.getOpenAIEntity(conversation.model.id), prompt: conversation.prompt, temperature: conversation.temperature, selectedAddons: conversation.selectedAddons, @@ -34,7 +34,9 @@ export class ChatApiHelper extends BaseApiHelper { return conversation.assistantModelId ? { ...commonData, - assistantModelId: conversation.assistantModelId, + assistantModel: ModelsUtil.getOpenAIEntity( + conversation.assistantModelId, + ), } : commonData; } diff --git a/apps/chat-e2e/src/tests/chatHeader.test.ts b/apps/chat-e2e/src/tests/chatHeader.test.ts index 309fb530e..464fbd8eb 100644 --- a/apps/chat-e2e/src/tests/chatHeader.test.ts +++ b/apps/chat-e2e/src/tests/chatHeader.test.ts @@ -66,7 +66,7 @@ dialTest( const requestsData = await chat.sendRequestWithKeyboard(request, false); expect - .soft(requestsData.modelId, ExpectedMessages.requestModeIdIsValid) + .soft(requestsData.model.id, ExpectedMessages.requestModeIdIsValid) .toBe(conversation.model.id); expect .soft(requestsData.prompt, ExpectedMessages.requestPromptIsValid) diff --git a/apps/chat-e2e/src/tests/compareMode.test.ts b/apps/chat-e2e/src/tests/compareMode.test.ts index 5fd6fd093..94b6e671c 100644 --- a/apps/chat-e2e/src/tests/compareMode.test.ts +++ b/apps/chat-e2e/src/tests/compareMode.test.ts @@ -551,7 +551,7 @@ dialTest( expect .soft( - requestsData.rightRequest.modelId, + requestsData.rightRequest.model.id, ExpectedMessages.requestModeIdIsValid, ) .toBe(defaultModel.id); @@ -570,7 +570,7 @@ dialTest( expect .soft( - requestsData.leftRequest.modelId, + requestsData.leftRequest.model.id, ExpectedMessages.requestModeIdIsValid, ) .toBe(aModel.id); @@ -1460,13 +1460,13 @@ dialTest( ); expect .soft( - requestsData.rightRequest.modelId, + requestsData.rightRequest.model.id, ExpectedMessages.requestModeIdIsValid, ) .toBe(firstFolderConversation.conversations[0].model.id); expect .soft( - requestsData.leftRequest.modelId, + requestsData.leftRequest.model.id, ExpectedMessages.requestModeIdIsValid, ) .toBe(secondFolderConversation.conversations[0].model.id); diff --git a/apps/chat-e2e/src/tests/replay.test.ts b/apps/chat-e2e/src/tests/replay.test.ts index 5b4e526f9..eb113eacd 100644 --- a/apps/chat-e2e/src/tests/replay.test.ts +++ b/apps/chat-e2e/src/tests/replay.test.ts @@ -308,7 +308,10 @@ dialTest( 'Verify chat API request is sent with correct settings', async () => { expect - .soft(replayRequest.modelId, ExpectedMessages.chatRequestModelIsValid) + .soft( + replayRequest.model?.id, + ExpectedMessages.chatRequestModelIsValid, + ) .toBe(replayModel.id); expect .soft(replayRequest.prompt, ExpectedMessages.chatRequestPromptIsValid) @@ -413,7 +416,10 @@ dialTest( conversation.messages[0].content, ); expect - .soft(replayRequest.modelId, ExpectedMessages.chatRequestModelIsValid) + .soft( + replayRequest.model?.id, + ExpectedMessages.chatRequestModelIsValid, + ) .toBe(conversation.model.id); expect .soft(replayRequest.prompt, ExpectedMessages.chatRequestPromptIsValid) @@ -651,7 +657,10 @@ dialTest( true, ); expect - .soft(replayRequest.modelId, ExpectedMessages.chatRequestModelIsValid) + .soft( + replayRequest.model.id, + ExpectedMessages.chatRequestModelIsValid, + ) .toBe(conversation.model.id); }, ); @@ -674,7 +683,7 @@ dialTest( const newMessage = '2+3'; const newRequest = await chat.sendRequestWithButton(newMessage); expect - .soft(newRequest.modelId, ExpectedMessages.chatRequestModelIsValid) + .soft(newRequest.model.id, ExpectedMessages.chatRequestModelIsValid) .toBe(conversation.model.id); expect .soft( @@ -850,7 +859,10 @@ dialTest( const modelId = i === 1 ? ImportedModelIds.CHAT_BISON : ImportedModelIds.GPT_4; expect - .soft(requests[i].modelId, ExpectedMessages.chatRequestModelIsValid) + .soft( + requests[i].model.id, + ExpectedMessages.chatRequestModelIsValid, + ) .toBe(modelId); } }, diff --git a/apps/chat/src/components/Chat/Chat.tsx b/apps/chat/src/components/Chat/Chat.tsx index eaa26ac31..fdc713702 100644 --- a/apps/chat/src/components/Chat/Chat.tsx +++ b/apps/chat/src/components/Chat/Chat.tsx @@ -310,14 +310,16 @@ export const ChatView = memo(() => { useLayoutEffect(() => { if (selectedConversations.length > 0) { const mergedMessages: MergedMessages[] = []; - const firstConversationMessages = excludeSystemMessages( - selectedConversations[0].messages, + const userMessages = selectedConversations.map((conv) => + excludeSystemMessages(conv.messages), ); - for (let i = 0; i < firstConversationMessages.length; i++) { + const messagesLength = userMessages[0].length; + + for (let i = 0; i < messagesLength; i++) { mergedMessages.push( - selectedConversations.map((conv) => [ + selectedConversations.map((conv, convIndex) => [ conv, - excludeSystemMessages(conv.messages)[i] || { + userMessages[convIndex][i] || { role: Role.Assistant, content: '', }, diff --git a/apps/chat/src/pages/api/chat.ts b/apps/chat/src/pages/api/chat.ts index b7551d777..3ea7cfb8f 100644 --- a/apps/chat/src/pages/api/chat.ts +++ b/apps/chat/src/pages/api/chat.ts @@ -9,7 +9,6 @@ import { getUserMessageCustomContent, limitMessagesByTokens, } from '@/src/utils/server/chat'; -import { getSortedEntities } from '@/src/utils/server/get-sorted-entities'; import { ChatBody } from '@/src/types/chat'; import { EntityType } from '@/src/types/common'; @@ -30,33 +29,22 @@ const handler = async (req: NextApiRequest, res: NextApiResponse) => { } const { - modelId, id, messages, prompt, temperature, selectedAddons, - assistantModelId, + model, + assistantModel, } = req.body as ChatBody; try { const token = await getToken({ req }); - const models = await getSortedEntities(token); - const model = models.find( - ({ id, reference }) => id === modelId || reference === modelId, - ); - const assistantModel = assistantModelId - ? models.find( - ({ id, reference }) => - id === assistantModelId || reference === assistantModelId, - ) - : undefined; if ( !id || !model || - (!!assistantModelId && !assistantModel) || - (!!assistantModelId && model.type !== EntityType.Assistant) || + (!!assistantModel && model.type !== EntityType.Assistant) || (!prompt && !messages?.length) ) { return res.status(400).send(errorsMessages[400]); @@ -119,7 +107,7 @@ const handler = async (req: NextApiRequest, res: NextApiResponse) => { temperature: temperatureToUse, messages: messagesToSend, selectedAddonsIds: selectedAddons?.length ? selectedAddons : undefined, - assistantModelId, + assistantModelId: assistantModel?.id, userJWT: token?.access_token as string, chatId: id, jobTitle: token?.jobTitle as string, @@ -158,7 +146,7 @@ const handler = async (req: NextApiRequest, res: NextApiResponse) => { return chatErrorHandler({ error, res, - msg: `Error while sending chat request to '${modelId}'`, + msg: `Error while sending chat request to '${model?.id}'`, }); } }; diff --git a/apps/chat/src/store/conversations/conversations.epics.ts b/apps/chat/src/store/conversations/conversations.epics.ts index 894cea560..be929b1ef 100644 --- a/apps/chat/src/store/conversations/conversations.epics.ts +++ b/apps/chat/src/store/conversations/conversations.epics.ts @@ -1333,14 +1333,14 @@ const streamMessageEpic: AppEpic = (action$, state$) => } if (conversationModelType === EntityType.Assistant && assistantModelId) { modelAdditionalSettings = { - assistantModelId, + assistantModel: modelsMap[assistantModelId], temperature: payload.conversation.temperature, selectedAddons, }; } const chatBody: ChatBody = { - modelId: payload.conversation.model.id, + model: modelsMap[payload.conversation.model.id], messages: payload.conversation.messages .filter( (message, index) => diff --git a/apps/chat/src/types/chat.ts b/apps/chat/src/types/chat.ts index f9f28723e..08bc79ee4 100644 --- a/apps/chat/src/types/chat.ts +++ b/apps/chat/src/types/chat.ts @@ -1,3 +1,5 @@ +import { DialAIEntityModel } from '@/src/types/models'; + import { ConversationInfo, Message, ShareEntity } from '@epam/ai-dial-shared'; export enum CopyTableType { @@ -7,13 +9,13 @@ export enum CopyTableType { } export interface ChatBody { - modelId: string; messages: Message[]; id: string; prompt?: string; temperature?: number; selectedAddons?: string[]; - assistantModelId?: string; + model?: DialAIEntityModel; + assistantModel?: DialAIEntityModel; } export interface RateBody {