Skip to content

Commit 6ab9335

Browse files
authored
feat(llm-bridge): support tools params (#1032)
If the user provides `tools` parameters through tools, the LLM Bridge will not invoke llm-sfn.
1 parent 1e3d113 commit 6ab9335

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

pkg/bridge/ai/call_syncer.go

+3
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ type toolOut struct {
8282
}
8383

8484
func (f *callSyncer) Call(ctx context.Context, transID, reqID string, toolCalls []openai.ToolCall) ([]ToolCallResult, error) {
85+
if len(toolCalls) == 0 {
86+
return []ToolCallResult{}, nil
87+
}
8588
defer func() {
8689
f.cleanCh <- reqID
8790
}()

pkg/bridge/ai/service.go

+29-15
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
228228
return err
229229
}
230230
// 2. add those tools to request
231-
req = srv.addToolsToRequest(req, tools)
231+
req, hasReqTools := srv.addToolsToRequest(req, tools)
232232

233233
// 3. operate system prompt to request
234234
prompt, op := caller.GetSystemPrompt()
@@ -274,7 +274,14 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
274274
if err != nil {
275275
return err
276276
}
277-
277+
if hasReqTools {
278+
if i == 0 {
279+
respSpan = startRespSpan(reqCtx, reqSpan, tracer, w)
280+
}
281+
w.WriteStreamEvent(streamRes)
282+
i++
283+
continue
284+
}
278285
if len(streamRes.PromptFilterResults) > 0 {
279286
continue
280287
}
@@ -318,13 +325,11 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
318325
_ = w.WriteStreamEvent(streamRes)
319326
}
320327
if i == 0 && j == 0 && !isFunctionCall {
321-
reqSpan.End()
322-
recordTTFT(ctx, tracer, w)
323-
_, respSpan = tracer.Start(ctx, "response_in_stream(TBT)")
328+
respSpan = startRespSpan(reqCtx, reqSpan, tracer, w)
324329
}
325330
i++
326331
}
327-
if !isFunctionCall {
332+
if !isFunctionCall || hasReqTools {
328333
respSpan.End()
329334
return w.WriteStreamDone()
330335
}
@@ -350,7 +355,7 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
350355

351356
srv.logger.Debug(" #1 first call", "response", fmt.Sprintf("%+v", resp))
352357
// it is a function call
353-
if resp.Choices[0].FinishReason == openai.FinishReasonToolCalls {
358+
if resp.Choices[0].FinishReason == openai.FinishReasonToolCalls && !hasReqTools {
354359
toolCalls = append(toolCalls, resp.Choices[0].Message.ToolCalls...)
355360
assistantMessage = resp.Choices[0].Message
356361
firstCallSpan.End()
@@ -390,6 +395,8 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
390395
Content: v.Content,
391396
}
392397
}
398+
// second call should not have tool_choice option
399+
req.ToolChoice = nil
393400
req.Messages = append(reqMessages, assistantMessage)
394401
req.Messages = append(req.Messages, llmCalls...)
395402
// anthropic must define tools
@@ -413,7 +420,6 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
413420
)
414421
for {
415422
if i == 0 {
416-
recordTTFT(resCtx, tracer, w)
417423
_, secondRespSpan = tracer.Start(resCtx, "second_call_response_in_stream(TBT)")
418424
}
419425
i++
@@ -452,6 +458,13 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
452458
}
453459
}
454460

461+
func startRespSpan(ctx context.Context, reqSpan trace.Span, tracer trace.Tracer, w EventResponseWriter) trace.Span {
462+
reqSpan.End()
463+
recordTTFT(ctx, tracer, w)
464+
_, respSpan := tracer.Start(ctx, "response_in_stream(TBT)")
465+
return respSpan
466+
}
467+
455468
func (srv *Service) loadOrCreateCaller(credential string) (*Caller, error) {
456469
caller, ok := srv.callers.Get(credential)
457470
if ok {
@@ -476,14 +489,15 @@ func (srv *Service) loadOrCreateCaller(credential string) (*Caller, error) {
476489
return caller, nil
477490
}
478491

479-
func (srv *Service) addToolsToRequest(req openai.ChatCompletionRequest, tools []openai.Tool) openai.ChatCompletionRequest {
480-
if len(tools) > 0 {
481-
req.Tools = tools
492+
func (srv *Service) addToolsToRequest(req openai.ChatCompletionRequest, tools []openai.Tool) (openai.ChatCompletionRequest, bool) {
493+
hasReqTools := len(req.Tools) > 0
494+
if !hasReqTools {
495+
if len(tools) > 0 {
496+
req.Tools = tools
497+
srv.logger.Debug("#1 first call", "request", fmt.Sprintf("%+v", req))
498+
}
482499
}
483-
484-
srv.logger.Debug(" #1 first call", "request", fmt.Sprintf("%+v", req))
485-
486-
return req
500+
return req, hasReqTools
487501
}
488502

489503
func (srv *Service) opSystemPrompt(req openai.ChatCompletionRequest, sysPrompt string, op SystemPromptOp) openai.ChatCompletionRequest {

0 commit comments

Comments
 (0)