@@ -27,6 +27,9 @@ import (
27
27
const (
28
28
// DefaultZipperAddr is the default endpoint of the zipper
29
29
DefaultZipperAddr = "localhost:9000"
30
+ )
31
+
32
+ var (
30
33
// RequestTimeout is the timeout for the request, default is 90 seconds
31
34
RequestTimeout = 90 * time .Second
32
35
// RunFunctionTimeout is the timeout for awaiting the function response, default is 60 seconds
@@ -53,11 +56,9 @@ func Serve(config *Config, logger *slog.Logger, source yomo.Source, reducer yomo
53
56
}
54
57
55
58
// NewServeMux creates a new http.ServeMux for the llm bridge server.
56
- func NewServeMux (service * Service ) * http.ServeMux {
57
- var (
58
- h = & Handler {service }
59
- mux = http .NewServeMux ()
60
- )
59
+ func NewServeMux (h * Handler ) * http.ServeMux {
60
+ mux := http .NewServeMux ()
61
+
61
62
// GET /overview
62
63
mux .HandleFunc ("/overview" , h .HandleOverview )
63
64
// POST /invoke
@@ -88,7 +89,7 @@ func NewBasicAPIServer(config *Config, provider provider.LLMProvider, source yom
88
89
}
89
90
service := NewService (provider , opts )
90
91
91
- mux := NewServeMux (service )
92
+ mux := NewServeMux (NewHandler ( service ) )
92
93
93
94
server := & BasicAPIServer {
94
95
httpHandler : DecorateHandler (mux , decorateReqContext (service , logger )),
@@ -135,8 +136,8 @@ func decorateReqContext(service *Service, logger *slog.Logger) func(handler http
135
136
handler .ServeHTTP (ww , r .WithContext (ctx ))
136
137
137
138
duration := time .Since (start )
138
- if ! ww .TTFT .IsZero () {
139
- duration = ww . TTFT .Sub (start )
139
+ if ttft := ww .GetTTFT (); ! ttft .IsZero () {
140
+ duration = ttft .Sub (start )
140
141
}
141
142
142
143
logContent := []any {
@@ -149,8 +150,8 @@ func decorateReqContext(service *Service, logger *slog.Logger) func(handler http
149
150
if traceID := span .SpanContext ().TraceID (); traceID .IsValid () {
150
151
logContent = append (logContent , "traceId" , traceID .String ())
151
152
}
152
- if ww .Err != nil {
153
- logger .Error ("llm birdge request" , append (logContent , "err" , ww . Err )... )
153
+ if err := ww .GetError (); err != nil {
154
+ logger .Error ("llm birdge request" , append (logContent , "err" , err )... )
154
155
} else {
155
156
logger .Info ("llm birdge request" , logContent ... )
156
157
}
@@ -163,6 +164,11 @@ type Handler struct {
163
164
service * Service
164
165
}
165
166
167
+ // NewHandler return a hander that handles chat completions requests.
168
+ func NewHandler (service * Service ) * Handler {
169
+ return & Handler {service }
170
+ }
171
+
166
172
// HandleOverview is the handler for GET /overview
167
173
func (h * Handler ) HandleOverview (w http.ResponseWriter , r * http.Request ) {
168
174
w .Header ().Set ("Content-Type" , "application/json" )
@@ -188,13 +194,13 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) {
188
194
var (
189
195
ctx = r .Context ()
190
196
transID = FromTransIDContext (ctx )
191
- ww = w .(* ResponseWriter )
197
+ ww = w .(EventResponseWriter )
192
198
)
193
199
defer r .Body .Close ()
194
200
195
201
req , err := DecodeRequest [ai.InvokeRequest ](r , w , h .service .logger )
196
202
if err != nil {
197
- ww .Err = errors .New ("bad request" )
203
+ ww .RecordError ( errors .New ("bad request" ) )
198
204
return
199
205
}
200
206
@@ -210,7 +216,7 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) {
210
216
211
217
res , err := h .service .GetInvoke (ctx , req .Prompt , baseSystemMessage , transID , caller , req .IncludeCallStack , tracer )
212
218
if err != nil {
213
- ww .Err = err
219
+ ww .RecordError ( err )
214
220
RespondWithError (w , http .StatusInternalServerError , err , h .service .logger )
215
221
return
216
222
}
@@ -223,13 +229,13 @@ func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request)
223
229
var (
224
230
ctx = r .Context ()
225
231
transID = FromTransIDContext (ctx )
226
- ww = w .(* ResponseWriter )
232
+ ww = w .(EventResponseWriter )
227
233
)
228
234
defer r .Body .Close ()
229
235
230
236
req , err := DecodeRequest [openai.ChatCompletionRequest ](r , w , h .service .logger )
231
237
if err != nil {
232
- ww .Err = err
238
+ ww .RecordError ( err )
233
239
return
234
240
}
235
241
@@ -242,11 +248,11 @@ func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request)
242
248
)
243
249
244
250
if err := h .service .GetChatCompletions (ctx , req , transID , caller , ww , tracer ); err != nil {
245
- ww .Err = err
251
+ ww .RecordError ( err )
246
252
if err == context .Canceled {
247
253
return
248
254
}
249
- if ww .IsStream {
255
+ if ww .IsStream () {
250
256
h .service .logger .Error ("bridge server error" , "err" , err .Error (), "err_type" , reflect .TypeOf (err ).String ())
251
257
w .Write ([]byte (fmt .Sprintf (`{"error":{"message":"%s"}}` , err .Error ())))
252
258
return
0 commit comments