Skip to content

Commit 4207dd1

Browse files
authoredFeb 10, 2025
Merge pull request #31 from brichet/codestral_completion_chat_model
Use a chat model instead of LLM for codestral completion
2 parents 640f525 + 6b41d23 commit 4207dd1

File tree

3 files changed

+38
-83
lines changed

3 files changed

+38
-83
lines changed
 

‎src/llm-models/chrome-completer.ts

+1-12
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,6 @@ import { HumanMessage, SystemMessage } from '@langchain/core/messages';
88
import { BaseCompleter, IBaseCompleter } from './base-completer';
99
import { COMPLETION_SYSTEM_PROMPT } from '../provider';
1010

11-
/**
12-
* The initial prompt to use for the completion.
13-
* Add extra instructions to get better results.
14-
*/
15-
const CUSTOM_SYSTEM_PROMPT = `${COMPLETION_SYSTEM_PROMPT}
16-
Only give raw strings back, do not format the response using backticks!
17-
The output should be a single string, and should correspond to what a human users
18-
would write.
19-
Do not include the prompt in the output, only the string that should be appended to the current input.
20-
`;
21-
2211
/**
2312
* Regular expression to match the '```' string at the start of a string.
2413
* So the completions returned by the LLM can still be kept after removing the code block formatting.
@@ -97,5 +86,5 @@ export class ChromeCompleter implements IBaseCompleter {
9786
}
9887

9988
private _chromeProvider: ChromeAI;
100-
private _prompt: string = CUSTOM_SYSTEM_PROMPT;
89+
private _prompt: string = COMPLETION_SYSTEM_PROMPT;
10190
}

‎src/llm-models/codestral-completer.ts

+33-71
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@ import {
22
CompletionHandler,
33
IInlineCompletionContext
44
} from '@jupyterlab/completer';
5-
import { LLM } from '@langchain/core/language_models/llms';
6-
import { MistralAI } from '@langchain/mistralai';
5+
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
6+
import {
7+
BaseMessage,
8+
HumanMessage,
9+
SystemMessage
10+
} from '@langchain/core/messages';
11+
import { ChatMistralAI } from '@langchain/mistralai';
712
import { Throttler } from '@lumino/polling';
8-
import { CompletionRequest } from '@mistralai/mistralai';
913

1014
import { BaseCompleter, IBaseCompleter } from './base-completer';
1115
import { COMPLETION_SYSTEM_PROMPT } from '../provider';
@@ -15,55 +19,35 @@ import { COMPLETION_SYSTEM_PROMPT } from '../provider';
1519
*/
1620
const INTERVAL = 1000;
1721

18-
/**
19-
* Timeout to avoid endless requests
20-
*/
21-
const REQUEST_TIMEOUT = 3000;
22-
2322
export class CodestralCompleter implements IBaseCompleter {
2423
constructor(options: BaseCompleter.IOptions) {
25-
// this._requestCompletion = options.requestCompletion;
26-
this._mistralProvider = new MistralAI({ ...options.settings });
24+
this._mistralProvider = new ChatMistralAI({ ...options.settings });
2725
this._throttler = new Throttler(
28-
async (data: CompletionRequest) => {
29-
const invokedData = data;
30-
31-
// Request completion.
32-
const request = this._mistralProvider.completionWithRetry(
33-
data,
34-
{},
35-
false
36-
);
37-
const timeoutPromise = new Promise<null>(resolve => {
38-
return setTimeout(() => resolve(null), REQUEST_TIMEOUT);
39-
});
40-
41-
// Fetch again if the request is too long or if the prompt has changed.
42-
const response = await Promise.race([request, timeoutPromise]);
43-
if (
44-
response === null ||
45-
invokedData.prompt !== this._currentData?.prompt
46-
) {
47-
return {
48-
items: [],
49-
fetchAgain: true
50-
};
51-
}
52-
26+
async (messages: BaseMessage[]) => {
27+
const response = await this._mistralProvider.invoke(messages);
5328
// Extract results of completion request.
54-
const items = response.choices.map((choice: any) => {
55-
return { insertText: choice.message.content as string };
56-
});
57-
58-
return {
59-
items
60-
};
29+
const items = [];
30+
if (typeof response.content === 'string') {
31+
items.push({
32+
insertText: response.content
33+
});
34+
} else {
35+
response.content.forEach(content => {
36+
if (content.type !== 'text') {
37+
return;
38+
}
39+
items.push({
40+
insertText: content.text
41+
});
42+
});
43+
}
44+
return { items };
6145
},
6246
{ limit: INTERVAL }
6347
);
6448
}
6549

66-
get provider(): LLM {
50+
get provider(): BaseChatModel {
6751
return this._mistralProvider;
6852
}
6953

@@ -77,49 +61,27 @@ export class CodestralCompleter implements IBaseCompleter {
7761
this._prompt = value;
7862
}
7963

80-
set requestCompletion(value: () => void) {
81-
this._requestCompletion = value;
82-
}
83-
8464
async fetch(
8565
request: CompletionHandler.IRequest,
8666
context: IInlineCompletionContext
8767
) {
8868
const { text, offset: cursorOffset } = request;
8969
const prompt = text.slice(0, cursorOffset);
90-
const suffix = text.slice(cursorOffset);
9170

92-
const data = {
93-
prompt,
94-
suffix,
95-
model: this._mistralProvider.model,
96-
// temperature: 0,
97-
// top_p: 1,
98-
// max_tokens: 1024,
99-
// min_tokens: 0,
100-
stream: false,
101-
// random_seed: 1337,
102-
stop: []
103-
};
71+
const messages: BaseMessage[] = [
72+
new SystemMessage(this._prompt),
73+
new HumanMessage(prompt)
74+
];
10475

10576
try {
106-
this._currentData = data;
107-
const completionResult = await this._throttler.invoke(data);
108-
if (completionResult.fetchAgain) {
109-
if (this._requestCompletion) {
110-
this._requestCompletion();
111-
}
112-
}
113-
return { items: completionResult.items };
77+
return await this._throttler.invoke(messages);
11478
} catch (error) {
11579
console.error('Error fetching completions', error);
11680
return { items: [] };
11781
}
11882
}
11983

120-
private _requestCompletion?: () => void;
12184
private _throttler: Throttler;
122-
private _mistralProvider: MistralAI;
85+
private _mistralProvider: ChatMistralAI;
12386
private _prompt: string = COMPLETION_SYSTEM_PROMPT;
124-
private _currentData: CompletionRequest | null = null;
12587
}

‎src/provider.ts

+4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ programming language comment syntax. Produce clean code.
3030
The code is written in JupyterLab, a data analysis and code development
3131
environment which can execute code extended with additional syntax for
3232
interactive features, such as magics.
33+
Only give raw strings back, do not format the response using backticks.
34+
The output should be a single string, and should correspond to what a human users
35+
would write.
36+
Do not include the prompt in the output, only the string that should be appended to the current input.
3337
`;
3438

3539
export class AIProvider implements IAIProvider {

0 commit comments

Comments
 (0)