Skip to content

Commit

Permalink
Add LLMStream function to nodejs host
Browse files Browse the repository at this point in the history
  • Loading branch information
qianlifeng committed Jun 14, 2024
1 parent 87a272d commit 6b1e6d8
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 18 deletions.
22 changes: 22 additions & 0 deletions Wox.Plugin.Host.Nodejs/src/jsonrpc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { PluginAPI } from "./pluginAPI"
import { Context, Plugin, PluginInitParams, Query, QueryEnv, RefreshableResult, Result, ResultAction, Selection } from "@wox-launcher/wox-plugin"
import { WebSocket } from "ws"
import * as crypto from "crypto"
import { llm } from "@wox-launcher/wox-plugin/types/llm"

const pluginInstances = new Map<PluginJsonRpcRequest["PluginId"], PluginInstance>()

Expand Down Expand Up @@ -62,6 +63,8 @@ export async function handleRequestFromWox(ctx: Context, request: PluginJsonRpcR
return onPluginSettingChange(ctx, request)
case "onGetDynamicSetting":
return onGetDynamicSetting(ctx, request)
case "onLLMStream":
return onLLMStream(ctx, request)
default:
logger.info(ctx, `unknown method handler: ${request.Method}`)
throw new Error(`unknown method handler: ${request.Method}`)
Expand Down Expand Up @@ -162,6 +165,25 @@ async function onGetDynamicSetting(ctx: Context, request: PluginJsonRpcRequest)
return setting
}

async function onLLMStream(ctx: Context, request: PluginJsonRpcRequest) {
const plugin = pluginInstances.get(request.PluginId)
if (plugin === undefined || plugin === null) {
logger.error(ctx, `plugin not found: ${request.PluginName}, forget to load plugin?`)
throw new Error(`plugin not found: ${request.PluginName}, forget to load plugin?`)
}

const callbackId = request.Params.CallbackId
const streamType = request.Params.StreamType
const data = request.Params.Data
const callbackFunc = plugin.API.llmStreamCallbacks.get(callbackId)
if (callbackFunc === undefined || callbackFunc === null) {
logger.error(ctx, `llm stream callback not found: ${callbackId}`)
throw new Error(`llm stream callback not found: ${callbackId}`)
}

callbackFunc(<llm.ChatStreamDataType>streamType, data)
}

