From 1ff8eef464027f4cf4c1982d0a35ccc53f0f56f3 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Mon, 28 Oct 2024 19:12:41 +0100 Subject: [PATCH 1/5] WIP on making the LLM providers more generics --- package.json | 3 +- schema/llm-provider.json | 24 ++++ src/{handler.ts => chat-handler.ts} | 39 ++++-- src/completion-providers/base-provider.ts | 16 +++ .../codestral-provider.ts | 90 ++++++++++++ src/completion-providers/index.ts | 1 + src/index.ts | 129 +++++++++--------- src/provider.ts | 117 ++++++++-------- src/token.ts | 16 +++ tsconfig.json | 2 +- 10 files changed, 302 insertions(+), 135 deletions(-) create mode 100644 schema/llm-provider.json rename src/{handler.ts => chat-handler.ts} (64%) create mode 100644 src/completion-providers/base-provider.ts create mode 100644 src/completion-providers/codestral-provider.ts create mode 100644 src/completion-providers/index.ts create mode 100644 src/token.ts diff --git a/package.json b/package.json index 4816377..539736a 100644 --- a/package.json +++ b/package.json @@ -63,7 +63,8 @@ "@langchain/core": "^0.3.13", "@langchain/mistralai": "^0.1.1", "@lumino/coreutils": "^2.1.2", - "@lumino/polling": "^2.1.2" + "@lumino/polling": "^2.1.2", + "@lumino/signaling": "^2.1.2" }, "devDependencies": { "@jupyterlab/builder": "^4.0.0", diff --git a/schema/llm-provider.json b/schema/llm-provider.json new file mode 100644 index 0000000..2855e34 --- /dev/null +++ b/schema/llm-provider.json @@ -0,0 +1,24 @@ +{ + "title": "LLM provider", + "description": "Provider settings", + "type": "object", + "properties": { + "provider": { + "type": "string", + "title": "The LLM provider", + "description": "The LLM provider to use for chat and completion", + "default": "None", + "enum": [ + "None", + "MistralAI" + ] + }, + "apiKey": { + "type": "string", + "title": "The Codestral API key", + "description": "The API key to use for Codestral", + "default": "" + } + }, + "additionalProperties": false +} diff --git a/src/handler.ts b/src/chat-handler.ts similarity index 64% rename from src/handler.ts rename to src/chat-handler.ts index 5e9d8d3..47c8867 100644 --- a/src/handler.ts +++ b/src/chat-handler.ts @@ -9,8 +9,8 @@ import { IChatMessage, INewMessage } from '@jupyter/chat'; +import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { UUID } from '@lumino/coreutils'; -import type { ChatMistralAI } from '@langchain/mistralai'; import { AIMessage, HumanMessage, @@ -22,10 +22,17 @@ export type ConnectionMessage = { client_id: string; }; -export class CodestralHandler extends ChatModel { - constructor(options: CodestralHandler.IOptions) { +export class ChatHandler extends ChatModel { + constructor(options: ChatHandler.IOptions) { super(options); - this._mistralClient = options.mistralClient; + this._llmClient = options.llmClient; + } + + get llmClient(): BaseChatModel | null { + return this._llmClient; + } + set llmClient(client: BaseChatModel | null) { + this._llmClient = client; } async sendMessage(message: INewMessage): Promise { @@ -38,6 +45,19 @@ export class CodestralHandler extends ChatModel { type: 'msg' }; this.messageAdded(msg); + + if (this._llmClient === null) { + const botMsg: IChatMessage = { + id: UUID.uuid4(), + body: '**Chat client not configured**', + sender: { username: 'ERROR' }, + time: Date.now(), + type: 'msg' + }; + this.messageAdded(botMsg); + return false; + } + this._history.messages.push(msg); const messages = mergeMessageRuns( @@ -48,13 +68,14 @@ export class CodestralHandler extends ChatModel { return new AIMessage(msg.body); }) ); - const response = await this._mistralClient.invoke(messages); + + const response = await this._llmClient.invoke(messages); // TODO: fix deprecated response.text const content = response.text; const botMsg: IChatMessage = { id: UUID.uuid4(), body: content, - sender: { username: 'Codestral' }, + sender: { username: 'Bot' }, time: Date.now(), type: 'msg' }; @@ -75,12 +96,12 @@ export class CodestralHandler extends ChatModel { super.messageAdded(message); } - private _mistralClient: ChatMistralAI; + private _llmClient: BaseChatModel | null; private _history: IChatHistory = { messages: [] }; } -export namespace CodestralHandler { +export namespace ChatHandler { export interface IOptions extends ChatModel.IOptions { - mistralClient: ChatMistralAI; + llmClient: BaseChatModel | null; } } diff --git a/src/completion-providers/base-provider.ts b/src/completion-providers/base-provider.ts new file mode 100644 index 0000000..4a3f6bf --- /dev/null +++ b/src/completion-providers/base-provider.ts @@ -0,0 +1,16 @@ +import { IInlineCompletionProvider } from '@jupyterlab/completer'; +import { LLM } from '@langchain/core/language_models/llms'; +import { JSONValue } from '@lumino/coreutils'; + +export interface IBaseProvider extends IInlineCompletionProvider { + configure(settings: { [property: string]: JSONValue }): void; +} + +// https://stackoverflow.com/questions/54724875/can-we-check-whether-property-is-readonly-in-typescript +export function isWritable(obj: T, key: keyof T) { + const desc = + Object.getOwnPropertyDescriptor(obj, key) || + Object.getOwnPropertyDescriptor(Object.getPrototypeOf(obj), key) || + {}; + return Boolean(desc.writable); +} diff --git a/src/completion-providers/codestral-provider.ts b/src/completion-providers/codestral-provider.ts new file mode 100644 index 0000000..4a5ad90 --- /dev/null +++ b/src/completion-providers/codestral-provider.ts @@ -0,0 +1,90 @@ +import { + CompletionHandler, + IInlineCompletionContext +} from '@jupyterlab/completer'; + +import { Throttler } from '@lumino/polling'; + +import { CompletionRequest } from '@mistralai/mistralai'; + +import type { MistralAI } from '@langchain/mistralai'; +import { JSONValue } from '@lumino/coreutils'; +import { IBaseProvider, isWritable } from './base-provider'; + +/* + * The Mistral API has a rate limit of 1 request per second + */ +const INTERVAL = 1000; + +export class CodestralProvider implements IBaseProvider { + readonly identifier = 'Codestral'; + readonly name = 'Codestral'; + + constructor(options: CodestralProvider.IOptions) { + this._mistralClient = options.mistralClient; + this._throttler = new Throttler(async (data: CompletionRequest) => { + const response = await this._mistralClient.completionWithRetry( + data, + {}, + false + ); + const items = response.choices.map((choice: any) => { + return { insertText: choice.message.content as string }; + }); + + return { + items + }; + }, INTERVAL); + } + + async fetch( + request: CompletionHandler.IRequest, + context: IInlineCompletionContext + ) { + const { text, offset: cursorOffset } = request; + const prompt = text.slice(0, cursorOffset); + const suffix = text.slice(cursorOffset); + + const data = { + prompt, + suffix, + model: 'codestral-latest', + // temperature: 0, + // top_p: 1, + // max_tokens: 1024, + // min_tokens: 0, + stream: false, + // random_seed: 1337, + stop: [] + }; + + try { + return this._throttler.invoke(data); + } catch (error) { + console.error('Error fetching completions', error); + return { items: [] }; + } + } + + configure(settings: { [property: string]: JSONValue }): void { + Object.entries(settings).forEach(([key, value], index) => { + if (key in this._mistralClient) { + if (isWritable(this._mistralClient, key as keyof MistralAI)) { + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore + this._mistralClient[key as keyof MistralAI] = value; + } + } + }); + } + + private _throttler: Throttler; + private _mistralClient: MistralAI; +} + +export namespace CodestralProvider { + export interface IOptions { + mistralClient: MistralAI; + } +} diff --git a/src/completion-providers/index.ts b/src/completion-providers/index.ts new file mode 100644 index 0000000..fdb3eeb --- /dev/null +++ b/src/completion-providers/index.ts @@ -0,0 +1 @@ +export * from './codestral-provider'; diff --git a/src/index.ts b/src/index.ts index 1e9ec03..bbbb300 100644 --- a/src/index.ts +++ b/src/index.ts @@ -13,55 +13,37 @@ import { ICompletionProviderManager } from '@jupyterlab/completer'; import { INotebookTracker } from '@jupyterlab/notebook'; import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { ISettingRegistry } from '@jupyterlab/settingregistry'; -import { ChatMistralAI, MistralAI } from '@langchain/mistralai'; -import { CodestralHandler } from './handler'; -import { CodestralProvider } from './provider'; +import { ChatHandler } from './chat-handler'; +import { ILlmProvider } from './token'; +import { LlmProvider } from './provider'; -const inlineProviderPlugin: JupyterFrontEndPlugin = { - id: 'jupyterlab-codestral:inline-provider', - autoStart: true, - requires: [ICompletionProviderManager, ISettingRegistry], - activate: ( - app: JupyterFrontEnd, - manager: ICompletionProviderManager, - settingRegistry: ISettingRegistry - ): void => { - const mistralClient = new MistralAI({ - model: 'codestral-latest', - apiKey: 'TMP' - }); - const provider = new CodestralProvider({ mistralClient }); - manager.registerInlineProvider(provider); - - settingRegistry - .load(inlineProviderPlugin.id) - .then(settings => { - const updateKey = () => { - const apiKey = settings.get('apiKey').composite as string; - mistralClient.apiKey = apiKey; - }; - - settings.changed.connect(() => updateKey()); - updateKey(); - }) - .catch(reason => { - console.error( - `Failed to load settings for ${inlineProviderPlugin.id}`, - reason - ); - }); - } -}; +// const inlineProviderPlugin: JupyterFrontEndPlugin = { +// id: 'jupyterlab-codestral:inline-provider', +// autoStart: true, +// requires: [ICompletionProviderManager, ILlmProvider, ISettingRegistry], +// activate: ( +// app: JupyterFrontEnd, +// manager: ICompletionProviderManager, +// llmProvider: ILlmProvider +// ): void => { +// llmProvider.providerChange.connect(() => { +// if (llmProvider.inlineCompleter !== null) { +// manager.registerInlineProvider(llmProvider.inlineCompleter); +// } +// }); +// } +// }; const chatPlugin: JupyterFrontEndPlugin = { id: 'jupyterlab-codestral:chat', - description: 'Codestral chat extension', + description: 'LLM chat extension', autoStart: true, optional: [INotebookTracker, ISettingRegistry, IThemeManager], - requires: [IRenderMimeRegistry], + requires: [ILlmProvider, IRenderMimeRegistry], activate: async ( app: JupyterFrontEnd, + llmProvider: ILlmProvider, rmRegistry: IRenderMimeRegistry, notebookTracker: INotebookTracker | null, settingsRegistry: ISettingRegistry | null, @@ -75,15 +57,15 @@ const chatPlugin: JupyterFrontEndPlugin = { }); } - const mistralClient = new ChatMistralAI({ - model: 'codestral-latest', - apiKey: 'TMP' - }); - const chatHandler = new CodestralHandler({ - mistralClient, + const chatHandler = new ChatHandler({ + llmClient: llmProvider.chatModel, activeCellManager: activeCellManager }); + llmProvider.providerChange.connect(() => { + chatHandler.llmClient = llmProvider.chatModel; + }); + let sendWithShiftEnter = false; let enableCodeToolbar = true; @@ -94,25 +76,6 @@ const chatPlugin: JupyterFrontEndPlugin = { chatHandler.config = { sendWithShiftEnter, enableCodeToolbar }; } - // TODO: handle the apiKey better - settingsRegistry - ?.load(inlineProviderPlugin.id) - .then(settings => { - const updateKey = () => { - const apiKey = settings.get('apiKey').composite as string; - mistralClient.apiKey = apiKey; - }; - - settings.changed.connect(() => updateKey()); - updateKey(); - }) - .catch(reason => { - console.error( - `Failed to load settings for ${inlineProviderPlugin.id}`, - reason - ); - }); - Promise.all([app.restored, settingsRegistry?.load(chatPlugin.id)]) .then(([, settings]) => { if (!settings) { @@ -148,4 +111,38 @@ const chatPlugin: JupyterFrontEndPlugin = { } }; -export default [inlineProviderPlugin, chatPlugin]; +const llmProviderPlugin: JupyterFrontEndPlugin = { + id: 'jupyterlab-codestral:llm-provider', + autoStart: true, + requires: [ICompletionProviderManager, ISettingRegistry], + provides: ILlmProvider, + activate: ( + app: JupyterFrontEnd, + manager: ICompletionProviderManager, + settingRegistry: ISettingRegistry + ): ILlmProvider => { + const llmProvider = new LlmProvider({ completionProviderManager: manager }); + + settingRegistry + .load(llmProviderPlugin.id) + .then(settings => { + const updateProvider = () => { + const provider = settings.get('provider').composite as string; + llmProvider.setProvider(provider, settings.composite); + }; + + settings.changed.connect(() => updateProvider()); + updateProvider(); + }) + .catch(reason => { + console.error( + `Failed to load settings for ${llmProviderPlugin.id}`, + reason + ); + }); + + return llmProvider; + } +}; + +export default [chatPlugin, llmProviderPlugin]; diff --git a/src/provider.ts b/src/provider.ts index 7c4c1e5..1515c50 100644 --- a/src/provider.ts +++ b/src/provider.ts @@ -1,77 +1,78 @@ import { - CompletionHandler, - IInlineCompletionContext, + ICompletionProviderManager, IInlineCompletionProvider } from '@jupyterlab/completer'; +import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { ChatMistralAI, MistralAI } from '@langchain/mistralai'; +import { ISignal, Signal } from '@lumino/signaling'; +import { JSONValue, ReadonlyPartialJSONObject } from '@lumino/coreutils'; +import * as completionProviders from './completion-providers'; +import { ILlmProvider } from './token'; +import { IBaseProvider } from './completion-providers/base-provider'; -import { Throttler } from '@lumino/polling'; - -import { CompletionRequest } from '@mistralai/mistralai'; - -import type { MistralAI } from '@langchain/mistralai'; - -/* - * The Mistral API has a rate limit of 1 request per second - */ -const INTERVAL = 1000; +export class LlmProvider implements ILlmProvider { + constructor(options: LlmProvider.IOptions) { + this._completionProviderManager = options.completionProviderManager; + } -export class CodestralProvider implements IInlineCompletionProvider { - readonly identifier = 'Codestral'; - readonly name = 'Codestral'; + get name(): string | null { + return this._name; + } - constructor(options: CodestralProvider.IOptions) { - this._mistralClient = options.mistralClient; - this._throttler = new Throttler(async (data: CompletionRequest) => { - const response = await this._mistralClient.completionWithRetry( - data, - {}, - false - ); - const items = response.choices.map((choice: any) => { - return { insertText: choice.message.content as string }; - }); + get inlineProvider(): IInlineCompletionProvider | null { + return this._inlineProvider; + } - return { - items - }; - }, INTERVAL); + get chatModel(): BaseChatModel | null { + return this._chatModel; } - async fetch( - request: CompletionHandler.IRequest, - context: IInlineCompletionContext - ) { - const { text, offset: cursorOffset } = request; - const prompt = text.slice(0, cursorOffset); - const suffix = text.slice(cursorOffset); + setProvider(value: string | null, settings: ReadonlyPartialJSONObject) { + if (value === null) { + this._inlineProvider = null; + this._chatModel = null; + this._providerChange.emit(); + return; + } - const data = { - prompt, - suffix, - model: 'codestral-latest', - // temperature: 0, - // top_p: 1, - // max_tokens: 1024, - // min_tokens: 0, - stream: false, - // random_seed: 1337, - stop: [] - }; + const provider = this._completionProviders.get(value) as IBaseProvider; + if (provider) { + provider.configure(settings as { [property: string]: JSONValue }); + return; + } - try { - return this._throttler.invoke(data); - } catch (error) { - console.error('Error fetching completions', error); - return { items: [] }; + if (value === 'MistralAI') { + this._name = 'MistralAI'; + const mistralClient = new MistralAI({ apiKey: 'TMP', ...settings }); + this._inlineProvider = new completionProviders.CodestralProvider({ + mistralClient + }); + this._completionProviderManager.registerInlineProvider( + this._inlineProvider + ); + this._completionProviders.set(value, this._inlineProvider); + this._chatModel = new ChatMistralAI({ apiKey: 'TMP', ...settings }); + } else { + this._inlineProvider = null; + this._chatModel = null; } + this._providerChange.emit(); + } + + get providerChange(): ISignal { + return this._providerChange; } - private _throttler: Throttler; - private _mistralClient: MistralAI; + private _completionProviderManager: ICompletionProviderManager; + private _completionProviders = new Map(); + private _name: string | null = null; + private _inlineProvider: IBaseProvider | null = null; + private _chatModel: BaseChatModel | null = null; + private _providerChange = new Signal(this); } -export namespace CodestralProvider { +export namespace LlmProvider { export interface IOptions { - mistralClient: MistralAI; + completionProviderManager: ICompletionProviderManager; } } diff --git a/src/token.ts b/src/token.ts new file mode 100644 index 0000000..96ef537 --- /dev/null +++ b/src/token.ts @@ -0,0 +1,16 @@ +import { IInlineCompletionProvider } from '@jupyterlab/completer'; +import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { Token } from '@lumino/coreutils'; +import { ISignal } from '@lumino/signaling'; + +export interface ILlmProvider { + name: string | null; + inlineProvider: IInlineCompletionProvider | null; + chatModel: BaseChatModel | null; + providerChange: ISignal; +} + +export const ILlmProvider = new Token( + 'jupyterlab-codestral:LlmProvider', + 'Provider for chat and completion LLM client' +); diff --git a/tsconfig.json b/tsconfig.json index 9897917..bcaac9d 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -19,5 +19,5 @@ "strictNullChecks": true, "target": "ES2018" }, - "include": ["src/*"] + "include": ["src/*", "src/**/*"] } From f63cce0a8b7022895c73325b37d993e807075a24 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Tue, 29 Oct 2024 12:08:00 +0100 Subject: [PATCH 2/5] Update changes in settings to the chat and completion LLM --- src/completion-providers/base-provider.ts | 12 +-- .../codestral-provider.ts | 24 ++--- src/index.ts | 17 ---- src/provider.ts | 95 +++++++++++++------ src/token.ts | 9 +- src/tools.ts | 17 ++++ 6 files changed, 99 insertions(+), 75 deletions(-) create mode 100644 src/tools.ts diff --git a/src/completion-providers/base-provider.ts b/src/completion-providers/base-provider.ts index 4a3f6bf..f312f17 100644 --- a/src/completion-providers/base-provider.ts +++ b/src/completion-providers/base-provider.ts @@ -1,16 +1,6 @@ import { IInlineCompletionProvider } from '@jupyterlab/completer'; import { LLM } from '@langchain/core/language_models/llms'; -import { JSONValue } from '@lumino/coreutils'; export interface IBaseProvider extends IInlineCompletionProvider { - configure(settings: { [property: string]: JSONValue }): void; -} - -// https://stackoverflow.com/questions/54724875/can-we-check-whether-property-is-readonly-in-typescript -export function isWritable(obj: T, key: keyof T) { - const desc = - Object.getOwnPropertyDescriptor(obj, key) || - Object.getOwnPropertyDescriptor(Object.getPrototypeOf(obj), key) || - {}; - return Boolean(desc.writable); + client: LLM; } diff --git a/src/completion-providers/codestral-provider.ts b/src/completion-providers/codestral-provider.ts index 4a5ad90..6143bd6 100644 --- a/src/completion-providers/codestral-provider.ts +++ b/src/completion-providers/codestral-provider.ts @@ -2,14 +2,12 @@ import { CompletionHandler, IInlineCompletionContext } from '@jupyterlab/completer'; - import { Throttler } from '@lumino/polling'; - import { CompletionRequest } from '@mistralai/mistralai'; - import type { MistralAI } from '@langchain/mistralai'; -import { JSONValue } from '@lumino/coreutils'; -import { IBaseProvider, isWritable } from './base-provider'; + +import { IBaseProvider } from './base-provider'; +import { LLM } from '@langchain/core/language_models/llms'; /* * The Mistral API has a rate limit of 1 request per second @@ -38,6 +36,10 @@ export class CodestralProvider implements IBaseProvider { }, INTERVAL); } + get client(): LLM { + return this._mistralClient; + } + async fetch( request: CompletionHandler.IRequest, context: IInlineCompletionContext @@ -67,18 +69,6 @@ export class CodestralProvider implements IBaseProvider { } } - configure(settings: { [property: string]: JSONValue }): void { - Object.entries(settings).forEach(([key, value], index) => { - if (key in this._mistralClient) { - if (isWritable(this._mistralClient, key as keyof MistralAI)) { - // eslint-disable-next-line @typescript-eslint/ban-ts-comment - // @ts-ignore - this._mistralClient[key as keyof MistralAI] = value; - } - } - }); - } - private _throttler: Throttler; private _mistralClient: MistralAI; } diff --git a/src/index.ts b/src/index.ts index bbbb300..b37d29b 100644 --- a/src/index.ts +++ b/src/index.ts @@ -18,23 +18,6 @@ import { ChatHandler } from './chat-handler'; import { ILlmProvider } from './token'; import { LlmProvider } from './provider'; -// const inlineProviderPlugin: JupyterFrontEndPlugin = { -// id: 'jupyterlab-codestral:inline-provider', -// autoStart: true, -// requires: [ICompletionProviderManager, ILlmProvider, ISettingRegistry], -// activate: ( -// app: JupyterFrontEnd, -// manager: ICompletionProviderManager, -// llmProvider: ILlmProvider -// ): void => { -// llmProvider.providerChange.connect(() => { -// if (llmProvider.inlineCompleter !== null) { -// manager.registerInlineProvider(llmProvider.inlineCompleter); -// } -// }); -// } -// }; - const chatPlugin: JupyterFrontEndPlugin = { id: 'jupyterlab-codestral:chat', description: 'LLM chat extension', diff --git a/src/provider.ts b/src/provider.ts index 1515c50..a7c26bd 100644 --- a/src/provider.ts +++ b/src/provider.ts @@ -1,14 +1,13 @@ -import { - ICompletionProviderManager, - IInlineCompletionProvider -} from '@jupyterlab/completer'; +import { ICompletionProviderManager } from '@jupyterlab/completer'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { ChatMistralAI, MistralAI } from '@langchain/mistralai'; import { ISignal, Signal } from '@lumino/signaling'; -import { JSONValue, ReadonlyPartialJSONObject } from '@lumino/coreutils'; +import { ReadonlyPartialJSONObject } from '@lumino/coreutils'; import * as completionProviders from './completion-providers'; -import { ILlmProvider } from './token'; +import { ILlmProvider, IProviders } from './token'; import { IBaseProvider } from './completion-providers/base-provider'; +import { isWritable } from './tools'; +import { BaseLanguageModel } from '@langchain/core/language_models/base'; export class LlmProvider implements ILlmProvider { constructor(options: LlmProvider.IOptions) { @@ -19,42 +18,65 @@ export class LlmProvider implements ILlmProvider { return this._name; } - get inlineProvider(): IInlineCompletionProvider | null { - return this._inlineProvider; + get completionProvider(): IBaseProvider | null { + if (this._name === null) { + return null; + } + return ( + this._completionProviders.get(this._name)?.completionProvider || null + ); } get chatModel(): BaseChatModel | null { - return this._chatModel; + if (this._name === null) { + return null; + } + return this._completionProviders.get(this._name)?.chatModel || null; } - setProvider(value: string | null, settings: ReadonlyPartialJSONObject) { - if (value === null) { - this._inlineProvider = null; - this._chatModel = null; + setProvider(name: string | null, settings: ReadonlyPartialJSONObject) { + console.log('SET PROVIDER', name); + if (name === null) { + // TODO: the inline completion is not disabled, it should be removed/disabled + // from the manager. this._providerChange.emit(); return; } - const provider = this._completionProviders.get(value) as IBaseProvider; - if (provider) { - provider.configure(settings as { [property: string]: JSONValue }); + const providers = this._completionProviders.get(name); + if (providers !== undefined) { + console.log('Provider defined'); + // Update the inline completion provider settings. + this._updateConfig(providers.completionProvider.client, settings); + + // Update the chat LLM settings. + this._updateConfig(providers.chatModel, settings); + + if (name !== this._name) { + this._name = name; + this._providerChange.emit(); + } return; } - - if (value === 'MistralAI') { + console.log('Provider undefined'); + if (name === 'MistralAI') { this._name = 'MistralAI'; - const mistralClient = new MistralAI({ apiKey: 'TMP', ...settings }); - this._inlineProvider = new completionProviders.CodestralProvider({ + const mistralClient = new MistralAI({ apiKey: 'TMP' }); + this._updateConfig(mistralClient, settings); + + const completionProvider = new completionProviders.CodestralProvider({ mistralClient }); this._completionProviderManager.registerInlineProvider( - this._inlineProvider + completionProvider ); - this._completionProviders.set(value, this._inlineProvider); - this._chatModel = new ChatMistralAI({ apiKey: 'TMP', ...settings }); + + const chatModel = new ChatMistralAI({ apiKey: 'TMP' }); + this._updateConfig(chatModel as any, settings); + + this._completionProviders.set(name, { completionProvider, chatModel }); } else { - this._inlineProvider = null; - this._chatModel = null; + this._name = null; } this._providerChange.emit(); } @@ -63,11 +85,28 @@ export class LlmProvider implements ILlmProvider { return this._providerChange; } + private _updateConfig( + model: T, + settings: ReadonlyPartialJSONObject + ) { + Object.entries(settings).forEach(([key, value], index) => { + if (key in model) { + const modelKey = key as keyof typeof model; + if (isWritable(model, modelKey)) { + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-ignore + model[modelKey] = value; + } + } + }); + } + private _completionProviderManager: ICompletionProviderManager; - private _completionProviders = new Map(); + // The ICompletionProviderManager does not allow manipulating the providers, + // like getting, removing or recreating them. This map store the created providers to + // be able to modify them. + private _completionProviders = new Map(); private _name: string | null = null; - private _inlineProvider: IBaseProvider | null = null; - private _chatModel: BaseChatModel | null = null; private _providerChange = new Signal(this); } diff --git a/src/token.ts b/src/token.ts index 96ef537..5e1ae1d 100644 --- a/src/token.ts +++ b/src/token.ts @@ -1,15 +1,20 @@ -import { IInlineCompletionProvider } from '@jupyterlab/completer'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { Token } from '@lumino/coreutils'; import { ISignal } from '@lumino/signaling'; +import { IBaseProvider } from './completion-providers/base-provider'; export interface ILlmProvider { name: string | null; - inlineProvider: IInlineCompletionProvider | null; + completionProvider: IBaseProvider | null; chatModel: BaseChatModel | null; providerChange: ISignal; } +export interface IProviders { + completionProvider: IBaseProvider; + chatModel: BaseChatModel; +} + export const ILlmProvider = new Token( 'jupyterlab-codestral:LlmProvider', 'Provider for chat and completion LLM client' diff --git a/src/tools.ts b/src/tools.ts new file mode 100644 index 0000000..4369aec --- /dev/null +++ b/src/tools.ts @@ -0,0 +1,17 @@ +import { BaseLanguageModel } from '@langchain/core/language_models/base'; + +/** + * This function indicates whether a key is writable in an object. + * https://stackoverflow.com/questions/54724875/can-we-check-whether-property-is-readonly-in-typescript + * + * @param obj - An object extending the BaseLanguageModel interface. + * @param key - A string as a key of the object. + * @returns a boolean whether the key is writable or not. + */ +export function isWritable(obj: T, key: keyof T) { + const desc = + Object.getOwnPropertyDescriptor(obj, key) || + Object.getOwnPropertyDescriptor(Object.getPrototypeOf(obj), key) || + {}; + return Boolean(desc.writable); +} From 9f2f2490431043fdf81e6c2ce8cb97286a18ec22 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Tue, 29 Oct 2024 16:12:26 +0100 Subject: [PATCH 3/5] Cleaning and lint --- schema/inline-provider.json | 14 -------------- schema/llm-provider.json | 5 +---- src/provider.ts | 7 ++----- yarn.lock | 1 + 4 files changed, 4 insertions(+), 23 deletions(-) delete mode 100644 schema/inline-provider.json diff --git a/schema/inline-provider.json b/schema/inline-provider.json deleted file mode 100644 index 12a7219..0000000 --- a/schema/inline-provider.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "title": "Codestral", - "description": "Codestral settings", - "type": "object", - "properties": { - "apiKey": { - "type": "string", - "title": "The Codestral API key", - "description": "The API key to use for Codestral", - "default": "" - } - }, - "additionalProperties": false -} diff --git a/schema/llm-provider.json b/schema/llm-provider.json index 2855e34..c85bf43 100644 --- a/schema/llm-provider.json +++ b/schema/llm-provider.json @@ -8,10 +8,7 @@ "title": "The LLM provider", "description": "The LLM provider to use for chat and completion", "default": "None", - "enum": [ - "None", - "MistralAI" - ] + "enum": ["None", "MistralAI"] }, "apiKey": { "type": "string", diff --git a/src/provider.ts b/src/provider.ts index a7c26bd..7030eed 100644 --- a/src/provider.ts +++ b/src/provider.ts @@ -35,17 +35,15 @@ export class LlmProvider implements ILlmProvider { } setProvider(name: string | null, settings: ReadonlyPartialJSONObject) { - console.log('SET PROVIDER', name); if (name === null) { - // TODO: the inline completion is not disabled, it should be removed/disabled - // from the manager. + // TODO: the inline completion is not disabled. + // It should be removed/disabled from the manager. this._providerChange.emit(); return; } const providers = this._completionProviders.get(name); if (providers !== undefined) { - console.log('Provider defined'); // Update the inline completion provider settings. this._updateConfig(providers.completionProvider.client, settings); @@ -58,7 +56,6 @@ export class LlmProvider implements ILlmProvider { } return; } - console.log('Provider undefined'); if (name === 'MistralAI') { this._name = 'MistralAI'; const mistralClient = new MistralAI({ apiKey: 'TMP' }); diff --git a/yarn.lock b/yarn.lock index 7bf63b1..47e6599 100644 --- a/yarn.lock +++ b/yarn.lock @@ -4887,6 +4887,7 @@ __metadata: "@langchain/mistralai": ^0.1.1 "@lumino/coreutils": ^2.1.2 "@lumino/polling": ^2.1.2 + "@lumino/signaling": ^2.1.2 "@types/json-schema": ^7.0.11 "@types/react": ^18.0.26 "@types/react-addons-linked-state-mixin": ^0.14.22 From 6e11c72dc07b77bcea1398f720cb85ab2c9e5ae7 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Wed, 30 Oct 2024 16:01:33 +0100 Subject: [PATCH 4/5] Provides only one completion provider --- src/chat-handler.ts | 2 +- src/completion-provider.ts | 61 ++++++++ src/completion-providers/base-provider.ts | 6 - src/completion-providers/index.ts | 1 - src/index.ts | 6 +- src/llm-models/base-completer.ts | 20 +++ .../codestral-completer.ts} | 26 ++-- src/llm-models/index.ts | 3 + src/llm-models/utils.ts | 24 +++ src/provider.ts | 147 ++++++++++-------- src/token.ts | 12 +- src/tools.ts | 17 -- 12 files changed, 208 insertions(+), 117 deletions(-) create mode 100644 src/completion-provider.ts delete mode 100644 src/completion-providers/base-provider.ts delete mode 100644 src/completion-providers/index.ts create mode 100644 src/llm-models/base-completer.ts rename src/{completion-providers/codestral-provider.ts => llm-models/codestral-completer.ts} (76%) create mode 100644 src/llm-models/index.ts create mode 100644 src/llm-models/utils.ts delete mode 100644 src/tools.ts diff --git a/src/chat-handler.ts b/src/chat-handler.ts index 47c8867..0191302 100644 --- a/src/chat-handler.ts +++ b/src/chat-handler.ts @@ -10,12 +10,12 @@ import { INewMessage } from '@jupyter/chat'; import type { BaseChatModel } from '@langchain/core/language_models/chat_models'; -import { UUID } from '@lumino/coreutils'; import { AIMessage, HumanMessage, mergeMessageRuns } from '@langchain/core/messages'; +import { UUID } from '@lumino/coreutils'; export type ConnectionMessage = { type: 'connection'; diff --git a/src/completion-provider.ts b/src/completion-provider.ts new file mode 100644 index 0000000..53b2051 --- /dev/null +++ b/src/completion-provider.ts @@ -0,0 +1,61 @@ +import { + CompletionHandler, + IInlineCompletionContext, + IInlineCompletionProvider +} from '@jupyterlab/completer'; +import { LLM } from '@langchain/core/language_models/llms'; + +import { getCompleter, IBaseCompleter } from './llm-models'; + +/** + * The generic completion provider to register to the completion provider manager. + */ +export class CompletionProvider implements IInlineCompletionProvider { + readonly identifier = '@jupyterlite/ai'; + + constructor(options: CompletionProvider.IOptions) { + this.name = options.name; + } + + /** + * Getter and setter of the name. + * The setter will create the appropriate completer, accordingly to the name. + */ + get name(): string { + return this._name; + } + set name(name: string) { + this._name = name; + this._completer = getCompleter(name); + } + + /** + * get the current completer. + */ + get completer(): IBaseCompleter | null { + return this._completer; + } + + /** + * Get the LLM completer. + */ + get llmCompleter(): LLM | null { + return this._completer?.client || null; + } + + async fetch( + request: CompletionHandler.IRequest, + context: IInlineCompletionContext + ) { + return this._completer?.fetch(request, context); + } + + private _name: string = 'None'; + private _completer: IBaseCompleter | null = null; +} + +export namespace CompletionProvider { + export interface IOptions { + name: string; + } +} diff --git a/src/completion-providers/base-provider.ts b/src/completion-providers/base-provider.ts deleted file mode 100644 index f312f17..0000000 --- a/src/completion-providers/base-provider.ts +++ /dev/null @@ -1,6 +0,0 @@ -import { IInlineCompletionProvider } from '@jupyterlab/completer'; -import { LLM } from '@langchain/core/language_models/llms'; - -export interface IBaseProvider extends IInlineCompletionProvider { - client: LLM; -} diff --git a/src/completion-providers/index.ts b/src/completion-providers/index.ts deleted file mode 100644 index fdb3eeb..0000000 --- a/src/completion-providers/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from './codestral-provider'; diff --git a/src/index.ts b/src/index.ts index b37d29b..fa939a3 100644 --- a/src/index.ts +++ b/src/index.ts @@ -15,8 +15,8 @@ import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { ISettingRegistry } from '@jupyterlab/settingregistry'; import { ChatHandler } from './chat-handler'; -import { ILlmProvider } from './token'; import { LlmProvider } from './provider'; +import { ILlmProvider } from './token'; const chatPlugin: JupyterFrontEndPlugin = { id: 'jupyterlab-codestral:chat', @@ -45,7 +45,7 @@ const chatPlugin: JupyterFrontEndPlugin = { activeCellManager: activeCellManager }); - llmProvider.providerChange.connect(() => { + llmProvider.modelChange.connect(() => { chatHandler.llmClient = llmProvider.chatModel; }); @@ -111,7 +111,7 @@ const llmProviderPlugin: JupyterFrontEndPlugin = { .then(settings => { const updateProvider = () => { const provider = settings.get('provider').composite as string; - llmProvider.setProvider(provider, settings.composite); + llmProvider.setModels(provider, settings.composite); }; settings.changed.connect(() => updateProvider()); diff --git a/src/llm-models/base-completer.ts b/src/llm-models/base-completer.ts new file mode 100644 index 0000000..8374db4 --- /dev/null +++ b/src/llm-models/base-completer.ts @@ -0,0 +1,20 @@ +import { + CompletionHandler, + IInlineCompletionContext +} from '@jupyterlab/completer'; +import { LLM } from '@langchain/core/language_models/llms'; + +export interface IBaseCompleter { + /** + * The LLM completer. + */ + client: LLM; + + /** + * The fetch request for the LLM completer. + */ + fetch( + request: CompletionHandler.IRequest, + context: IInlineCompletionContext + ): Promise; +} diff --git a/src/completion-providers/codestral-provider.ts b/src/llm-models/codestral-completer.ts similarity index 76% rename from src/completion-providers/codestral-provider.ts rename to src/llm-models/codestral-completer.ts index 6143bd6..8f3e6ee 100644 --- a/src/completion-providers/codestral-provider.ts +++ b/src/llm-models/codestral-completer.ts @@ -2,24 +2,24 @@ import { CompletionHandler, IInlineCompletionContext } from '@jupyterlab/completer'; +import { LLM } from '@langchain/core/language_models/llms'; +import { MistralAI } from '@langchain/mistralai'; import { Throttler } from '@lumino/polling'; import { CompletionRequest } from '@mistralai/mistralai'; -import type { MistralAI } from '@langchain/mistralai'; -import { IBaseProvider } from './base-provider'; -import { LLM } from '@langchain/core/language_models/llms'; +import { IBaseCompleter } from './base-completer'; /* * The Mistral API has a rate limit of 1 request per second */ const INTERVAL = 1000; -export class CodestralProvider implements IBaseProvider { - readonly identifier = 'Codestral'; - readonly name = 'Codestral'; - - constructor(options: CodestralProvider.IOptions) { - this._mistralClient = options.mistralClient; +export class CodestralCompleter implements IBaseCompleter { + constructor() { + this._mistralClient = new MistralAI({ + apiKey: 'TMP', + model: 'codestral-latest' + }); this._throttler = new Throttler(async (data: CompletionRequest) => { const response = await this._mistralClient.completionWithRetry( data, @@ -51,7 +51,7 @@ export class CodestralProvider implements IBaseProvider { const data = { prompt, suffix, - model: 'codestral-latest', + model: this._mistralClient.model, // temperature: 0, // top_p: 1, // max_tokens: 1024, @@ -72,9 +72,3 @@ export class CodestralProvider implements IBaseProvider { private _throttler: Throttler; private _mistralClient: MistralAI; } - -export namespace CodestralProvider { - export interface IOptions { - mistralClient: MistralAI; - } -} diff --git a/src/llm-models/index.ts b/src/llm-models/index.ts new file mode 100644 index 0000000..ae6b725 --- /dev/null +++ b/src/llm-models/index.ts @@ -0,0 +1,3 @@ +export * from './base-completer'; +export * from './codestral-completer'; +export * from './utils'; diff --git a/src/llm-models/utils.ts b/src/llm-models/utils.ts new file mode 100644 index 0000000..6d9b9f4 --- /dev/null +++ b/src/llm-models/utils.ts @@ -0,0 +1,24 @@ +import { BaseChatModel } from '@langchain/core/language_models/chat_models'; +import { ChatMistralAI } from '@langchain/mistralai'; +import { IBaseCompleter } from './base-completer'; +import { CodestralCompleter } from './codestral-completer'; + +/** + * Get an LLM completer from the name. + */ +export function getCompleter(name: string): IBaseCompleter | null { + if (name === 'MistralAI') { + return new CodestralCompleter(); + } + return null; +} + +/** + * Get an LLM chat model from the name. + */ +export function getChatModel(name: string): BaseChatModel | null { + if (name === 'MistralAI') { + return new ChatMistralAI({ apiKey: 'TMP' }); + } + return null; +} diff --git a/src/provider.ts b/src/provider.ts index 7030eed..1eed586 100644 --- a/src/provider.ts +++ b/src/provider.ts @@ -1,88 +1,119 @@ import { ICompletionProviderManager } from '@jupyterlab/completer'; +import { BaseLanguageModel } from '@langchain/core/language_models/base'; import { BaseChatModel } from '@langchain/core/language_models/chat_models'; -import { ChatMistralAI, MistralAI } from '@langchain/mistralai'; import { ISignal, Signal } from '@lumino/signaling'; import { ReadonlyPartialJSONObject } from '@lumino/coreutils'; -import * as completionProviders from './completion-providers'; -import { ILlmProvider, IProviders } from './token'; -import { IBaseProvider } from './completion-providers/base-provider'; -import { isWritable } from './tools'; -import { BaseLanguageModel } from '@langchain/core/language_models/base'; + +import { CompletionProvider } from './completion-provider'; +import { getChatModel, IBaseCompleter } from './llm-models'; +import { ILlmProvider } from './token'; export class LlmProvider implements ILlmProvider { constructor(options: LlmProvider.IOptions) { - this._completionProviderManager = options.completionProviderManager; + this._completionProvider = new CompletionProvider({ name: 'None' }); + options.completionProviderManager.registerInlineProvider( + this._completionProvider + ); } - get name(): string | null { + get name(): string { return this._name; } - get completionProvider(): IBaseProvider | null { + /** + * get the current completer of the completion provider. + */ + get completer(): IBaseCompleter | null { if (this._name === null) { return null; } - return ( - this._completionProviders.get(this._name)?.completionProvider || null - ); + return this._completionProvider.completer; } + /** + * get the current llm chat model. + */ get chatModel(): BaseChatModel | null { if (this._name === null) { return null; } - return this._completionProviders.get(this._name)?.chatModel || null; + return this._llmChatModel; } - setProvider(name: string | null, settings: ReadonlyPartialJSONObject) { - if (name === null) { - // TODO: the inline completion is not disabled. - // It should be removed/disabled from the manager. - this._providerChange.emit(); - return; + /** + * Set the models (chat model and completer). + * Creates the models if the name has changed, otherwise only updates their config. + * + * @param name - the name of the model to use. + * @param settings - the settings for the models. + */ + setModels(name: string, settings: ReadonlyPartialJSONObject) { + if (name !== this._name) { + this._name = name; + this._completionProvider.name = name; + this._llmChatModel = getChatModel(name); + this._modelChange.emit(); } - const providers = this._completionProviders.get(name); - if (providers !== undefined) { - // Update the inline completion provider settings. - this._updateConfig(providers.completionProvider.client, settings); - - // Update the chat LLM settings. - this._updateConfig(providers.chatModel, settings); + // Update the inline completion provider settings. + if (this._completionProvider.llmCompleter) { + LlmProvider.updateConfig(this._completionProvider.llmCompleter, settings); + } - if (name !== this._name) { - this._name = name; - this._providerChange.emit(); - } - return; + // Update the chat LLM settings. + if (this._llmChatModel) { + LlmProvider.updateConfig(this._llmChatModel, settings); } - if (name === 'MistralAI') { - this._name = 'MistralAI'; - const mistralClient = new MistralAI({ apiKey: 'TMP' }); - this._updateConfig(mistralClient, settings); + } - const completionProvider = new completionProviders.CodestralProvider({ - mistralClient - }); - this._completionProviderManager.registerInlineProvider( - completionProvider - ); + get modelChange(): ISignal { + return this._modelChange; + } - const chatModel = new ChatMistralAI({ apiKey: 'TMP' }); - this._updateConfig(chatModel as any, settings); + private _completionProvider: CompletionProvider; + private _llmChatModel: BaseChatModel | null = null; + private _name: string = 'None'; + private _modelChange = new Signal(this); +} - this._completionProviders.set(name, { completionProvider, chatModel }); - } else { - this._name = null; - } - this._providerChange.emit(); +export namespace LlmProvider { + /** + * The options for the LLM provider. + */ + export interface IOptions { + /** + * The completion provider manager in which register the LLM completer. + */ + completionProviderManager: ICompletionProviderManager; } - get providerChange(): ISignal { - return this._providerChange; + /** + * This function indicates whether a key is writable in an object. + * https://stackoverflow.com/questions/54724875/can-we-check-whether-property-is-readonly-in-typescript + * + * @param obj - An object extending the BaseLanguageModel interface. + * @param key - A string as a key of the object. + * @returns a boolean whether the key is writable or not. + */ + export function isWritable( + obj: T, + key: keyof T + ) { + const desc = + Object.getOwnPropertyDescriptor(obj, key) || + Object.getOwnPropertyDescriptor(Object.getPrototypeOf(obj), key) || + {}; + return Boolean(desc.writable); } - private _updateConfig( + /** + * Update the config of a language model. + * It only updates the writable attributes of the model. + * + * @param model - the model to update. + * @param settings - the configuration s a JSON object. + */ + export function updateConfig( model: T, settings: ReadonlyPartialJSONObject ) { @@ -97,18 +128,4 @@ export class LlmProvider implements ILlmProvider { } }); } - - private _completionProviderManager: ICompletionProviderManager; - // The ICompletionProviderManager does not allow manipulating the providers, - // like getting, removing or recreating them. This map store the created providers to - // be able to modify them. - private _completionProviders = new Map(); - private _name: string | null = null; - private _providerChange = new Signal(this); -} - -export namespace LlmProvider { - export interface IOptions { - completionProviderManager: ICompletionProviderManager; - } } diff --git a/src/token.ts b/src/token.ts index 5e1ae1d..3148938 100644 --- a/src/token.ts +++ b/src/token.ts @@ -1,18 +1,14 @@ import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { Token } from '@lumino/coreutils'; import { ISignal } from '@lumino/signaling'; -import { IBaseProvider } from './completion-providers/base-provider'; + +import { IBaseCompleter } from './llm-models'; export interface ILlmProvider { name: string | null; - completionProvider: IBaseProvider | null; + completer: IBaseCompleter | null; chatModel: BaseChatModel | null; - providerChange: ISignal; -} - -export interface IProviders { - completionProvider: IBaseProvider; - chatModel: BaseChatModel; + modelChange: ISignal; } export const ILlmProvider = new Token( diff --git a/src/tools.ts b/src/tools.ts deleted file mode 100644 index 4369aec..0000000 --- a/src/tools.ts +++ /dev/null @@ -1,17 +0,0 @@ -import { BaseLanguageModel } from '@langchain/core/language_models/base'; - -/** - * This function indicates whether a key is writable in an object. - * https://stackoverflow.com/questions/54724875/can-we-check-whether-property-is-readonly-in-typescript - * - * @param obj - An object extending the BaseLanguageModel interface. - * @param key - A string as a key of the object. - * @returns a boolean whether the key is writable or not. - */ -export function isWritable(obj: T, key: keyof T) { - const desc = - Object.getOwnPropertyDescriptor(obj, key) || - Object.getOwnPropertyDescriptor(Object.getPrototypeOf(obj), key) || - {}; - return Boolean(desc.writable); -} From 9329b590ea397ae486b60d434e07130b5f0ca055 Mon Sep 17 00:00:00 2001 From: Nicolas Brichet Date: Thu, 31 Oct 2024 15:36:13 +0100 Subject: [PATCH 5/5] Rename 'client' to 'provider' and LlmProvider to AIProvider for better readability --- .../{llm-provider.json => ai-provider.json} | 6 ++-- src/chat-handler.ts | 20 +++++------ src/completion-provider.ts | 2 +- src/index.ts | 34 +++++++++---------- src/llm-models/base-completer.ts | 2 +- src/llm-models/codestral-completer.ts | 12 +++---- src/provider.ts | 16 ++++----- src/token.ts | 10 +++--- 8 files changed, 51 insertions(+), 51 deletions(-) rename schema/{llm-provider.json => ai-provider.json} (74%) diff --git a/schema/llm-provider.json b/schema/ai-provider.json similarity index 74% rename from schema/llm-provider.json rename to schema/ai-provider.json index c85bf43..d4b9a04 100644 --- a/schema/llm-provider.json +++ b/schema/ai-provider.json @@ -1,12 +1,12 @@ { - "title": "LLM provider", + "title": "AI provider", "description": "Provider settings", "type": "object", "properties": { "provider": { "type": "string", - "title": "The LLM provider", - "description": "The LLM provider to use for chat and completion", + "title": "The AI provider", + "description": "The AI provider to use for chat and completion", "default": "None", "enum": ["None", "MistralAI"] }, diff --git a/src/chat-handler.ts b/src/chat-handler.ts index 0191302..18417f6 100644 --- a/src/chat-handler.ts +++ b/src/chat-handler.ts @@ -25,14 +25,14 @@ export type ConnectionMessage = { export class ChatHandler extends ChatModel { constructor(options: ChatHandler.IOptions) { super(options); - this._llmClient = options.llmClient; + this._provider = options.provider; } - get llmClient(): BaseChatModel | null { - return this._llmClient; + get provider(): BaseChatModel | null { + return this._provider; } - set llmClient(client: BaseChatModel | null) { - this._llmClient = client; + set provider(provider: BaseChatModel | null) { + this._provider = provider; } async sendMessage(message: INewMessage): Promise { @@ -46,10 +46,10 @@ export class ChatHandler extends ChatModel { }; this.messageAdded(msg); - if (this._llmClient === null) { + if (this._provider === null) { const botMsg: IChatMessage = { id: UUID.uuid4(), - body: '**Chat client not configured**', + body: '**AI provider not configured for the chat**', sender: { username: 'ERROR' }, time: Date.now(), type: 'msg' @@ -69,7 +69,7 @@ export class ChatHandler extends ChatModel { }) ); - const response = await this._llmClient.invoke(messages); + const response = await this._provider.invoke(messages); // TODO: fix deprecated response.text const content = response.text; const botMsg: IChatMessage = { @@ -96,12 +96,12 @@ export class ChatHandler extends ChatModel { super.messageAdded(message); } - private _llmClient: BaseChatModel | null; + private _provider: BaseChatModel | null; private _history: IChatHistory = { messages: [] }; } export namespace ChatHandler { export interface IOptions extends ChatModel.IOptions { - llmClient: BaseChatModel | null; + provider: BaseChatModel | null; } } diff --git a/src/completion-provider.ts b/src/completion-provider.ts index 53b2051..b2ac0b1 100644 --- a/src/completion-provider.ts +++ b/src/completion-provider.ts @@ -40,7 +40,7 @@ export class CompletionProvider implements IInlineCompletionProvider { * Get the LLM completer. */ get llmCompleter(): LLM | null { - return this._completer?.client || null; + return this._completer?.provider || null; } async fetch( diff --git a/src/index.ts b/src/index.ts index fa939a3..2cc8bdc 100644 --- a/src/index.ts +++ b/src/index.ts @@ -15,18 +15,18 @@ import { IRenderMimeRegistry } from '@jupyterlab/rendermime'; import { ISettingRegistry } from '@jupyterlab/settingregistry'; import { ChatHandler } from './chat-handler'; -import { LlmProvider } from './provider'; -import { ILlmProvider } from './token'; +import { AIProvider } from './provider'; +import { IAIProvider } from './token'; const chatPlugin: JupyterFrontEndPlugin = { id: 'jupyterlab-codestral:chat', description: 'LLM chat extension', autoStart: true, optional: [INotebookTracker, ISettingRegistry, IThemeManager], - requires: [ILlmProvider, IRenderMimeRegistry], + requires: [IAIProvider, IRenderMimeRegistry], activate: async ( app: JupyterFrontEnd, - llmProvider: ILlmProvider, + aiProvider: IAIProvider, rmRegistry: IRenderMimeRegistry, notebookTracker: INotebookTracker | null, settingsRegistry: ISettingRegistry | null, @@ -41,12 +41,12 @@ const chatPlugin: JupyterFrontEndPlugin = { } const chatHandler = new ChatHandler({ - llmClient: llmProvider.chatModel, + provider: aiProvider.chatModel, activeCellManager: activeCellManager }); - llmProvider.modelChange.connect(() => { - chatHandler.llmClient = llmProvider.chatModel; + aiProvider.modelChange.connect(() => { + chatHandler.provider = aiProvider.chatModel; }); let sendWithShiftEnter = false; @@ -94,24 +94,24 @@ const chatPlugin: JupyterFrontEndPlugin = { } }; -const llmProviderPlugin: JupyterFrontEndPlugin = { - id: 'jupyterlab-codestral:llm-provider', +const aiProviderPlugin: JupyterFrontEndPlugin = { + id: 'jupyterlab-codestral:ai-provider', autoStart: true, requires: [ICompletionProviderManager, ISettingRegistry], - provides: ILlmProvider, + provides: IAIProvider, activate: ( app: JupyterFrontEnd, manager: ICompletionProviderManager, settingRegistry: ISettingRegistry - ): ILlmProvider => { - const llmProvider = new LlmProvider({ completionProviderManager: manager }); + ): IAIProvider => { + const aiProvider = new AIProvider({ completionProviderManager: manager }); settingRegistry - .load(llmProviderPlugin.id) + .load(aiProviderPlugin.id) .then(settings => { const updateProvider = () => { const provider = settings.get('provider').composite as string; - llmProvider.setModels(provider, settings.composite); + aiProvider.setModels(provider, settings.composite); }; settings.changed.connect(() => updateProvider()); @@ -119,13 +119,13 @@ const llmProviderPlugin: JupyterFrontEndPlugin = { }) .catch(reason => { console.error( - `Failed to load settings for ${llmProviderPlugin.id}`, + `Failed to load settings for ${aiProviderPlugin.id}`, reason ); }); - return llmProvider; + return aiProvider; } }; -export default [chatPlugin, llmProviderPlugin]; +export default [chatPlugin, aiProviderPlugin]; diff --git a/src/llm-models/base-completer.ts b/src/llm-models/base-completer.ts index 8374db4..498abf6 100644 --- a/src/llm-models/base-completer.ts +++ b/src/llm-models/base-completer.ts @@ -8,7 +8,7 @@ export interface IBaseCompleter { /** * The LLM completer. */ - client: LLM; + provider: LLM; /** * The fetch request for the LLM completer. diff --git a/src/llm-models/codestral-completer.ts b/src/llm-models/codestral-completer.ts index 8f3e6ee..f1168c8 100644 --- a/src/llm-models/codestral-completer.ts +++ b/src/llm-models/codestral-completer.ts @@ -16,12 +16,12 @@ const INTERVAL = 1000; export class CodestralCompleter implements IBaseCompleter { constructor() { - this._mistralClient = new MistralAI({ + this._mistralProvider = new MistralAI({ apiKey: 'TMP', model: 'codestral-latest' }); this._throttler = new Throttler(async (data: CompletionRequest) => { - const response = await this._mistralClient.completionWithRetry( + const response = await this._mistralProvider.completionWithRetry( data, {}, false @@ -36,8 +36,8 @@ export class CodestralCompleter implements IBaseCompleter { }, INTERVAL); } - get client(): LLM { - return this._mistralClient; + get provider(): LLM { + return this._mistralProvider; } async fetch( @@ -51,7 +51,7 @@ export class CodestralCompleter implements IBaseCompleter { const data = { prompt, suffix, - model: this._mistralClient.model, + model: this._mistralProvider.model, // temperature: 0, // top_p: 1, // max_tokens: 1024, @@ -70,5 +70,5 @@ export class CodestralCompleter implements IBaseCompleter { } private _throttler: Throttler; - private _mistralClient: MistralAI; + private _mistralProvider: MistralAI; } diff --git a/src/provider.ts b/src/provider.ts index 1eed586..de88ba3 100644 --- a/src/provider.ts +++ b/src/provider.ts @@ -6,10 +6,10 @@ import { ReadonlyPartialJSONObject } from '@lumino/coreutils'; import { CompletionProvider } from './completion-provider'; import { getChatModel, IBaseCompleter } from './llm-models'; -import { ILlmProvider } from './token'; +import { IAIProvider } from './token'; -export class LlmProvider implements ILlmProvider { - constructor(options: LlmProvider.IOptions) { +export class AIProvider implements IAIProvider { + constructor(options: AIProvider.IOptions) { this._completionProvider = new CompletionProvider({ name: 'None' }); options.completionProviderManager.registerInlineProvider( this._completionProvider @@ -57,26 +57,26 @@ export class LlmProvider implements ILlmProvider { // Update the inline completion provider settings. if (this._completionProvider.llmCompleter) { - LlmProvider.updateConfig(this._completionProvider.llmCompleter, settings); + AIProvider.updateConfig(this._completionProvider.llmCompleter, settings); } // Update the chat LLM settings. if (this._llmChatModel) { - LlmProvider.updateConfig(this._llmChatModel, settings); + AIProvider.updateConfig(this._llmChatModel, settings); } } - get modelChange(): ISignal { + get modelChange(): ISignal { return this._modelChange; } private _completionProvider: CompletionProvider; private _llmChatModel: BaseChatModel | null = null; private _name: string = 'None'; - private _modelChange = new Signal(this); + private _modelChange = new Signal(this); } -export namespace LlmProvider { +export namespace AIProvider { /** * The options for the LLM provider. */ diff --git a/src/token.ts b/src/token.ts index 3148938..626be4a 100644 --- a/src/token.ts +++ b/src/token.ts @@ -4,14 +4,14 @@ import { ISignal } from '@lumino/signaling'; import { IBaseCompleter } from './llm-models'; -export interface ILlmProvider { +export interface IAIProvider { name: string | null; completer: IBaseCompleter | null; chatModel: BaseChatModel | null; - modelChange: ISignal; + modelChange: ISignal; } -export const ILlmProvider = new Token( - 'jupyterlab-codestral:LlmProvider', - 'Provider for chat and completion LLM client' +export const IAIProvider = new Token( + 'jupyterlab-codestral:AIProvider', + 'Provider for chat and completion LLM provider' );