@@ -228,7 +228,7 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
228
228
return err
229
229
}
230
230
// 2. add those tools to request
231
- req = srv .addToolsToRequest (req , tools )
231
+ req , hasReqTools : = srv .addToolsToRequest (req , tools )
232
232
233
233
// 3. operate system prompt to request
234
234
prompt , op := caller .GetSystemPrompt ()
@@ -274,7 +274,14 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
274
274
if err != nil {
275
275
return err
276
276
}
277
-
277
+ if hasReqTools {
278
+ if i == 0 {
279
+ respSpan = startRespSpan (reqCtx , reqSpan , tracer , w )
280
+ }
281
+ w .WriteStreamEvent (streamRes )
282
+ i ++
283
+ continue
284
+ }
278
285
if len (streamRes .PromptFilterResults ) > 0 {
279
286
continue
280
287
}
@@ -318,13 +325,11 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
318
325
_ = w .WriteStreamEvent (streamRes )
319
326
}
320
327
if i == 0 && j == 0 && ! isFunctionCall {
321
- reqSpan .End ()
322
- recordTTFT (ctx , tracer , w )
323
- _ , respSpan = tracer .Start (ctx , "response_in_stream(TBT)" )
328
+ respSpan = startRespSpan (reqCtx , reqSpan , tracer , w )
324
329
}
325
330
i ++
326
331
}
327
- if ! isFunctionCall {
332
+ if ! isFunctionCall || hasReqTools {
328
333
respSpan .End ()
329
334
return w .WriteStreamDone ()
330
335
}
@@ -350,7 +355,7 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
350
355
351
356
srv .logger .Debug (" #1 first call" , "response" , fmt .Sprintf ("%+v" , resp ))
352
357
// it is a function call
353
- if resp .Choices [0 ].FinishReason == openai .FinishReasonToolCalls {
358
+ if resp .Choices [0 ].FinishReason == openai .FinishReasonToolCalls && ! hasReqTools {
354
359
toolCalls = append (toolCalls , resp .Choices [0 ].Message .ToolCalls ... )
355
360
assistantMessage = resp .Choices [0 ].Message
356
361
firstCallSpan .End ()
@@ -390,6 +395,8 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
390
395
Content : v .Content ,
391
396
}
392
397
}
398
+ // second call should not have tool_choice option
399
+ req .ToolChoice = nil
393
400
req .Messages = append (reqMessages , assistantMessage )
394
401
req .Messages = append (req .Messages , llmCalls ... )
395
402
// anthropic must define tools
@@ -413,7 +420,6 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
413
420
)
414
421
for {
415
422
if i == 0 {
416
- recordTTFT (resCtx , tracer , w )
417
423
_ , secondRespSpan = tracer .Start (resCtx , "second_call_response_in_stream(TBT)" )
418
424
}
419
425
i ++
@@ -452,6 +458,13 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
452
458
}
453
459
}
454
460
461
+ func startRespSpan (ctx context.Context , reqSpan trace.Span , tracer trace.Tracer , w EventResponseWriter ) trace.Span {
462
+ reqSpan .End ()
463
+ recordTTFT (ctx , tracer , w )
464
+ _ , respSpan := tracer .Start (ctx , "response_in_stream(TBT)" )
465
+ return respSpan
466
+ }
467
+
455
468
func (srv * Service ) loadOrCreateCaller (credential string ) (* Caller , error ) {
456
469
caller , ok := srv .callers .Get (credential )
457
470
if ok {
@@ -476,14 +489,15 @@ func (srv *Service) loadOrCreateCaller(credential string) (*Caller, error) {
476
489
return caller , nil
477
490
}
478
491
479
- func (srv * Service ) addToolsToRequest (req openai.ChatCompletionRequest , tools []openai.Tool ) openai.ChatCompletionRequest {
480
- if len (tools ) > 0 {
481
- req .Tools = tools
492
+ func (srv * Service ) addToolsToRequest (req openai.ChatCompletionRequest , tools []openai.Tool ) (openai.ChatCompletionRequest , bool ) {
493
+ hasReqTools := len (req .Tools ) > 0
494
+ if ! hasReqTools {
495
+ if len (tools ) > 0 {
496
+ req .Tools = tools
497
+ srv .logger .Debug ("#1 first call" , "request" , fmt .Sprintf ("%+v" , req ))
498
+ }
482
499
}
483
-
484
- srv .logger .Debug (" #1 first call" , "request" , fmt .Sprintf ("%+v" , req ))
485
-
486
- return req
500
+ return req , hasReqTools
487
501
}
488
502
489
503
func (srv * Service ) opSystemPrompt (req openai.ChatCompletionRequest , sysPrompt string , op SystemPromptOp ) openai.ChatCompletionRequest {
0 commit comments