@@ -2,10 +2,14 @@ import {
2
2
CompletionHandler ,
3
3
IInlineCompletionContext
4
4
} from '@jupyterlab/completer' ;
5
- import { LLM } from '@langchain/core/language_models/llms' ;
6
- import { MistralAI } from '@langchain/mistralai' ;
5
+ import { BaseChatModel } from '@langchain/core/language_models/chat_models' ;
6
+ import {
7
+ BaseMessage ,
8
+ HumanMessage ,
9
+ SystemMessage
10
+ } from '@langchain/core/messages' ;
11
+ import { ChatMistralAI } from '@langchain/mistralai' ;
7
12
import { Throttler } from '@lumino/polling' ;
8
- import { CompletionRequest } from '@mistralai/mistralai' ;
9
13
10
14
import { BaseCompleter , IBaseCompleter } from './base-completer' ;
11
15
import { COMPLETION_SYSTEM_PROMPT } from '../provider' ;
@@ -15,55 +19,35 @@ import { COMPLETION_SYSTEM_PROMPT } from '../provider';
15
19
*/
16
20
const INTERVAL = 1000 ;
17
21
18
- /**
19
- * Timeout to avoid endless requests
20
- */
21
- const REQUEST_TIMEOUT = 3000 ;
22
-
23
22
export class CodestralCompleter implements IBaseCompleter {
24
23
constructor ( options : BaseCompleter . IOptions ) {
25
- // this._requestCompletion = options.requestCompletion;
26
- this . _mistralProvider = new MistralAI ( { ...options . settings } ) ;
24
+ this . _mistralProvider = new ChatMistralAI ( { ...options . settings } ) ;
27
25
this . _throttler = new Throttler (
28
- async ( data : CompletionRequest ) => {
29
- const invokedData = data ;
30
-
31
- // Request completion.
32
- const request = this . _mistralProvider . completionWithRetry (
33
- data ,
34
- { } ,
35
- false
36
- ) ;
37
- const timeoutPromise = new Promise < null > ( resolve => {
38
- return setTimeout ( ( ) => resolve ( null ) , REQUEST_TIMEOUT ) ;
39
- } ) ;
40
-
41
- // Fetch again if the request is too long or if the prompt has changed.
42
- const response = await Promise . race ( [ request , timeoutPromise ] ) ;
43
- if (
44
- response === null ||
45
- invokedData . prompt !== this . _currentData ?. prompt
46
- ) {
47
- return {
48
- items : [ ] ,
49
- fetchAgain : true
50
- } ;
51
- }
52
-
26
+ async ( messages : BaseMessage [ ] ) => {
27
+ const response = await this . _mistralProvider . invoke ( messages ) ;
53
28
// Extract results of completion request.
54
- const items = response . choices . map ( ( choice : any ) => {
55
- return { insertText : choice . message . content as string } ;
56
- } ) ;
57
-
58
- return {
59
- items
60
- } ;
29
+ const items = [ ] ;
30
+ if ( typeof response . content === 'string' ) {
31
+ items . push ( {
32
+ insertText : response . content
33
+ } ) ;
34
+ } else {
35
+ response . content . forEach ( content => {
36
+ if ( content . type !== 'text' ) {
37
+ return ;
38
+ }
39
+ items . push ( {
40
+ insertText : content . text
41
+ } ) ;
42
+ } ) ;
43
+ }
44
+ return { items } ;
61
45
} ,
62
46
{ limit : INTERVAL }
63
47
) ;
64
48
}
65
49
66
- get provider ( ) : LLM {
50
+ get provider ( ) : BaseChatModel {
67
51
return this . _mistralProvider ;
68
52
}
69
53
@@ -77,49 +61,27 @@ export class CodestralCompleter implements IBaseCompleter {
77
61
this . _prompt = value ;
78
62
}
79
63
80
- set requestCompletion ( value : ( ) => void ) {
81
- this . _requestCompletion = value ;
82
- }
83
-
84
64
async fetch (
85
65
request : CompletionHandler . IRequest ,
86
66
context : IInlineCompletionContext
87
67
) {
88
68
const { text, offset : cursorOffset } = request ;
89
69
const prompt = text . slice ( 0 , cursorOffset ) ;
90
- const suffix = text . slice ( cursorOffset ) ;
91
70
92
- const data = {
93
- prompt,
94
- suffix,
95
- model : this . _mistralProvider . model ,
96
- // temperature: 0,
97
- // top_p: 1,
98
- // max_tokens: 1024,
99
- // min_tokens: 0,
100
- stream : false ,
101
- // random_seed: 1337,
102
- stop : [ ]
103
- } ;
71
+ const messages : BaseMessage [ ] = [
72
+ new SystemMessage ( this . _prompt ) ,
73
+ new HumanMessage ( prompt )
74
+ ] ;
104
75
105
76
try {
106
- this . _currentData = data ;
107
- const completionResult = await this . _throttler . invoke ( data ) ;
108
- if ( completionResult . fetchAgain ) {
109
- if ( this . _requestCompletion ) {
110
- this . _requestCompletion ( ) ;
111
- }
112
- }
113
- return { items : completionResult . items } ;
77
+ return await this . _throttler . invoke ( messages ) ;
114
78
} catch ( error ) {
115
79
console . error ( 'Error fetching completions' , error ) ;
116
80
return { items : [ ] } ;
117
81
}
118
82
}
119
83
120
- private _requestCompletion ?: ( ) => void ;
121
84
private _throttler : Throttler ;
122
- private _mistralProvider : MistralAI ;
85
+ private _mistralProvider : ChatMistralAI ;
123
86
private _prompt : string = COMPLETION_SYSTEM_PROMPT ;
124
- private _currentData : CompletionRequest | null = null ;
125
87
}
0 commit comments