async function query(ctx: Context, request: PluginJsonRpcRequest) {
const plugin = pluginInstances.get(request.PluginId)
if (plugin === undefined || plugin === null) {
Expand Down
8 changes: 4 additions & 4 deletions Wox.Plugin.Host.Nodejs/src/pluginAPI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ export class PluginAPI implements PublicAPI {
pluginName: string
settingChangeCallbacks: Map<string, (key: string, value: string) => void>
getDynamicSettingCallbacks: Map<string, (key: string) => PluginSettingDefinitionItem>
llmStreamCallbacks: Map<string, llm.ChatStreamFunc>

constructor(ws: WebSocket, pluginId: string, pluginName: string) {
this.ws = ws
this.pluginId = pluginId
this.pluginName = pluginName
this.settingChangeCallbacks = new Map<string, (key: string, value: string) => void>()
this.getDynamicSettingCallbacks = new Map<string, (key: string) => PluginSettingDefinitionItem>()
this.llmStreamCallbacks = new Map<string, llm.ChatStreamFunc>()
}

async invokeMethod(ctx: Context, method: string, params: { [key: string]: string }): Promise<unknown> {
Expand Down Expand Up @@ -105,9 +107,7 @@ export class PluginAPI implements PublicAPI {

async LLMStream(ctx: Context, conversations: llm.Conversation[], callback: llm.ChatStreamFunc): Promise<void> {
const callbackId = crypto.randomUUID()
await this.invokeMethod(ctx, "LLMStream", { callbackId })

//TODO: implement LLMStream
return null
this.llmStreamCallbacks.set(callbackId, callback)
await this.invokeMethod(ctx, "LLMStream", { callbackId, conversations: JSON.stringify(conversations) })
}
}
28 changes: 14 additions & 14 deletions Wox.UI.Flutter/wox/pubspec.lock
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,10 @@ packages:
dependency: transitive
description:
name: intl
sha256: "3bc132a9dbce73a7e4a21a17d06e1878839ffbf975568bc875c60537824b0c4d"
sha256: d6f56758b7d3014a48af9701c085700aac781a92a87a62b1333b46d8879661cf
url: "https://pub.dev"
source: hosted
version: "0.18.1"
version: "0.19.0"
js:
dependency: transitive
description:
Expand All @@ -337,26 +337,26 @@ packages:
dependency: transitive
description:
name: leak_tracker
sha256: "78eb209deea09858f5269f5a5b02be4049535f568c07b275096836f01ea323fa"
sha256: "7f0df31977cb2c0b88585095d168e689669a2cc9b97c309665e3386f3e9d341a"
url: "https://pub.dev"
source: hosted
version: "10.0.0"
version: "10.0.4"
leak_tracker_flutter_testing:
dependency: transitive
description:
name: leak_tracker_flutter_testing
sha256: b46c5e37c19120a8a01918cfaf293547f47269f7cb4b0058f21531c2465d6ef0
sha256: "06e98f569d004c1315b991ded39924b21af84cf14cc94791b8aea337d25b57f8"
url: "https://pub.dev"
source: hosted
version: "2.0.1"
version: "3.0.3"
leak_tracker_testing:
dependency: transitive
description:
name: leak_tracker_testing
sha256: a597f72a664dbd293f3bfc51f9ba69816f84dcd403cdac7066cb3f6003f3ab47
sha256: "6ba465d5d76e67ddf503e1161d1f4a6bc42306f9d66ca1e8f079a47290fb06d3"
url: "https://pub.dev"
source: hosted
version: "2.0.1"
version: "3.0.1"
lints:
dependency: transitive
description:
Expand Down Expand Up @@ -425,10 +425,10 @@ packages:
dependency: transitive
description:
name: meta
sha256: d584fa6707a52763a52446f02cc621b077888fb63b93bbcb1143a7be5a0c0c04
sha256: "7687075e408b093f36e6bbf6c91878cc0d4cd10f409506f7bc996f68220b9136"
url: "https://pub.dev"
source: hosted
version: "1.11.0"
version: "1.12.0"
path:
dependency: "direct main"
description:
Expand Down Expand Up @@ -710,10 +710,10 @@ packages:
dependency: transitive
description:
name: test_api
sha256: "5c2f730018264d276c20e4f1503fd1308dfbbae39ec8ee63c5236311ac06954b"
sha256: "9955ae474176f7ac8ee4e989dadfb411a58c30415bcfb648fa04b2b8a03afa7f"
url: "https://pub.dev"
source: hosted
version: "0.6.1"
version: "0.7.0"
typed_data:
dependency: transitive
description:
Expand Down Expand Up @@ -838,10 +838,10 @@ packages:
dependency: transitive
description:
name: vm_service
sha256: b3d56ff4341b8f182b96aceb2fa20e3dcb336b9f867bc0eafc0de10f1048e957
sha256: "3923c89304b715fb1eb6423f017651664a03bf5f4b29983627c4da791f74a4ec"
url: "https://pub.dev"
source: hosted
version: "13.0.0"
version: "14.2.1"
web:
dependency: transitive
description:
Expand Down
32 changes: 32 additions & 0 deletions Wox/plugin/host/host_websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"
"time"
"wox/plugin"
"wox/plugin/llm"
"wox/setting/definition"
"wox/share"
"wox/util"
Expand Down Expand Up @@ -379,6 +380,37 @@ func (w *WebsocketHost) handleRequestFromPlugin(ctx context.Context, request Jso
}

pluginInstance.API.RegisterQueryCommands(ctx, commands)
w.sendResponseToHost(ctx, request, "")
case "LLMStream":
callbackId, exist := request.Params["callbackId"]
if !exist {
util.GetLogger().Error(ctx, fmt.Sprintf("[%s] LLMStream method must have a callbackId parameter", request.PluginName))
return
}
conversationsStr, exist := request.Params["conversations"]
if !exist {
util.GetLogger().Error(ctx, fmt.Sprintf("[%s] LLMStream method must have a conversations parameter", request.PluginName))
return
}

var conversations []llm.Conversation
unmarshalErr := json.Unmarshal([]byte(conversationsStr), &conversations)
if unmarshalErr != nil {
util.GetLogger().Error(ctx, fmt.Sprintf("[%s] failed to unmarshal conversations: %s", request.PluginName, unmarshalErr))
return
}

llmErr := pluginInstance.API.LLMStream(ctx, conversations, func(streamType llm.ChatStreamDataType, data string) {
w.invokeMethod(ctx, pluginInstance.Metadata, "onLLMStream", map[string]string{
"CallbackId": callbackId,
"StreamType": string(streamType),
"Data": data,
})
})
if llmErr != nil {
util.GetLogger().Error(ctx, fmt.Sprintf("[%s] failed to start LLM stream: %s", request.PluginName, llmErr))
}

w.sendResponseToHost(ctx, request, "")
}
}
Expand Down

0 comments on commit 6b1e6d8

Please sign in to comment.