diff --git a/pkg/api/http/handlers.go b/pkg/api/http/handlers.go index 26b96535..b85e07ad 100644 --- a/pkg/api/http/handlers.go +++ b/pkg/api/http/handlers.go @@ -6,7 +6,6 @@ import ( "github.com/EinStack/glide/pkg/router" - "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/telemetry" "github.com/gofiber/contrib/websocket" "github.com/gofiber/fiber/v2" @@ -35,16 +34,16 @@ type Handler = func(c *fiber.Ctx) error func LangChatHandler(routerManager *router.Manager) Handler { return func(c *fiber.Ctx) error { if !c.Is("json") { - return c.Status(fiber.StatusBadRequest).JSON(schemas.ErrUnsupportedMediaType) + return c.Status(fiber.StatusBadRequest).JSON(schema.ErrUnsupportedMediaType) } // Unmarshal request body - req := schemas.GetChatRequest() - defer schemas.ReleaseChatRequest(req) + req := schema.GetChatRequest() + defer schema.ReleaseChatRequest(req) err := c.BodyParser(&req) if err != nil { - return c.Status(fiber.StatusBadRequest).JSON(schemas.NewPayloadParseErr(err)) + return c.Status(fiber.StatusBadRequest).JSON(schema.NewPayloadParseErr(err)) } // Get router ID from path @@ -52,18 +51,18 @@ func LangChatHandler(routerManager *router.Manager) Handler { r, err := routerManager.GetLangRouter(routerID) if err != nil { - httpErr := schemas.FromErr(err) + httpErr := schema.FromErr(err) return c.Status(httpErr.Status).JSON(httpErr) } // Chat with router - resp := schemas.GetChatResponse() - defer schemas.ReleaseChatResponse(resp) + resp := schema.GetChatResponse() + defer schema.ReleaseChatResponse(resp) resp, err = r.Chat(c.Context(), req) if err != nil { - httpErr := schemas.FromErr(err) + httpErr := schema.FromErr(err) return c.Status(httpErr.Status).JSON(httpErr) } @@ -80,7 +79,7 @@ func LangStreamRouterValidator(routerManager *router.Manager) Handler { _, err := routerManager.GetLangRouter(routerID) if err != nil { - httpErr := schemas.FromErr(err) + httpErr := schema.FromErr(err) return c.Status(httpErr.Status).JSON(httpErr) } @@ -119,7 +118,7 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *router.Manag wg sync.WaitGroup ) - chatStreamC := make(chan *schemas.ChatStreamMessage) + chatStreamC := make(chan *schema.ChatStreamMessage) r, _ := routerManager.GetLangRouter(routerID) @@ -139,7 +138,7 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *router.Manag }() for { - var chatRequest schemas.ChatStreamRequest + var chatRequest schema.ChatStreamRequest if err = c.ReadJSON(&chatRequest); err != nil { // TODO: handle bad request schemas gracefully and return back validation errors @@ -155,7 +154,7 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *router.Manag // TODO: handle termination gracefully wg.Add(1) - go func(chatRequest schemas.ChatStreamRequest) { + go func(chatRequest schema.ChatStreamRequest) { defer wg.Done() r.ChatStream(context.Background(), &chatRequest, chatStreamC) @@ -185,7 +184,7 @@ func LangRoutersHandler(routerManager *router.Manager) Handler { cfgs = append(cfgs, r.Config) } - return c.Status(fiber.StatusOK).JSON(schemas.RouterListSchema{Routers: cfgs}) + return c.Status(fiber.StatusOK).JSON(schema.RouterListSchema{Routers: cfgs}) } } @@ -200,9 +199,9 @@ func LangRoutersHandler(routerManager *router.Manager) Handler { // @Success 200 {object} schemas.HealthSchema // @Router /v1/health/ [get] func HealthHandler(c *fiber.Ctx) error { - return c.Status(fiber.StatusOK).JSON(schemas.HealthSchema{Healthy: true}) + return c.Status(fiber.StatusOK).JSON(schema.HealthSchema{Healthy: true}) } func NotFoundHandler(c *fiber.Ctx) error { - return c.Status(fiber.StatusNotFound).JSON(schemas.ErrRouteNotFound) + return c.Status(fiber.StatusNotFound).JSON(schema.ErrRouteNotFound) } diff --git a/pkg/api/schemas/chat.go b/pkg/api/schema/chat.go similarity index 99% rename from pkg/api/schemas/chat.go rename to pkg/api/schema/chat.go index bb846043..b833b367 100644 --- a/pkg/api/schemas/chat.go +++ b/pkg/api/schema/chat.go @@ -1,4 +1,4 @@ -package schemas +package schema // ChatRequest defines Glide's Chat Request Schema unified across all language models type ChatRequest struct { diff --git a/pkg/api/schemas/chat_stream.go b/pkg/api/schema/chat_stream.go similarity index 99% rename from pkg/api/schemas/chat_stream.go rename to pkg/api/schema/chat_stream.go index f7cf8b27..ee1cd228 100644 --- a/pkg/api/schemas/chat_stream.go +++ b/pkg/api/schema/chat_stream.go @@ -1,4 +1,4 @@ -package schemas +package schema import "time" diff --git a/pkg/api/schemas/chat_test.go b/pkg/api/schema/chat_test.go similarity index 99% rename from pkg/api/schemas/chat_test.go rename to pkg/api/schema/chat_test.go index 9b5ce407..9d77da62 100644 --- a/pkg/api/schemas/chat_test.go +++ b/pkg/api/schema/chat_test.go @@ -1,4 +1,4 @@ -package schemas +package schema import ( "testing" diff --git a/pkg/api/schemas/embed.go b/pkg/api/schema/embed.go similarity index 86% rename from pkg/api/schemas/embed.go rename to pkg/api/schema/embed.go index 16fd70c7..5698d330 100644 --- a/pkg/api/schemas/embed.go +++ b/pkg/api/schema/embed.go @@ -1,4 +1,4 @@ -package schemas +package schema type EmbedRequest struct { // TODO: implement diff --git a/pkg/api/schemas/errors.go b/pkg/api/schema/errors.go similarity index 99% rename from pkg/api/schemas/errors.go rename to pkg/api/schema/errors.go index 2765f93e..0eecf0b5 100644 --- a/pkg/api/schemas/errors.go +++ b/pkg/api/schema/errors.go @@ -1,4 +1,4 @@ -package schemas +package schema import ( "fmt" diff --git a/pkg/api/schemas/health_checks.go b/pkg/api/schema/health_checks.go similarity index 79% rename from pkg/api/schemas/health_checks.go rename to pkg/api/schema/health_checks.go index 6078e769..896e00c5 100644 --- a/pkg/api/schemas/health_checks.go +++ b/pkg/api/schema/health_checks.go @@ -1,4 +1,4 @@ -package schemas +package schema type HealthSchema struct { Healthy bool `json:"healthy"` diff --git a/pkg/api/schemas/pool.go b/pkg/api/schema/pool.go similarity index 97% rename from pkg/api/schemas/pool.go rename to pkg/api/schema/pool.go index dcd9ccf8..4b5c38ba 100755 --- a/pkg/api/schemas/pool.go +++ b/pkg/api/schema/pool.go @@ -1,4 +1,4 @@ -package schemas +package schema import ( "sync" diff --git a/pkg/api/schemas/routers.go b/pkg/api/schema/routers.go similarity index 95% rename from pkg/api/schemas/routers.go rename to pkg/api/schema/routers.go index 9111a319..18dcee02 100644 --- a/pkg/api/schemas/routers.go +++ b/pkg/api/schema/routers.go @@ -1,4 +1,4 @@ -package schemas +package schema // RouterListSchema returns list of active configured routers. // diff --git a/pkg/clients/stream.go b/pkg/clients/stream.go index 913bbddc..4ab55fb0 100644 --- a/pkg/clients/stream.go +++ b/pkg/clients/stream.go @@ -1,21 +1,19 @@ package clients -import ( - "github.com/EinStack/glide/pkg/api/schemas" -) +import "github.com/EinStack/glide/pkg/api/schema" type ChatStream interface { Open() error - Recv() (*schemas.ChatStreamChunk, error) + Recv() (*schema.ChatStreamChunk, error) Close() error } type ChatStreamResult struct { - chunk *schemas.ChatStreamChunk + chunk *schema.ChatStreamChunk err error } -func (r *ChatStreamResult) Chunk() *schemas.ChatStreamChunk { +func (r *ChatStreamResult) Chunk() *schema.ChatStreamChunk { return r.chunk } @@ -23,7 +21,7 @@ func (r *ChatStreamResult) Error() error { return r.err } -func NewChatStreamResult(chunk *schemas.ChatStreamChunk, err error) *ChatStreamResult { +func NewChatStreamResult(chunk *schema.ChatStreamChunk, err error) *ChatStreamResult { return &ChatStreamResult{ chunk: chunk, err: err, diff --git a/pkg/extmodel/lang.go b/pkg/extmodel/lang.go index 0c29870b..7c95282a 100644 --- a/pkg/extmodel/lang.go +++ b/pkg/extmodel/lang.go @@ -5,6 +5,8 @@ import ( "io" "time" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/provider" "github.com/EinStack/glide/pkg/clients" @@ -13,16 +15,14 @@ import ( "github.com/EinStack/glide/pkg/config/fields" "github.com/EinStack/glide/pkg/router/latency" - - "github.com/EinStack/glide/pkg/api/schemas" ) type LangModel interface { Interface Provider() string ModelName() string - Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) - ChatStream(ctx context.Context, params *schemas.ChatParams) (<-chan *clients.ChatStreamResult, error) + Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) + ChatStream(ctx context.Context, params *schema.ChatParams) (<-chan *clients.ChatStreamResult, error) } // LanguageModel wraps provider client and expend it with health & latency tracking @@ -79,7 +79,7 @@ func (m LanguageModel) ChatStreamLatency() *latency.MovingAverage { return m.chatStreamLatency } -func (m *LanguageModel) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (m *LanguageModel) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { startedAt := time.Now() resp, err := m.client.Chat(ctx, params) @@ -98,7 +98,7 @@ func (m *LanguageModel) Chat(ctx context.Context, params *schemas.ChatParams) (* return resp, err } -func (m *LanguageModel) ChatStream(ctx context.Context, params *schemas.ChatParams) (<-chan *clients.ChatStreamResult, error) { +func (m *LanguageModel) ChatStream(ctx context.Context, params *schema.ChatParams) (<-chan *clients.ChatStreamResult, error) { stream, err := m.client.ChatStream(ctx, params) if err != nil { m.healthTracker.TrackErr(err) diff --git a/pkg/provider/anthropic/chat.go b/pkg/provider/anthropic/chat.go index c45efb76..bb0559ad 100644 --- a/pkg/provider/anthropic/chat.go +++ b/pkg/provider/anthropic/chat.go @@ -9,27 +9,28 @@ import ( "net/http" "time" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/api/schemas" "go.uber.org/zap" ) // ChatRequest is an Anthropic-specific request schema type ChatRequest struct { - Model string `json:"model"` - Messages []schemas.ChatMessage `json:"messages"` - System string `json:"system,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Stream bool `json:"stream,omitempty"` - Metadata *string `json:"metadata,omitempty"` - StopSequences []string `json:"stop_sequences,omitempty"` + Model string `json:"model"` + Messages []schema.ChatMessage `json:"messages"` + System string `json:"system,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` + Metadata *string `json:"metadata,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { r.Messages = params.Messages } @@ -51,7 +52,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { // Chat sends a chat request to the specified anthropic model. // // Ref: https://docs.anthropic.com/claude/reference/messages_post -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate @@ -67,7 +68,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -130,19 +131,19 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche usage := anthropicResponse.Usage // Map response to ChatResponse schema - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: anthropicResponse.ID, Created: int(time.Now().UTC().Unix()), // not provided by anthropic Provider: ProviderID, ModelName: anthropicResponse.Model, Cached: false, - ModelResponse: schemas.ModelResponse{ + ModelResponse: schema.ModelResponse{ Metadata: map[string]string{}, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: completion.Type, Content: completion.Text, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ PromptTokens: usage.InputTokens, ResponseTokens: usage.OutputTokens, TotalTokens: usage.InputTokens + usage.OutputTokens, diff --git a/pkg/provider/anthropic/chat_stream.go b/pkg/provider/anthropic/chat_stream.go index dbb0b8ff..1a9f88a4 100644 --- a/pkg/provider/anthropic/chat_stream.go +++ b/pkg/provider/anthropic/chat_stream.go @@ -3,15 +3,15 @@ package anthropic import ( "context" - clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" ) func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients2.ChatStream, error) { - return nil, clients2.ErrChatStreamNotImplemented +func (c *Client) ChatStream(_ context.Context, _ *schema.ChatParams) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented } diff --git a/pkg/provider/anthropic/client_test.go b/pkg/provider/anthropic/client_test.go index 70977bb0..2fe33334 100644 --- a/pkg/provider/anthropic/client_test.go +++ b/pkg/provider/anthropic/client_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -56,7 +56,7 @@ func TestAnthropicClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} @@ -86,7 +86,7 @@ func TestAnthropicClient_BadChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} diff --git a/pkg/provider/azureopenai/chat.go b/pkg/provider/azureopenai/chat.go index 86aab1f2..d2f1200e 100644 --- a/pkg/provider/azureopenai/chat.go +++ b/pkg/provider/azureopenai/chat.go @@ -8,12 +8,12 @@ import ( "io" "net/http" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/provider/openai" - "github.com/EinStack/glide/pkg/api/schemas" - "go.uber.org/zap" ) @@ -38,7 +38,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified azure openai model. -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate // hoping to get a copy of the template @@ -54,7 +54,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -110,19 +110,19 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Map response to UnifiedChatResponse schema - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: chatCompletion.ID, Created: chatCompletion.Created, Provider: providerName, ModelName: chatCompletion.ModelName, Cached: false, - ModelResponse: schemas.ModelResponse{ + ModelResponse: schema.ModelResponse{ Metadata: map[string]string{}, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: modelChoice.Message.Role, Content: modelChoice.Message.Content, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ PromptTokens: chatCompletion.Usage.PromptTokens, ResponseTokens: chatCompletion.Usage.CompletionTokens, TotalTokens: chatCompletion.Usage.TotalTokens, diff --git a/pkg/provider/azureopenai/chat_stream.go b/pkg/provider/azureopenai/chat_stream.go index 8e12c8e3..f75fae4c 100644 --- a/pkg/provider/azureopenai/chat_stream.go +++ b/pkg/provider/azureopenai/chat_stream.go @@ -8,6 +8,8 @@ import ( "io" "net/http" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -17,8 +19,6 @@ import ( "github.com/r3labs/sse/v2" "go.uber.org/zap" - - "github.com/EinStack/glide/pkg/api/schemas" ) // TODO: Think about reducing the number of copy-pasted code btw OpenAI and Azure providers @@ -73,7 +73,7 @@ func (s *ChatStream) Open() error { } // Recv receives a chat stream chunk from the ChatStream and returns a ChatStreamChunk object. -func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { +func (s *ChatStream) Recv() (*schema.ChatStreamChunk, error) { var completionChunk ChatCompletionChunk for { @@ -130,16 +130,16 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { responseChunk := completionChunk.Choices[0] // TODO: use objectpool here - return &schemas.ChatStreamChunk{ + return &schema.ChatStreamChunk{ Cached: false, Provider: providerName, ModelName: completionChunk.ModelName, - ModelResponse: schemas.ModelChunkResponse{ - Metadata: &schemas.Metadata{ + ModelResponse: schema.ModelChunkResponse{ + Metadata: &schema.Metadata{ "response_id": completionChunk.ID, "system_fingerprint": completionChunk.SystemFingerprint, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: responseChunk.Delta.Role, Content: responseChunk.Delta.Content, }, @@ -161,7 +161,7 @@ func (c *Client) SupportChatStream() bool { return true } -func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) { +func (c *Client) ChatStream(ctx context.Context, params *schema.ChatParams) (clients.ChatStream, error) { // Create a new chat request httpRequest, err := c.makeStreamReq(ctx, params) if err != nil { @@ -177,7 +177,7 @@ func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (cl ), nil } -func (c *Client) makeStreamReq(ctx context.Context, params *schemas.ChatParams) (*http.Request, error) { +func (c *Client) makeStreamReq(ctx context.Context, params *schema.ChatParams) (*http.Request, error) { chatReq := *c.chatRequestTemplate chatReq.ApplyParams(params) diff --git a/pkg/provider/azureopenai/chat_stream_test.go b/pkg/provider/azureopenai/chat_stream_test.go index 39a5b93e..f056d599 100644 --- a/pkg/provider/azureopenai/chat_stream_test.go +++ b/pkg/provider/azureopenai/chat_stream_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + clients2 "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -71,7 +71,7 @@ func TestAzureOpenAIClient_ChatStreamRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -139,7 +139,7 @@ func TestAzureOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the biggest animal?", }}} diff --git a/pkg/provider/azureopenai/client_test.go b/pkg/provider/azureopenai/client_test.go index 5c390114..accca38d 100644 --- a/pkg/provider/azureopenai/client_test.go +++ b/pkg/provider/azureopenai/client_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -55,7 +55,7 @@ func TestAzureOpenAIClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -88,7 +88,7 @@ func TestAzureOpenAIClient_ChatError(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} @@ -115,7 +115,7 @@ func TestDoChatRequest_ErrorResponse(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the dealio?", }}} diff --git a/pkg/provider/azureopenai/schemas.go b/pkg/provider/azureopenai/schemas.go index 5940648c..2ce12eb5 100644 --- a/pkg/provider/azureopenai/schemas.go +++ b/pkg/provider/azureopenai/schemas.go @@ -1,27 +1,27 @@ package azureopenai -import "github.com/EinStack/glide/pkg/api/schemas" +import "github.com/EinStack/glide/pkg/api/schema" // ChatRequest is an Azure openai-specific request schema type ChatRequest struct { - Messages []schemas.ChatMessage `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - N int `json:"n,omitempty"` - StopWords []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - FrequencyPenalty int `json:"frequency_penalty,omitempty"` - PresencePenalty int `json:"presence_penalty,omitempty"` - LogitBias *map[int]float64 `json:"logit_bias,omitempty"` - User *string `json:"user,omitempty"` - Seed *int `json:"seed,omitempty"` - Tools []string `json:"tools,omitempty"` - ToolChoice interface{} `json:"tool_choice,omitempty"` - ResponseFormat interface{} `json:"response_format,omitempty"` + Messages []schema.ChatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` + LogitBias *map[int]float64 `json:"logit_bias,omitempty"` + User *string `json:"user,omitempty"` + Seed *int `json:"seed,omitempty"` + Tools []string `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { r.Messages = params.Messages } @@ -38,10 +38,10 @@ type ChatCompletion struct { } type Choice struct { - Index int `json:"index"` - Message schemas.ChatMessage `json:"message"` - Logprobs interface{} `json:"logprobs"` - FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Message schema.ChatMessage `json:"message"` + Logprobs interface{} `json:"logprobs"` + FinishReason string `json:"finish_reason"` } type Usage struct { @@ -62,7 +62,7 @@ type ChatCompletionChunk struct { } type StreamChoice struct { - Index int `json:"index"` - Delta schemas.ChatMessage `json:"delta"` - FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Delta schema.ChatMessage `json:"delta"` + FinishReason string `json:"finish_reason"` } diff --git a/pkg/provider/bedrock/chat.go b/pkg/provider/bedrock/chat.go index 658c1769..cd51027b 100644 --- a/pkg/provider/bedrock/chat.go +++ b/pkg/provider/bedrock/chat.go @@ -6,7 +6,7 @@ import ( "fmt" "time" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" "go.uber.org/zap" @@ -22,7 +22,7 @@ type ChatRequest struct { TextGenerationConfig TextGenerationConfig `json:"textGenerationConfig"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { // message history not yet supported for AWS models // TODO: do something about lack of message history. Maybe just concatenate all messages? // in any case, this is not a way to go to ignore message history @@ -51,7 +51,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified bedrock model. -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate // hoping to get a copy of the template @@ -65,7 +65,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { rawPayload, err := json.Marshal(payload) if err != nil { return nil, fmt.Errorf("unable to marshal chat request payload: %w", err) @@ -96,18 +96,18 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return nil, ErrEmptyResponse } - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: uuid.NewString(), Created: int(time.Now().Unix()), Provider: providerName, ModelName: c.config.ModelName, Cached: false, - ModelResponse: schemas.ModelResponse{ - Message: schemas.ChatMessage{ + ModelResponse: schema.ModelResponse{ + Message: schema.ChatMessage{ Role: "assistant", Content: modelResult.OutputText, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ // TODO: what would happen if there is a few responses? We need to sum that up PromptTokens: modelResult.TokenCount, ResponseTokens: -1, diff --git a/pkg/provider/bedrock/chat_stream.go b/pkg/provider/bedrock/chat_stream.go index 57413043..6bb87905 100644 --- a/pkg/provider/bedrock/chat_stream.go +++ b/pkg/provider/bedrock/chat_stream.go @@ -3,15 +3,15 @@ package bedrock import ( "context" - clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" ) func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients2.ChatStream, error) { - return nil, clients2.ErrChatStreamNotImplemented +func (c *Client) ChatStream(_ context.Context, _ *schema.ChatParams) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented } diff --git a/pkg/provider/bedrock/client_test.go b/pkg/provider/bedrock/client_test.go index e99f8d9c..f6081966 100644 --- a/pkg/provider/bedrock/client_test.go +++ b/pkg/provider/bedrock/client_test.go @@ -11,9 +11,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -61,7 +61,7 @@ func TestBedrockClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the biggest animal?", }}} diff --git a/pkg/provider/cohere/chat.go b/pkg/provider/cohere/chat.go index 754d8537..7b2ebbb9 100644 --- a/pkg/provider/cohere/chat.go +++ b/pkg/provider/cohere/chat.go @@ -9,9 +9,9 @@ import ( "net/http" "time" - "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "go.uber.org/zap" ) @@ -30,7 +30,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified cohere model. -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate @@ -44,7 +44,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -115,22 +115,22 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Map response to ChatResponse schema - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: cohereCompletion.ResponseID, Created: int(time.Now().UTC().Unix()), // Cohere doesn't provide this Provider: ProviderID, ModelName: c.config.ModelName, Cached: false, - ModelResponse: schemas.ModelResponse{ + ModelResponse: schema.ModelResponse{ Metadata: map[string]string{ "generationId": cohereCompletion.GenerationID, "responseId": cohereCompletion.ResponseID, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: "assistant", Content: cohereCompletion.Text, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ PromptTokens: cohereCompletion.TokenCount.PromptTokens, ResponseTokens: cohereCompletion.TokenCount.ResponseTokens, TotalTokens: cohereCompletion.TokenCount.TotalTokens, diff --git a/pkg/provider/cohere/chat_stream.go b/pkg/provider/cohere/chat_stream.go index 51bd8045..46b07598 100644 --- a/pkg/provider/cohere/chat_stream.go +++ b/pkg/provider/cohere/chat_stream.go @@ -8,13 +8,13 @@ import ( "io" "net/http" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" "go.uber.org/zap" - - "github.com/EinStack/glide/pkg/api/schemas" ) // SupportedEventType Cohere has other types too: @@ -83,7 +83,7 @@ func (s *ChatStream) Open() error { return nil } -func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { +func (s *ChatStream) Recv() (*schema.ChatStreamChunk, error) { if s.streamFinished { return nil, io.EOF } @@ -135,16 +135,16 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { s.streamFinished = true // TODO: use objectpool here - return &schemas.ChatStreamChunk{ + return &schema.ChatStreamChunk{ Cached: false, Provider: ProviderID, ModelName: s.modelName, - ModelResponse: schemas.ModelChunkResponse{ - Metadata: &schemas.Metadata{ + ModelResponse: schema.ModelChunkResponse{ + Metadata: &schema.Metadata{ "generation_id": s.generationID, "response_id": responseChunk.Response.ResponseID, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: "model", Content: responseChunk.Text, }, @@ -154,15 +154,15 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { } // TODO: use objectpool here - return &schemas.ChatStreamChunk{ + return &schema.ChatStreamChunk{ Cached: false, Provider: ProviderID, ModelName: s.modelName, - ModelResponse: schemas.ModelChunkResponse{ - Metadata: &schemas.Metadata{ + ModelResponse: schema.ModelChunkResponse{ + Metadata: &schema.Metadata{ "generation_id": s.generationID, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: "model", Content: responseChunk.Text, }, @@ -183,7 +183,7 @@ func (c *Client) SupportChatStream() bool { return true } -func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) { +func (c *Client) ChatStream(ctx context.Context, params *schema.ChatParams) (clients.ChatStream, error) { // Create a new chat request httpRequest, err := c.makeStreamReq(ctx, params) if err != nil { @@ -200,7 +200,7 @@ func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (cl ), nil } -func (c *Client) makeStreamReq(ctx context.Context, params *schemas.ChatParams) (*http.Request, error) { +func (c *Client) makeStreamReq(ctx context.Context, params *schema.ChatParams) (*http.Request, error) { // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate chatReq.ApplyParams(params) diff --git a/pkg/provider/cohere/chat_stream_test.go b/pkg/provider/cohere/chat_stream_test.go index 82060f84..e40eed14 100644 --- a/pkg/provider/cohere/chat_stream_test.go +++ b/pkg/provider/cohere/chat_stream_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -71,7 +71,7 @@ func TestCohere_ChatStreamRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -138,7 +138,7 @@ func TestCohere_ChatStreamRequestInterrupted(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} diff --git a/pkg/provider/cohere/client_test.go b/pkg/provider/cohere/client_test.go index bb4f99e4..721ceda7 100644 --- a/pkg/provider/cohere/client_test.go +++ b/pkg/provider/cohere/client_test.go @@ -11,9 +11,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -55,7 +55,7 @@ func TestCohereClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} diff --git a/pkg/provider/cohere/finish_reason.go b/pkg/provider/cohere/finish_reason.go index 139498e6..4d156875 100644 --- a/pkg/provider/cohere/finish_reason.go +++ b/pkg/provider/cohere/finish_reason.go @@ -3,9 +3,10 @@ package cohere import ( "strings" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/api/schemas" "go.uber.org/zap" ) @@ -27,27 +28,27 @@ type FinishReasonMapper struct { tel *telemetry.Telemetry } -func (m *FinishReasonMapper) Map(finishReason *string) *schemas.FinishReason { +func (m *FinishReasonMapper) Map(finishReason *string) *schema.FinishReason { if finishReason == nil || len(*finishReason) == 0 { return nil } - var reason *schemas.FinishReason + var reason *schema.FinishReason switch strings.ToLower(*finishReason) { case CompleteReason: - reason = &schemas.ReasonComplete + reason = &schema.ReasonComplete case MaxTokensReason: - reason = &schemas.ReasonMaxTokens + reason = &schema.ReasonMaxTokens case FilteredReason: - reason = &schemas.ReasonContentFiltered + reason = &schema.ReasonContentFiltered default: m.tel.Logger.Warn( "Unknown finish reason, other is going to used", zap.String("unknown_reason", *finishReason), ) - reason = &schemas.ReasonOther + reason = &schema.ReasonOther } return reason diff --git a/pkg/provider/cohere/schemas.go b/pkg/provider/cohere/schemas.go index 9dc9bb09..c224ec0e 100644 --- a/pkg/provider/cohere/schemas.go +++ b/pkg/provider/cohere/schemas.go @@ -1,6 +1,6 @@ package cohere -import "github.com/EinStack/glide/pkg/api/schemas" +import "github.com/EinStack/glide/pkg/api/schema" // Cohere Chat Response type ChatCompletion struct { @@ -90,25 +90,25 @@ type FinalResponse struct { // ChatRequest is a request to complete a chat completion // Ref: https://docs.cohere.com/reference/chat type ChatRequest struct { - Model string `json:"model"` - Message string `json:"message"` - ChatHistory []schemas.ChatMessage `json:"chat_history"` - Temperature float64 `json:"temperature,omitempty"` - Preamble string `json:"preamble,omitempty"` - PromptTruncation *string `json:"prompt_truncation,omitempty"` - Connectors []string `json:"connectors,omitempty"` - SearchQueriesOnly bool `json:"search_queries_only,omitempty"` - Stream bool `json:"stream,omitempty"` - Seed *int `json:"seed,omitempty"` - MaxTokens *int `json:"max_tokens,omitempty"` - K int `json:"k"` - P float32 `json:"p"` - FrequencyPenalty float32 `json:"frequency_penalty"` - PresencePenalty float32 `json:"presence_penalty"` - StopSequences []string `json:"stop_sequences"` + Model string `json:"model"` + Message string `json:"message"` + ChatHistory []schema.ChatMessage `json:"chat_history"` + Temperature float64 `json:"temperature,omitempty"` + Preamble string `json:"preamble,omitempty"` + PromptTruncation *string `json:"prompt_truncation,omitempty"` + Connectors []string `json:"connectors,omitempty"` + SearchQueriesOnly bool `json:"search_queries_only,omitempty"` + Stream bool `json:"stream,omitempty"` + Seed *int `json:"seed,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + K int `json:"k"` + P float32 `json:"p"` + FrequencyPenalty float32 `json:"frequency_penalty"` + PresencePenalty float32 `json:"presence_penalty"` + StopSequences []string `json:"stop_sequences"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { message := params.Messages[len(params.Messages)-1] messageHistory := params.Messages[:len(params.Messages)-1] diff --git a/pkg/provider/interface.go b/pkg/provider/interface.go index b2e4ffbd..1171f486 100644 --- a/pkg/provider/interface.go +++ b/pkg/provider/interface.go @@ -4,7 +4,8 @@ import ( "context" "errors" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/clients" ) @@ -22,13 +23,13 @@ type ModelProvider interface { type LangProvider interface { ModelProvider SupportChatStream() bool - Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) - ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) + Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) + ChatStream(ctx context.Context, params *schema.ChatParams) (clients.ChatStream, error) } // EmbeddingProvider defines an interface a provider should fulfill to be able to generate embeddings type EmbeddingProvider interface { ModelProvider SupportEmbedding() bool - Embed(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) + Embed(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) } diff --git a/pkg/provider/octoml/chat.go b/pkg/provider/octoml/chat.go index 9b2237f3..3648cd95 100644 --- a/pkg/provider/octoml/chat.go +++ b/pkg/provider/octoml/chat.go @@ -8,27 +8,27 @@ import ( "io" "net/http" - "github.com/EinStack/glide/pkg/provider/openai" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/provider/openai" "go.uber.org/zap" ) // ChatRequest is an octoml-specific request schema type ChatRequest struct { - Model string `json:"model"` - Messages []schemas.ChatMessage `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - StopWords []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - FrequencyPenalty int `json:"frequency_penalty,omitempty"` - PresencePenalty int `json:"presence_penalty,omitempty"` + Model string `json:"model"` + Messages []schema.ChatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { // TODO(185): set other params r.Messages = params.Messages } @@ -47,7 +47,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified octoml model. -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate // hoping to get a copy of the template @@ -63,7 +63,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -119,21 +119,21 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Map response to UnifiedChatResponse schema - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: completion.ID, Created: completion.Created, Provider: providerName, ModelName: completion.ModelName, Cached: false, - ModelResponse: schemas.ModelResponse{ + ModelResponse: schema.ModelResponse{ Metadata: map[string]string{ "system_fingerprint": completion.SystemFingerprint, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: modelChoice.Message.Role, Content: modelChoice.Message.Content, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ PromptTokens: completion.Usage.PromptTokens, ResponseTokens: completion.Usage.CompletionTokens, TotalTokens: completion.Usage.TotalTokens, diff --git a/pkg/provider/octoml/chat_stream.go b/pkg/provider/octoml/chat_stream.go index 7b8a1766..22ead76a 100644 --- a/pkg/provider/octoml/chat_stream.go +++ b/pkg/provider/octoml/chat_stream.go @@ -3,15 +3,15 @@ package octoml import ( "context" - clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" ) func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients2.ChatStream, error) { - return nil, clients2.ErrChatStreamNotImplemented +func (c *Client) ChatStream(_ context.Context, _ *schema.ChatParams) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented } diff --git a/pkg/provider/octoml/client_test.go b/pkg/provider/octoml/client_test.go index 128fd1f0..fcc266c1 100644 --- a/pkg/provider/octoml/client_test.go +++ b/pkg/provider/octoml/client_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -55,7 +55,7 @@ func TestOctoMLClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} @@ -88,7 +88,7 @@ func TestOctoMLClient_Chat_Error(t *testing.T) { require.NoError(t, err) // Create a chat request - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} @@ -120,7 +120,7 @@ func TestDoChatRequest_ErrorResponse(t *testing.T) { require.NoError(t, err) // Create a chat request payload - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the dealeo?", }}} diff --git a/pkg/provider/ollama/chat.go b/pkg/provider/ollama/chat.go index 42ee1f99..463899c8 100644 --- a/pkg/provider/ollama/chat.go +++ b/pkg/provider/ollama/chat.go @@ -9,38 +9,38 @@ import ( "net/http" "time" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/clients" "github.com/google/uuid" - "github.com/EinStack/glide/pkg/api/schemas" - "go.uber.org/zap" ) // ChatRequest is an ollama-specific request schema type ChatRequest struct { - Model string `json:"model"` - Messages []schemas.ChatMessage `json:"messages"` - Microstat int `json:"microstat,omitempty"` - MicrostatEta float64 `json:"microstat_eta,omitempty"` - MicrostatTau float64 `json:"microstat_tau,omitempty"` - NumCtx int `json:"num_ctx,omitempty"` - NumGqa int `json:"num_gqa,omitempty"` - NumGpu int `json:"num_gpu,omitempty"` - NumThread int `json:"num_thread,omitempty"` - RepeatLastN int `json:"repeat_last_n,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - Seed int `json:"seed,omitempty"` - StopWords []string `json:"stop,omitempty"` - Tfsz float64 `json:"tfs_z,omitempty"` - NumPredict int `json:"num_predict,omitempty"` - TopK int `json:"top_k,omitempty"` - TopP float64 `json:"top_p,omitempty"` - Stream bool `json:"stream"` + Model string `json:"model"` + Messages []schema.ChatMessage `json:"messages"` + Microstat int `json:"microstat,omitempty"` + MicrostatEta float64 `json:"microstat_eta,omitempty"` + MicrostatTau float64 `json:"microstat_tau,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` + NumGqa int `json:"num_gqa,omitempty"` + NumGpu int `json:"num_gpu,omitempty"` + NumThread int `json:"num_thread,omitempty"` + RepeatLastN int `json:"repeat_last_n,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Seed int `json:"seed,omitempty"` + StopWords []string `json:"stop,omitempty"` + Tfsz float64 `json:"tfs_z,omitempty"` + NumPredict int `json:"num_predict,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Stream bool `json:"stream"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { // TODO(185): set other params r.Messages = params.Messages } @@ -68,7 +68,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified ollama model. -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate // hoping to get a copy of the template @@ -84,7 +84,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { //nolint:cyclop +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { //nolint:cyclop // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -164,18 +164,18 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Map response to UnifiedChatResponse schema - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: uuid.NewString(), Created: int(time.Now().Unix()), Provider: providerName, ModelName: ollamaCompletion.Model, Cached: false, - ModelResponse: schemas.ModelResponse{ - Message: schemas.ChatMessage{ + ModelResponse: schema.ModelResponse{ + Message: schema.ChatMessage{ Role: ollamaCompletion.Message.Role, Content: ollamaCompletion.Message.Content, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ PromptTokens: ollamaCompletion.EvalCount, ResponseTokens: ollamaCompletion.EvalCount, TotalTokens: ollamaCompletion.EvalCount, diff --git a/pkg/provider/ollama/chat_stream.go b/pkg/provider/ollama/chat_stream.go index 15d220e9..d43f88c0 100644 --- a/pkg/provider/ollama/chat_stream.go +++ b/pkg/provider/ollama/chat_stream.go @@ -3,15 +3,15 @@ package ollama import ( "context" - clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" ) func (c *Client) SupportChatStream() bool { return false } -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients2.ChatStream, error) { - return nil, clients2.ErrChatStreamNotImplemented +func (c *Client) ChatStream(_ context.Context, _ *schema.ChatParams) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented } diff --git a/pkg/provider/ollama/client_test.go b/pkg/provider/ollama/client_test.go index 1c9dad49..e371c39d 100644 --- a/pkg/provider/ollama/client_test.go +++ b/pkg/provider/ollama/client_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -56,7 +56,7 @@ func TestOllamaClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the biggest animal?", }}} @@ -84,7 +84,7 @@ func TestOllamaClient_ChatRequest_Non200Response(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -121,7 +121,7 @@ func TestOllamaClient_ChatRequest_SuccessfulResponse(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} diff --git a/pkg/provider/openai/chat.go b/pkg/provider/openai/chat.go index 86bce6f1..be2bcbf3 100644 --- a/pkg/provider/openai/chat.go +++ b/pkg/provider/openai/chat.go @@ -8,9 +8,10 @@ import ( "io" "net/http" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/api/schemas" "go.uber.org/zap" ) @@ -36,7 +37,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified OpenAI model. -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate // hoping to get a copy of the template @@ -52,7 +53,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -123,21 +124,21 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Map response to ChatResponse schema - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: chatCompletion.ID, Created: chatCompletion.Created, Provider: ProviderID, ModelName: chatCompletion.ModelName, Cached: false, - ModelResponse: schemas.ModelResponse{ + ModelResponse: schema.ModelResponse{ Metadata: map[string]string{ "system_fingerprint": chatCompletion.SystemFingerprint, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: modelChoice.Message.Role, Content: modelChoice.Message.Content, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ PromptTokens: chatCompletion.Usage.PromptTokens, ResponseTokens: chatCompletion.Usage.CompletionTokens, TotalTokens: chatCompletion.Usage.TotalTokens, diff --git a/pkg/provider/openai/chat_stream.go b/pkg/provider/openai/chat_stream.go index 0e4a341e..9d50295b 100644 --- a/pkg/provider/openai/chat_stream.go +++ b/pkg/provider/openai/chat_stream.go @@ -8,12 +8,12 @@ import ( "io" "net/http" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/clients" "github.com/r3labs/sse/v2" "go.uber.org/zap" - - "github.com/EinStack/glide/pkg/api/schemas" ) var StreamDoneMarker = []byte("[DONE]") @@ -66,7 +66,7 @@ func (s *ChatStream) Open() error { return nil } -func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { +func (s *ChatStream) Recv() (*schema.ChatStreamChunk, error) { var completionChunk ChatCompletionChunk for { @@ -115,17 +115,17 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { responseChunk := completionChunk.Choices[0] // TODO: use objectpool here - return &schemas.ChatStreamChunk{ + return &schema.ChatStreamChunk{ Cached: false, Provider: ProviderID, ModelName: completionChunk.ModelName, - ModelResponse: schemas.ModelChunkResponse{ - Metadata: &schemas.Metadata{ + ModelResponse: schema.ModelChunkResponse{ + Metadata: &schema.Metadata{ "response_id": completionChunk.ID, "system_fingerprint": completionChunk.SystemFingerprint, "generated_at": completionChunk.Created, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: "assistant", // doesn't present in all chunks Content: responseChunk.Delta.Content, }, @@ -147,7 +147,7 @@ func (c *Client) SupportChatStream() bool { return true } -func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) { +func (c *Client) ChatStream(ctx context.Context, params *schema.ChatParams) (clients.ChatStream, error) { // Create a new chat request httpRequest, err := c.makeStreamReq(ctx, params) if err != nil { @@ -163,7 +163,7 @@ func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (cl ), nil } -func (c *Client) makeStreamReq(ctx context.Context, params *schemas.ChatParams) (*http.Request, error) { +func (c *Client) makeStreamReq(ctx context.Context, params *schema.ChatParams) (*http.Request, error) { // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate // hoping to get a copy of the template chatReq.ApplyParams(params) diff --git a/pkg/provider/openai/chat_stream_test.go b/pkg/provider/openai/chat_stream_test.go index 6928e6f0..2934df3f 100644 --- a/pkg/provider/openai/chat_stream_test.go +++ b/pkg/provider/openai/chat_stream_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + clients2 "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -71,7 +71,7 @@ func TestOpenAIClient_ChatStreamRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -139,7 +139,7 @@ func TestOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} diff --git a/pkg/provider/openai/chat_test.go b/pkg/provider/openai/chat_test.go index 0aae4d0e..65dde4f6 100644 --- a/pkg/provider/openai/chat_test.go +++ b/pkg/provider/openai/chat_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - clients2 "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + clients2 "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -56,7 +56,7 @@ func TestOpenAIClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -85,7 +85,7 @@ func TestOpenAIClient_RateLimit(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} diff --git a/pkg/provider/openai/embed.go b/pkg/provider/openai/embed.go index 69f9aa27..ba054adc 100644 --- a/pkg/provider/openai/embed.go +++ b/pkg/provider/openai/embed.go @@ -3,11 +3,11 @@ package openai import ( "context" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" ) // Embed sends an embedding request to the specified OpenAI model. -func (c *Client) Embed(_ context.Context, _ *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Embed(_ context.Context, _ *schema.ChatParams) (*schema.ChatResponse, error) { // TODO: implement return nil, nil } diff --git a/pkg/provider/openai/finish_reasons.go b/pkg/provider/openai/finish_reasons.go index 28b5f675..65196946 100644 --- a/pkg/provider/openai/finish_reasons.go +++ b/pkg/provider/openai/finish_reasons.go @@ -1,9 +1,9 @@ package openai import ( + "github.com/EinStack/glide/pkg/api/schema" "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/api/schemas" "go.uber.org/zap" ) @@ -25,27 +25,27 @@ type FinishReasonMapper struct { tel *telemetry.Telemetry } -func (m *FinishReasonMapper) Map(finishReason string) *schemas.FinishReason { +func (m *FinishReasonMapper) Map(finishReason string) *schema.FinishReason { if len(finishReason) == 0 { return nil } - var reason *schemas.FinishReason + var reason *schema.FinishReason switch finishReason { case CompleteReason: - reason = &schemas.ReasonComplete + reason = &schema.ReasonComplete case MaxTokensReason: - reason = &schemas.ReasonMaxTokens + reason = &schema.ReasonMaxTokens case FilteredReason: - reason = &schemas.ReasonContentFiltered + reason = &schema.ReasonContentFiltered default: m.tel.Logger.Warn( "Unknown finish reason, other is going to used", zap.String("unknown_reason", finishReason), ) - reason = &schemas.ReasonOther + reason = &schema.ReasonOther } return reason diff --git a/pkg/provider/openai/schemas.go b/pkg/provider/openai/schemas.go index bde0ba81..31af6f9c 100644 --- a/pkg/provider/openai/schemas.go +++ b/pkg/provider/openai/schemas.go @@ -1,28 +1,28 @@ package openai -import "github.com/EinStack/glide/pkg/api/schemas" +import "github.com/EinStack/glide/pkg/api/schema" // ChatRequest is an OpenAI-specific request schema type ChatRequest struct { - Model string `json:"model"` - Messages []schemas.ChatMessage `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - N int `json:"n,omitempty"` - StopWords []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - FrequencyPenalty int `json:"frequency_penalty,omitempty"` - PresencePenalty int `json:"presence_penalty,omitempty"` - LogitBias *map[int]float64 `json:"logit_bias,omitempty"` - User *string `json:"user,omitempty"` - Seed *int `json:"seed,omitempty"` - Tools []string `json:"tools,omitempty"` - ToolChoice interface{} `json:"tool_choice,omitempty"` - ResponseFormat interface{} `json:"response_format,omitempty"` + Model string `json:"model"` + Messages []schema.ChatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` + LogitBias *map[int]float64 `json:"logit_bias,omitempty"` + User *string `json:"user,omitempty"` + Seed *int `json:"seed,omitempty"` + Tools []string `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { // TODO(185): set other params r.Messages = params.Messages } @@ -40,10 +40,10 @@ type ChatCompletion struct { } type Choice struct { - Index int `json:"index"` - Message schemas.ChatMessage `json:"message"` - Logprobs interface{} `json:"logprobs"` - FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Message schema.ChatMessage `json:"message"` + Logprobs interface{} `json:"logprobs"` + FinishReason string `json:"finish_reason"` } type Usage struct { @@ -64,8 +64,8 @@ type ChatCompletionChunk struct { } type StreamChoice struct { - Index int `json:"index"` - Delta schemas.ChatMessage `json:"delta"` - Logprobs interface{} `json:"logprobs"` - FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Delta schema.ChatMessage `json:"delta"` + Logprobs interface{} `json:"logprobs"` + FinishReason string `json:"finish_reason"` } diff --git a/pkg/provider/testing.go b/pkg/provider/testing.go index cef49cdb..f7bc64f0 100644 --- a/pkg/provider/testing.go +++ b/pkg/provider/testing.go @@ -4,7 +4,8 @@ import ( "context" "io" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/config/fields" "github.com/EinStack/glide/pkg/telemetry" @@ -37,24 +38,24 @@ type RespMock struct { Err error } -func (m *RespMock) Resp() *schemas.ChatResponse { - return &schemas.ChatResponse{ +func (m *RespMock) Resp() *schema.ChatResponse { + return &schema.ChatResponse{ ID: "rsp0001", - ModelResponse: schemas.ModelResponse{ + ModelResponse: schema.ModelResponse{ Metadata: map[string]string{ "ID": "0001", }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Content: m.Msg, }, }, } } -func (m *RespMock) RespChunk() *schemas.ChatStreamChunk { - return &schemas.ChatStreamChunk{ - ModelResponse: schemas.ModelChunkResponse{ - Message: schemas.ChatMessage{ +func (m *RespMock) RespChunk() *schema.ChatStreamChunk { + return &schema.ChatStreamChunk{ + ModelResponse: schema.ModelChunkResponse{ + Message: schema.ChatMessage{ Content: m.Msg, }, }, @@ -97,7 +98,7 @@ func (m *RespStreamMock) Open() error { return nil } -func (m *RespStreamMock) Recv() (*schemas.ChatStreamChunk, error) { +func (m *RespStreamMock) Recv() (*schema.ChatStreamChunk, error) { if m.Chunks != nil && m.idx >= len(*m.Chunks) { return nil, io.EOF } @@ -154,7 +155,7 @@ func (c *Mock) SupportChatStream() bool { return c.supportStreaming } -func (c *Mock) Chat(_ context.Context, _ *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Mock) Chat(_ context.Context, _ *schema.ChatParams) (*schema.ChatResponse, error) { if c.chatResps == nil { return nil, clients.ErrProviderUnavailable } @@ -171,7 +172,7 @@ func (c *Mock) Chat(_ context.Context, _ *schemas.ChatParams) (*schemas.ChatResp return response.Resp(), nil } -func (c *Mock) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { +func (c *Mock) ChatStream(_ context.Context, _ *schema.ChatParams) (clients.ChatStream, error) { if c.chatStreams == nil || c.idx >= len(*c.chatStreams) { return nil, clients.ErrProviderUnavailable } diff --git a/pkg/router/embed_router.go b/pkg/router/embed_router.go index 5d62b522..b2ad8a59 100644 --- a/pkg/router/embed_router.go +++ b/pkg/router/embed_router.go @@ -3,7 +3,8 @@ package router import ( "context" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/telemetry" ) @@ -20,6 +21,6 @@ func NewEmbedRouter(_ *EmbedRouterConfig, _ *telemetry.Telemetry) (*EmbedRouter, return &EmbedRouter{}, nil } -func (r *EmbedRouter) Embed(ctx context.Context, req *schemas.EmbedRequest) (*schemas.EmbedResponse, error) { +func (r *EmbedRouter) Embed(ctx context.Context, req *schema.EmbedRequest) (*schema.EmbedResponse, error) { // TODO: implement } diff --git a/pkg/router/lang_router.go b/pkg/router/lang_router.go index 04c64e9e..ec2d9113 100644 --- a/pkg/router/lang_router.go +++ b/pkg/router/lang_router.go @@ -4,9 +4,10 @@ import ( "context" "errors" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/extmodel" - "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/resiliency/retry" "github.com/EinStack/glide/pkg/router/routing" "github.com/EinStack/glide/pkg/telemetry" @@ -59,7 +60,7 @@ func (r *LangRouter) ID() ID { return r.routerID } -func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error) { +func (r *LangRouter) Chat(ctx context.Context, req *schema.ChatRequest) (*schema.ChatResponse, error) { if len(r.chatModels) == 0 { return nil, ErrNoModels } @@ -112,22 +113,22 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem // if we reach this part, then we are in trouble r.logger.Error("No model was available to handle chat request") - return nil, &schemas.ErrNoModelAvailable + return nil, &schema.ErrNoModelAvailable } func (r *LangRouter) ChatStream( ctx context.Context, - req *schemas.ChatStreamRequest, - respC chan<- *schemas.ChatStreamMessage, + req *schema.ChatStreamRequest, + respC chan<- *schema.ChatStreamMessage, ) { if len(r.chatStreamModels) == 0 { - respC <- schemas.NewChatStreamError( + respC <- schema.NewChatStreamError( req.ID, r.routerID, - schemas.NoModelConfigured, + schema.NoModelConfigured, ErrNoModels.Error(), req.Metadata, - &schemas.ReasonError, + &schema.ReasonError, ) return @@ -175,10 +176,10 @@ func (r *LangRouter) ChatStream( // It's challenging to hide an error in case of streaming chat as consumer apps // may have already used all chunks we streamed this far (e.g. showed them to their users like OpenAI UI does), // so we cannot easily restart that process from scratch - respC <- schemas.NewChatStreamError( + respC <- schema.NewChatStreamError( req.ID, r.routerID, - schemas.ModelUnavailable, + schema.ModelUnavailable, err.Error(), req.Metadata, nil, @@ -189,7 +190,7 @@ func (r *LangRouter) ChatStream( chunk := chunkResult.Chunk() - respC <- schemas.NewChatStreamChunk( + respC <- schema.NewChatStreamChunk( req.ID, r.routerID, req.Metadata, @@ -207,10 +208,10 @@ func (r *LangRouter) ChatStream( err := retryIterator.WaitNext(ctx) if err != nil { // something has cancelled the context - respC <- schemas.NewChatStreamError( + respC <- schema.NewChatStreamError( req.ID, r.routerID, - schemas.UnknownError, + schema.UnknownError, err.Error(), req.Metadata, nil, @@ -226,12 +227,12 @@ func (r *LangRouter) ChatStream( "Try to configure more fallback models to avoid this", ) - respC <- schemas.NewChatStreamError( + respC <- schema.NewChatStreamError( req.ID, r.routerID, - schemas.ErrNoModelAvailable.Name, - schemas.ErrNoModelAvailable.Message, + schema.ErrNoModelAvailable.Name, + schema.ErrNoModelAvailable.Message, req.Metadata, - &schemas.ReasonError, + &schema.ReasonError, ) } diff --git a/pkg/router/lang_router_test.go b/pkg/router/lang_router_test.go index 92eb5d4b..68c31cab 100644 --- a/pkg/router/lang_router_test.go +++ b/pkg/router/lang_router_test.go @@ -13,7 +13,6 @@ import ( "github.com/EinStack/glide/pkg/resiliency/health" "github.com/EinStack/glide/pkg/resiliency/retry" - "github.com/EinStack/glide/pkg/api/schemas" "github.com/EinStack/glide/pkg/router/latency" "github.com/EinStack/glide/pkg/router/routing" "github.com/EinStack/glide/pkg/telemetry" @@ -56,7 +55,7 @@ func TestLangRouter_Chat_PickFistHealthy(t *testing.T) { } ctx := context.Background() - req := schemas.NewChatFromStr("tell me a dad joke") + req := schema.NewChatFromStr("tell me a dad joke") for i := 0; i < 2; i++ { resp, err := router.Chat(ctx, req) @@ -73,14 +72,14 @@ func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { langModels := []*extmodel.LanguageModel{ extmodel.NewLangModel( "first", - provider.NewMock(nil, []provider.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "3"}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schema.ErrNoModelAvailable}, {Msg: "3"}}), budget, *latConfig, 1, ), extmodel.NewLangModel( "second", - provider.NewMock(nil, []provider.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "4"}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schema.ErrNoModelAvailable}, {Msg: "4"}}), budget, *latConfig, 1, @@ -113,7 +112,7 @@ func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { } ctx := context.Background() - req := schemas.NewChatFromStr("tell me a dad joke") + req := schema.NewChatFromStr("tell me a dad joke") for _, modelID := range expectedModels { resp, err := router.Chat(ctx, req) @@ -130,14 +129,14 @@ func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) { langModels := []*extmodel.LanguageModel{ extmodel.NewLangModel( "first", - provider.NewMock(nil, []provider.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "2"}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schema.ErrNoModelAvailable}, {Msg: "2"}}), budget, *latConfig, 1, ), extmodel.NewLangModel( "second", - provider.NewMock(nil, []provider.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "1"}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schema.ErrNoModelAvailable}, {Msg: "1"}}), budget, *latConfig, 1, @@ -160,7 +159,7 @@ func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) { logger: telemetry.NewLoggerMock(), } - resp, err := router.Chat(context.Background(), schemas.NewChatFromStr("tell me a dad joke")) + resp, err := router.Chat(context.Background(), schema.NewChatFromStr("tell me a dad joke")) require.NoError(t, err) require.Equal(t, "first", resp.ModelID) @@ -204,7 +203,7 @@ func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) { } for i := 0; i < 2; i++ { - resp, err := router.Chat(context.Background(), schemas.NewChatFromStr("tell me a dad joke")) + resp, err := router.Chat(context.Background(), schema.NewChatFromStr("tell me a dad joke")) require.NoError(t, err) require.Equal(t, "second", resp.ModelID) @@ -218,14 +217,14 @@ func TestLangRouter_Chat_AllModelsUnavailable(t *testing.T) { langModels := []*extmodel.LanguageModel{ extmodel.NewLangModel( "first", - provider.NewMock(nil, []provider.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schema.ErrNoModelAvailable}, {Err: &schema.ErrNoModelAvailable}}), budget, *latConfig, 1, ), extmodel.NewLangModel( "second", - provider.NewMock(nil, []provider.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schema.ErrNoModelAvailable}, {Err: &schema.ErrNoModelAvailable}}), budget, *latConfig, 1, @@ -248,7 +247,7 @@ func TestLangRouter_Chat_AllModelsUnavailable(t *testing.T) { logger: telemetry.NewLoggerMock(), } - _, err := router.Chat(context.Background(), schemas.NewChatFromStr("tell me a dad joke")) + _, err := router.Chat(context.Background(), schema.NewChatFromStr("tell me a dad joke")) require.Error(t, err) } @@ -305,8 +304,8 @@ func TestLangRouter_ChatStream(t *testing.T) { } ctx := context.Background() - req := schemas.NewChatStreamFromStr("tell me a dad joke") - respC := make(chan *schemas.ChatStreamMessage) + req := schema.NewChatStreamFromStr("tell me a dad joke") + respC := make(chan *schema.ChatStreamMessage) defer close(respC) @@ -374,8 +373,8 @@ func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) { } ctx := context.Background() - req := schemas.NewChatStreamFromStr("tell me a dad joke") - respC := make(chan *schemas.ChatStreamMessage) + req := schema.NewChatStreamFromStr("tell me a dad joke") + respC := make(chan *schema.ChatStreamMessage) defer close(respC) @@ -442,10 +441,10 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { logger: telemetry.NewLoggerMock(), } - respC := make(chan *schemas.ChatStreamMessage) + respC := make(chan *schema.ChatStreamMessage) defer close(respC) - go router.ChatStream(context.Background(), schemas.NewChatStreamFromStr("tell me a dad joke"), respC) + go router.ChatStream(context.Background(), schema.NewChatStreamFromStr("tell me a dad joke"), respC) errs := make([]string, 0, 3) @@ -457,5 +456,5 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { errs = append(errs, result.Error.Name) } - require.Equal(t, []string{schemas.ModelUnavailable, schemas.ModelUnavailable, schemas.AllModelsUnavailable}, errs) + require.Equal(t, []string{schema.ModelUnavailable, schema.ModelUnavailable, schema.AllModelsUnavailable}, errs) } diff --git a/pkg/router/manager.go b/pkg/router/manager.go index b30afbcf..b0d2fd69 100644 --- a/pkg/router/manager.go +++ b/pkg/router/manager.go @@ -1,7 +1,7 @@ package router import ( - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" "github.com/EinStack/glide/pkg/telemetry" ) @@ -45,5 +45,5 @@ func (r *Manager) GetLangRouter(routerID string) (*LangRouter, error) { return router, nil } - return nil, &schemas.ErrRouterNotFound + return nil, &schema.ErrRouterNotFound }