Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improves the relevance of codestral completion #18

Merged
merged 5 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/completion-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export class CompletionProvider implements IInlineCompletionProvider {

constructor(options: CompletionProvider.IOptions) {
const { name, settings } = options;
this._requestCompletion = options.requestCompletion;
this.setCompleter(name, settings);
}

Expand All @@ -28,6 +29,9 @@ export class CompletionProvider implements IInlineCompletionProvider {
setCompleter(name: string, settings: ReadonlyPartialJSONObject) {
try {
this._completer = getCompleter(name, settings);
if (this._completer) {
this._completer.requestCompletion = this._requestCompletion;
}
this._name = this._completer === null ? 'None' : name;
} catch (e: any) {
this._completer = null;
Expand Down Expand Up @@ -65,11 +69,13 @@ export class CompletionProvider implements IInlineCompletionProvider {
}

private _name: string = 'None';
private _requestCompletion: () => void;
private _completer: IBaseCompleter | null = null;
}

export namespace CompletionProvider {
export interface IOptions extends BaseCompleter.IOptions {
name: string;
requestCompletion: () => void;
}
}
5 changes: 4 additions & 1 deletion src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ const aiProviderPlugin: JupyterFrontEndPlugin<IAIProvider> = {
manager: ICompletionProviderManager,
settingRegistry: ISettingRegistry
): IAIProvider => {
const aiProvider = new AIProvider({ completionProviderManager: manager });
const aiProvider = new AIProvider({
completionProviderManager: manager,
requestCompletion: () => app.commands.execute('inline-completer:invoke')
});

settingRegistry
.load(aiProviderPlugin.id)
Expand Down
5 changes: 5 additions & 0 deletions src/llm-models/base-completer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ export interface IBaseCompleter {
*/
provider: LLM;

/**
* The function to fetch a new completion.
*/
requestCompletion?: () => void;

/**
* The fetch request for the LLM completer.
*/
Expand Down
72 changes: 57 additions & 15 deletions src/llm-models/codestral-completer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,67 @@ import { CompletionRequest } from '@mistralai/mistralai';

import { BaseCompleter, IBaseCompleter } from './base-completer';

/*
/**
* The Mistral API has a rate limit of 1 request per second
*/
const INTERVAL = 1000;

/**
* Timeout to avoid endless requests
*/
const REQUEST_TIMEOUT = 3000;

export class CodestralCompleter implements IBaseCompleter {
constructor(options: BaseCompleter.IOptions) {
// this._requestCompletion = options.requestCompletion;
this._mistralProvider = new MistralAI({ ...options.settings });
this._throttler = new Throttler(async (data: CompletionRequest) => {
const response = await this._mistralProvider.completionWithRetry(
data,
{},
false
);
const items = response.choices.map((choice: any) => {
return { insertText: choice.message.content as string };
});
this._throttler = new Throttler(
async (data: CompletionRequest) => {
const invokedData = data;

// Request completion.
const request = this._mistralProvider.completionWithRetry(
data,
{},
false
);
const timeoutPromise = new Promise<null>(resolve => {
return setTimeout(() => resolve(null), REQUEST_TIMEOUT);
});

// Fetch again if the request is too long or if the prompt has changed.
const response = await Promise.race([request, timeoutPromise]);
if (
response === null ||
invokedData.prompt !== this._currentData?.prompt
) {
return {
items: [],
fetchAgain: true
};
}

return {
items
};
}, INTERVAL);
// Extract results of completion request.
const items = response.choices.map((choice: any) => {
return { insertText: choice.message.content as string };
});

return {
items
};
},
{ limit: INTERVAL }
);
}

get provider(): LLM {
return this._mistralProvider;
}

set requestCompletion(value: () => void) {
this._requestCompletion = value;
}

async fetch(
request: CompletionHandler.IRequest,
context: IInlineCompletionContext
Expand All @@ -59,13 +92,22 @@ export class CodestralCompleter implements IBaseCompleter {
};

try {
return this._throttler.invoke(data);
this._currentData = data;
const completionResult = await this._throttler.invoke(data);
if (completionResult.fetchAgain) {
if (this._requestCompletion) {
this._requestCompletion();
}
}
return { items: completionResult.items };
} catch (error) {
console.error('Error fetching completions', error);
return { items: [] };
}
}

private _requestCompletion?: () => void;
private _throttler: Throttler;
private _mistralProvider: MistralAI;
private _currentData: CompletionRequest | null = null;
}
7 changes: 6 additions & 1 deletion src/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ export class AIProvider implements IAIProvider {
constructor(options: AIProvider.IOptions) {
this._completionProvider = new CompletionProvider({
name: 'None',
settings: {}
settings: {},
requestCompletion: options.requestCompletion
});
options.completionProviderManager.registerInlineProvider(
this._completionProvider
Expand Down Expand Up @@ -103,6 +104,10 @@ export namespace AIProvider {
* The completion provider manager in which register the LLM completer.
*/
completionProviderManager: ICompletionProviderManager;
/**
* The application commands registry.
*/
requestCompletion: () => void;
}

/**
Expand Down