Skip to content

Commit c126653

Browse files
authored
feat: response writer with stream state recorder (#1012)
Change `EventResponseWriter` from struct to interface.
1 parent c50207b commit c126653

10 files changed

+189
-103
lines changed

go.mod

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ require (
2525
github.com/spf13/viper v1.19.0
2626
github.com/stretchr/testify v1.10.0
2727
github.com/tetratelabs/wazero v1.8.2
28+
github.com/tidwall/gjson v1.18.0
29+
github.com/tidwall/sjson v1.2.5
2830
github.com/vmihailenco/msgpack/v5 v5.4.1
2931
github.com/yomorun/y3 v1.0.5
3032
go.opentelemetry.io/otel v1.34.0
@@ -84,10 +86,8 @@ require (
8486
github.com/spf13/afero v1.12.0 // indirect
8587
github.com/spf13/cast v1.7.1 // indirect
8688
github.com/subosito/gotenv v1.6.0 // indirect
87-
github.com/tidwall/gjson v1.18.0 // indirect
8889
github.com/tidwall/match v1.1.1 // indirect
8990
github.com/tidwall/pretty v1.2.1 // indirect
90-
github.com/tidwall/sjson v1.2.5 // indirect
9191
github.com/tklauser/go-sysconf v0.3.14 // indirect
9292
github.com/tklauser/numcpus v0.9.0 // indirect
9393
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect

pkg/bridge/ai/api_server.go

+23-17
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ import (
2727
const (
2828
// DefaultZipperAddr is the default endpoint of the zipper
2929
DefaultZipperAddr = "localhost:9000"
30+
)
31+
32+
var (
3033
// RequestTimeout is the timeout for the request, default is 90 seconds
3134
RequestTimeout = 90 * time.Second
3235
// RunFunctionTimeout is the timeout for awaiting the function response, default is 60 seconds
@@ -53,11 +56,9 @@ func Serve(config *Config, logger *slog.Logger, source yomo.Source, reducer yomo
5356
}
5457

5558
// NewServeMux creates a new http.ServeMux for the llm bridge server.
56-
func NewServeMux(service *Service) *http.ServeMux {
57-
var (
58-
h = &Handler{service}
59-
mux = http.NewServeMux()
60-
)
59+
func NewServeMux(h *Handler) *http.ServeMux {
60+
mux := http.NewServeMux()
61+
6162
// GET /overview
6263
mux.HandleFunc("/overview", h.HandleOverview)
6364
// POST /invoke
@@ -88,7 +89,7 @@ func NewBasicAPIServer(config *Config, provider provider.LLMProvider, source yom
8889
}
8990
service := NewService(provider, opts)
9091

91-
mux := NewServeMux(service)
92+
mux := NewServeMux(NewHandler(service))
9293

9394
server := &BasicAPIServer{
9495
httpHandler: DecorateHandler(mux, decorateReqContext(service, logger)),
@@ -135,8 +136,8 @@ func decorateReqContext(service *Service, logger *slog.Logger) func(handler http
135136
handler.ServeHTTP(ww, r.WithContext(ctx))
136137

137138
duration := time.Since(start)
138-
if !ww.TTFT.IsZero() {
139-
duration = ww.TTFT.Sub(start)
139+
if ttft := ww.GetTTFT(); !ttft.IsZero() {
140+
duration = ttft.Sub(start)
140141
}
141142

142143
logContent := []any{
@@ -149,8 +150,8 @@ func decorateReqContext(service *Service, logger *slog.Logger) func(handler http
149150
if traceID := span.SpanContext().TraceID(); traceID.IsValid() {
150151
logContent = append(logContent, "traceId", traceID.String())
151152
}
152-
if ww.Err != nil {
153-
logger.Error("llm birdge request", append(logContent, "err", ww.Err)...)
153+
if err := ww.GetError(); err != nil {
154+
logger.Error("llm birdge request", append(logContent, "err", err)...)
154155
} else {
155156
logger.Info("llm birdge request", logContent...)
156157
}
@@ -163,6 +164,11 @@ type Handler struct {
163164
service *Service
164165
}
165166

167+
// NewHandler return a hander that handles chat completions requests.
168+
func NewHandler(service *Service) *Handler {
169+
return &Handler{service}
170+
}
171+
166172
// HandleOverview is the handler for GET /overview
167173
func (h *Handler) HandleOverview(w http.ResponseWriter, r *http.Request) {
168174
w.Header().Set("Content-Type", "application/json")
@@ -188,13 +194,13 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) {
188194
var (
189195
ctx = r.Context()
190196
transID = FromTransIDContext(ctx)
191-
ww = w.(*ResponseWriter)
197+
ww = w.(EventResponseWriter)
192198
)
193199
defer r.Body.Close()
194200

195201
req, err := DecodeRequest[ai.InvokeRequest](r, w, h.service.logger)
196202
if err != nil {
197-
ww.Err = errors.New("bad request")
203+
ww.RecordError(errors.New("bad request"))
198204
return
199205
}
200206

@@ -210,7 +216,7 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) {
210216

211217
res, err := h.service.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, caller, req.IncludeCallStack, tracer)
212218
if err != nil {
213-
ww.Err = err
219+
ww.RecordError(err)
214220
RespondWithError(w, http.StatusInternalServerError, err, h.service.logger)
215221
return
216222
}
@@ -223,13 +229,13 @@ func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request)
223229
var (
224230
ctx = r.Context()
225231
transID = FromTransIDContext(ctx)
226-
ww = w.(*ResponseWriter)
232+
ww = w.(EventResponseWriter)
227233
)
228234
defer r.Body.Close()
229235

230236
req, err := DecodeRequest[openai.ChatCompletionRequest](r, w, h.service.logger)
231237
if err != nil {
232-
ww.Err = err
238+
ww.RecordError(err)
233239
return
234240
}
235241

@@ -242,11 +248,11 @@ func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request)
242248
)
243249

244250
if err := h.service.GetChatCompletions(ctx, req, transID, caller, ww, tracer); err != nil {
245-
ww.Err = err
251+
ww.RecordError(err)
246252
if err == context.Canceled {
247253
return
248254
}
249-
if ww.IsStream {
255+
if ww.IsStream() {
250256
h.service.logger.Error("bridge server error", "err", err.Error(), "err_type", reflect.TypeOf(err).String())
251257
w.Write([]byte(fmt.Sprintf(`{"error":{"message":"%s"}}`, err.Error())))
252258
return

pkg/bridge/ai/api_server_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ func TestServer(t *testing.T) {
5252
MetadataExchanger: func(_ string) (metadata.M, error) { return metadata.M{"hello": "llm bridge"}, nil },
5353
})
5454

55-
handler := DecorateHandler(NewServeMux(service), decorateReqContext(service, service.logger))
55+
handler := DecorateHandler(NewServeMux(NewHandler(service)), decorateReqContext(service, service.logger))
5656

5757
// create a test server
5858
server := httptest.NewServer(handler)

pkg/bridge/ai/call_syncer.go

+33-15
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,21 @@ type CallSyncer interface {
1616
// Call fires a bunch of function callings, and wait the result of these function callings.
1717
// The result only contains the messages with role=="tool".
1818
// If some function callings failed, the content will be returned as the failed reason.
19-
Call(ctx context.Context, transID string, reqID string, toolCalls map[uint32][]*openai.ToolCall) ([]openai.ChatCompletionMessage, error)
19+
Call(ctx context.Context, transID string, reqID string, toolCalls map[uint32][]*openai.ToolCall) ([]ToolCallResult, error)
2020
// Close close the CallSyncer. if close, you can't use this CallSyncer anymore.
2121
Close() error
2222
}
2323

24+
// ToolCallResult is the result of a CallSyncer.Call()
25+
type ToolCallResult struct {
26+
// FunctionName is the name of the function calling.
27+
FunctionName string
28+
// ToolCallID is the tool call id.
29+
ToolCallID string
30+
// Content is the result of the function calling.
31+
Content string
32+
}
33+
2434
type callSyncer struct {
2535
ctx context.Context
2636
cancel context.CancelFunc
@@ -78,19 +88,27 @@ func NewCallSyncer(logger *slog.Logger, sourceCh chan<- TagFunctionCall, reduceC
7888
type toolOut struct {
7989
reqID string
8090
toolIDs map[string]struct{}
81-
ch chan openai.ChatCompletionMessage
91+
ch chan ToolCallResult
8292
}
8393

84-
func (f *callSyncer) Call(ctx context.Context, transID, reqID string, tagToolCalls map[uint32][]*openai.ToolCall) ([]openai.ChatCompletionMessage, error) {
94+
func (f *callSyncer) Call(ctx context.Context, transID, reqID string, tagToolCalls map[uint32][]*openai.ToolCall) ([]ToolCallResult, error) {
8595
defer func() {
8696
f.cleanCh <- reqID
8797
}()
8898

99+
toolNameMap := make(map[string]string)
100+
101+
for _, tools := range tagToolCalls {
102+
for _, tool := range tools {
103+
toolNameMap[tool.ID] = tool.Function.Name
104+
}
105+
}
106+
89107
toolIDs, err := f.fire(transID, reqID, tagToolCalls)
90108
if err != nil {
91109
return nil, err
92110
}
93-
ch := make(chan openai.ChatCompletionMessage)
111+
ch := make(chan ToolCallResult)
94112

95113
otherToolIDs := make(map[string]struct{})
96114
for id := range toolIDs {
@@ -105,7 +123,7 @@ func (f *callSyncer) Call(ctx context.Context, transID, reqID string, tagToolCal
105123

106124
f.toolOutCh <- toolOut
107125

108-
var result []openai.ChatCompletionMessage
126+
var result []ToolCallResult
109127
for {
110128
select {
111129
case <-f.ctx.Done():
@@ -123,10 +141,10 @@ func (f *callSyncer) Call(ctx context.Context, transID, reqID string, tagToolCal
123141
}
124142
case <-time.After(f.timeout):
125143
for id := range toolIDs {
126-
result = append(result, openai.ChatCompletionMessage{
127-
ToolCallID: id,
128-
Role: openai.ChatMessageRoleTool,
129-
Content: "timeout in this function calling, you should ignore this.",
144+
result = append(result, ToolCallResult{
145+
FunctionName: toolNameMap[id],
146+
ToolCallID: id,
147+
Content: "timeout in this function calling, you should ignore this.",
130148
})
131149
}
132150
return result, nil
@@ -166,7 +184,7 @@ func (f *callSyncer) Close() error {
166184

167185
func (f *callSyncer) background() {
168186
// buffered stores the messages from the reducer, the key is the reqID
169-
buffered := make(map[string]map[string]openai.ChatCompletionMessage)
187+
buffered := make(map[string]map[string]ToolCallResult)
170188
// singnals stores the result channel, the key is the reqID, the value channel will be sent when the buffered is fulled.
171189
singnals := make(map[string]toolOut)
172190

@@ -197,18 +215,18 @@ func (f *callSyncer) background() {
197215
f.logger.Warn("recv unexpected message", "msg", msg)
198216
continue
199217
}
200-
result := openai.ChatCompletionMessage{
201-
ToolCallID: msg.Message.ToolCallID,
202-
Role: msg.Message.Role,
203-
Content: msg.Message.Content,
218+
result := ToolCallResult{
219+
FunctionName: msg.Message.Name,
220+
ToolCallID: msg.Message.ToolCallID,
221+
Content: msg.Message.Content,
204222
}
205223

206224
sig, ok := singnals[msg.ReqID]
207225
// the signal that requests a result has not been sent. so buffer the data from reducer.
208226
if !ok {
209227
_, ok := buffered[msg.ReqID]
210228
if !ok {
211-
buffered[msg.ReqID] = make(map[string]openai.ChatCompletionMessage)
229+
buffered[msg.ReqID] = make(map[string]ToolCallResult)
212230
}
213231
buffered[msg.ReqID][msg.Message.ToolCallID] = result
214232
} else {

pkg/bridge/ai/call_syncer_test.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ func TestTimeoutCallSyncer(t *testing.T) {
3838
reqID = "mock-req-id"
3939
)
4040

41-
want := []openai.ChatCompletionMessage{
41+
want := []ToolCallResult{
4242
{
43-
Role: openai.ChatMessageRoleTool,
44-
ToolCallID: "tool-call-id",
45-
Content: "timeout in this function calling, you should ignore this.",
43+
FunctionName: "timeout-function",
44+
ToolCallID: "tool-call-id",
45+
Content: "timeout in this function calling, you should ignore this.",
4646
},
4747
}
4848

@@ -97,15 +97,15 @@ func (h *handler) handle(c serverless.Context) {
9797
h.ctxs[c.(*mock.MockContext)] = struct{}{}
9898
}
9999

100-
func (h *handler) result() []openai.ChatCompletionMessage {
100+
func (h *handler) result() []ToolCallResult {
101101
h.mu.Lock()
102102
defer h.mu.Unlock()
103103

104-
want := []openai.ChatCompletionMessage{}
104+
want := []ToolCallResult{}
105105
for c := range h.ctxs {
106106
invoke, _ := c.LLMFunctionCall()
107-
want = append(want, openai.ChatCompletionMessage{
108-
Role: openai.ChatMessageRoleTool, Content: invoke.Result, ToolCallID: invoke.ToolCallID,
107+
want = append(want, ToolCallResult{
108+
FunctionName: invoke.FunctionName, Content: invoke.Result, ToolCallID: invoke.ToolCallID,
109109
})
110110
}
111111

pkg/bridge/ai/caller.go

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ func reduceFunc(messages chan ReduceMessage, logger *slog.Logger) core.AsyncHand
9999
logger.Debug("sfn-reducer", "req_id", invoke.ReqID, "tool_call_id", invoke.ToolCallID, "result", string(invoke.Result))
100100

101101
message := openai.ChatCompletionMessage{
102+
Name: invoke.FunctionName,
102103
Role: openai.ChatMessageRoleTool,
103104
Content: invoke.Result,
104105
ToolCallID: invoke.ToolCallID,

0 commit comments

Comments
 (0)