Skip to content

Commit f68be53

Browse files
authored
Merge pull request #6 from jtpio/langchain
Switch to using langchain.js
2 parents 6b2cae1 + a7b0ca1 commit f68be53

File tree

5 files changed

+253
-33
lines changed

5 files changed

+253
-33
lines changed

package.json

+3-2
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,10 @@
6060
"@jupyterlab/notebook": "^4.2.0",
6161
"@jupyterlab/rendermime": "^4.2.0",
6262
"@jupyterlab/settingregistry": "^4.2.0",
63+
"@langchain/core": "^0.3.13",
64+
"@langchain/mistralai": "^0.1.1",
6365
"@lumino/coreutils": "^2.1.2",
64-
"@lumino/polling": "^2.1.2",
65-
"@mistralai/mistralai": "^0.5.0"
66+
"@lumino/polling": "^2.1.2"
6667
},
6768
"devDependencies": {
6869
"@jupyterlab/builder": "^4.0.0",

src/handler.ts

+21-16
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@ import {
1010
INewMessage
1111
} from '@jupyter/chat';
1212
import { UUID } from '@lumino/coreutils';
13-
import MistralClient from '@mistralai/mistralai';
13+
import type { ChatMistralAI } from '@langchain/mistralai';
14+
import {
15+
AIMessage,
16+
HumanMessage,
17+
mergeMessageRuns
18+
} from '@langchain/core/messages';
1419

1520
export type ConnectionMessage = {
1621
type: 'connection';
@@ -34,27 +39,27 @@ export class CodestralHandler extends ChatModel {
3439
};
3540
this.messageAdded(msg);
3641
this._history.messages.push(msg);
37-
const response = await this._mistralClient.chat({
38-
model: 'codestral-latest',
39-
messages: this._history.messages.map(msg => {
40-
return {
41-
role: msg.sender.username === 'User' ? 'user' : 'assistant',
42-
content: msg.body
43-
};
42+
43+
const messages = mergeMessageRuns(
44+
this._history.messages.map(msg => {
45+
if (msg.sender.username === 'User') {
46+
return new HumanMessage(msg.body);
47+
}
48+
return new AIMessage(msg.body);
4449
})
45-
});
46-
if (response.choices.length === 0) {
47-
return false;
48-
}
49-
const botMessage = response.choices[0].message;
50+
);
51+
const response = await this._mistralClient.invoke(messages);
52+
// TODO: fix deprecated response.text
53+
const content = response.text;
5054
const botMsg: IChatMessage = {
5155
id: UUID.uuid4(),
52-
body: botMessage.content as string,
56+
body: content,
5357
sender: { username: 'Codestral' },
5458
time: Date.now(),
5559
type: 'msg'
5660
};
5761
this.messageAdded(botMsg);
62+
this._history.messages.push(botMsg);
5863
return true;
5964
}
6065

@@ -70,12 +75,12 @@ export class CodestralHandler extends ChatModel {
7075
super.messageAdded(message);
7176
}
7277

73-
private _mistralClient: MistralClient;
78+
private _mistralClient: ChatMistralAI;
7479
private _history: IChatHistory = { messages: [] };
7580
}
7681

7782
export namespace CodestralHandler {
7883
export interface IOptions extends ChatModel.IOptions {
79-
mistralClient: MistralClient;
84+
mistralClient: ChatMistralAI;
8085
}
8186
}

src/index.ts

+29-4
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,10 @@ import { ICompletionProviderManager } from '@jupyterlab/completer';
1313
import { INotebookTracker } from '@jupyterlab/notebook';
1414
import { IRenderMimeRegistry } from '@jupyterlab/rendermime';
1515
import { ISettingRegistry } from '@jupyterlab/settingregistry';
16-
import { CodestralProvider } from './provider';
17-
import MistralClient from '@mistralai/mistralai';
16+
import { ChatMistralAI, MistralAI } from '@langchain/mistralai';
1817

1918
import { CodestralHandler } from './handler';
20-
21-
const mistralClient = new MistralClient();
19+
import { CodestralProvider } from './provider';
2220

2321
const inlineProviderPlugin: JupyterFrontEndPlugin<void> = {
2422
id: 'jupyterlab-codestral:inline-provider',
@@ -29,6 +27,10 @@ const inlineProviderPlugin: JupyterFrontEndPlugin<void> = {
2927
manager: ICompletionProviderManager,
3028
settingRegistry: ISettingRegistry
3129
): void => {
30+
const mistralClient = new MistralAI({
31+
model: 'codestral-latest',
32+
apiKey: 'TMP'
33+
});
3234
const provider = new CodestralProvider({ mistralClient });
3335
manager.registerInlineProvider(provider);
3436

@@ -73,6 +75,10 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
7375
});
7476
}
7577

78+
const mistralClient = new ChatMistralAI({
79+
model: 'codestral-latest',
80+
apiKey: 'TMP'
81+
});
7682
const chatHandler = new CodestralHandler({
7783
mistralClient,
7884
activeCellManager: activeCellManager
@@ -88,6 +94,25 @@ const chatPlugin: JupyterFrontEndPlugin<void> = {
8894
chatHandler.config = { sendWithShiftEnter, enableCodeToolbar };
8995
}
9096

97+
// TODO: handle the apiKey better
98+
settingsRegistry
99+
?.load(inlineProviderPlugin.id)
100+
.then(settings => {
101+
const updateKey = () => {
102+
const apiKey = settings.get('apiKey').composite as string;
103+
mistralClient.apiKey = apiKey;
104+
};
105+
106+
settings.changed.connect(() => updateKey());
107+
updateKey();
108+
})
109+
.catch(reason => {
110+
console.error(
111+
`Failed to load settings for ${inlineProviderPlugin.id}`,
112+
reason
113+
);
114+
});
115+
91116
Promise.all([app.restored, settingsRegistry?.load(chatPlugin.id)])
92117
.then(([, settings]) => {
93118
if (!settings) {

src/provider.ts

+10-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ import {
66

77
import { Throttler } from '@lumino/polling';
88

9-
import MistralClient, { CompletionRequest } from '@mistralai/mistralai';
9+
import { CompletionRequest } from '@mistralai/mistralai';
10+
11+
import type { MistralAI } from '@langchain/mistralai';
1012

1113
/*
1214
* The Mistral API has a rate limit of 1 request per second
@@ -20,7 +22,11 @@ export class CodestralProvider implements IInlineCompletionProvider {
2022
constructor(options: CodestralProvider.IOptions) {
2123
this._mistralClient = options.mistralClient;
2224
this._throttler = new Throttler(async (data: CompletionRequest) => {
23-
const response = await this._mistralClient.completion(data);
25+
const response = await this._mistralClient.completionWithRetry(
26+
data,
27+
{},
28+
false
29+
);
2430
const items = response.choices.map((choice: any) => {
2531
return { insertText: choice.message.content as string };
2632
});
@@ -61,11 +67,11 @@ export class CodestralProvider implements IInlineCompletionProvider {
6167
}
6268

6369
private _throttler: Throttler;
64-
private _mistralClient: MistralClient;
70+
private _mistralClient: MistralAI;
6571
}
6672

6773
export namespace CodestralProvider {
6874
export interface IOptions {
69-
mistralClient: MistralClient;
75+
mistralClient: MistralAI;
7076
}
7177
}

0 commit comments

Comments
 (0)