diff --git a/go/plugins/ollama/ollama.go b/go/plugins/ollama/ollama.go index fcd97aa71..a0e8d5d82 100644 --- a/go/plugins/ollama/ollama.go +++ b/go/plugins/ollama/ollama.go @@ -40,10 +40,20 @@ const provider = "ollama" var ( mediaSupportedModels = []string{"llava", "bakllava", "llava-llama3", "llava:13b", "llava:7b", "llava:latest"} - roleMapping = map[ai.Role]string{ + toolSupportedModels = []string{ + "qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral", + "qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2", + "mistral-small", "command-r", "hermes3", "mistral-large", "command-r-plus", + "phi4-mini", "granite3.1-dense", "granite3-dense", "granite3.2", "athene-v2", + "nemotron-mini", "nemotron", "llama3-groq-tool-use", "aya-expanse", "granite3-moe", + "granite3.2-vision", "granite3.1-moe", "cogito", "command-r7b", "firefunction-v2", + "granite3.3", "command-a", "command-r7b-arabic", + } + roleMapping = map[ai.Role]string{ ai.RoleUser: "user", ai.RoleModel: "assistant", ai.RoleSystem: "system", + ai.RoleTool: "tool", } ) @@ -58,12 +68,15 @@ func (o *Ollama) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.M if info != nil { mi = *info } else { + // Check if the model supports tools (must be a chat model and in the supported list) + supportsTools := model.Type == "chat" && slices.Contains(toolSupportedModels, model.Name) mi = ai.ModelInfo{ Label: model.Name, Supports: &ai.ModelSupports{ Multiturn: true, SystemRole: true, Media: slices.Contains(mediaSupportedModels, model.Name), + Tools: supportsTools, }, Versions: []string{}, } @@ -100,9 +113,10 @@ type generator struct { } type ollamaMessage struct { - Role string `json:"role"` - Content string `json:"content"` - Images []string `json:"images,omitempty"` + Role string `json:"role"` + Content string `json:"content,omitempty"` + Images []string `json:"images,omitempty"` + ToolCalls []ollamaToolCall `json:"tool_calls,omitempty"` } // Ollama has two API endpoints, one with a chat interface and another with a generate response interface. @@ -125,6 +139,7 @@ type ollamaChatRequest struct { Model string `json:"model"` Stream bool `json:"stream"` Format string `json:"format,omitempty"` + Tools []ollamaTool `json:"tools,omitempty"` } type ollamaModelRequest struct { @@ -136,13 +151,38 @@ type ollamaModelRequest struct { Format string `json:"format,omitempty"` } +// Tool definition from Ollama API +type ollamaTool struct { + Type string `json:"type"` + Function ollamaFunction `json:"function"` +} + +// Function definition for Ollama API +type ollamaFunction struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]any `json:"parameters"` +} + +// Tool Call from Ollama API +type ollamaToolCall struct { + Function ollamaFunctionCall `json:"function"` +} + +// Function Call for Ollama API +type ollamaFunctionCall struct { + Name string `json:"name"` + Arguments any `json:"arguments"` +} + // TODO: Add optional parameters (images, format, options, etc.) based on your use case type ollamaChatResponse struct { Model string `json:"model"` CreatedAt string `json:"created_at"` Message struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ollamaToolCall `json:"tool_calls,omitempty"` } `json:"message"` } @@ -217,34 +257,47 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun } messages = append(messages, message) } - payload = ollamaChatRequest{ + chatReq := ollamaChatRequest{ Messages: messages, Model: g.model.Name, Stream: stream, Images: images, } + if len(input.Tools) > 0 { + tools, err := convertTools(input.Tools) + if err != nil { + return nil, fmt.Errorf("failed to convert tools: %v", err) + } + chatReq.Tools = tools + } + payload = chatReq } + client := &http.Client{Timeout: 30 * time.Second} payloadBytes, err := json.Marshal(payload) if err != nil { return nil, err } + // Determine the correct endpoint endpoint := g.serverAddress + "/api/chat" if !isChatModel { endpoint = g.serverAddress + "/api/generate" } + req, err := http.NewRequest("POST", endpoint, bytes.NewReader(payloadBytes)) if err != nil { return nil, fmt.Errorf("failed to create request: %v", err) } req.Header.Set("Content-Type", "application/json") req = req.WithContext(ctx) + resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %v", err) } defer resp.Body.Close() + if cb == nil { // Existing behavior for non-streaming responses var err error @@ -255,6 +308,7 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("server returned non-200 status: %d, body: %s", resp.StatusCode, body) } + var response *ai.ModelResponse if isChatModel { response, err = translateChatResponse(body) @@ -269,8 +323,12 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun } else { var chunks []*ai.ModelResponseChunk scanner := bufio.NewScanner(resp.Body) + chunkCount := 0 + for scanner.Scan() { line := scanner.Text() + chunkCount++ + var chunk *ai.ModelResponseChunk if isChatModel { chunk, err = translateChatChunk(line) @@ -283,9 +341,11 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun chunks = append(chunks, chunk) cb(ctx, chunk) } + if err := scanner.Err(); err != nil { return nil, fmt.Errorf("reading response stream: %v", err) } + // Create a final response with the merged chunks finalResponse := &ai.ModelResponse{ Request: input, @@ -303,13 +363,29 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun } } +// convertTools converts Genkit tool definitions to Ollama tool format +func convertTools(tools []*ai.ToolDefinition) ([]ollamaTool, error) { + ollamaTools := make([]ollamaTool, 0, len(tools)) + for _, tool := range tools { + ollamaTools = append(ollamaTools, ollamaTool{ + Type: "function", + Function: ollamaFunction{ + Name: tool.Name, + Description: tool.Description, + Parameters: tool.InputSchema, + }, + }) + } + return ollamaTools, nil +} + func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) { message := &ollamaMessage{ Role: roleMapping[role], } var contentBuilder strings.Builder + var toolCalls []ollamaToolCall var images []string - for _, part := range parts { if part.IsText() { contentBuilder.WriteString(part.Text) @@ -320,12 +396,30 @@ func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) { } base64Encoded := base64.StdEncoding.EncodeToString(data) images = append(images, base64Encoded) + } else if part.IsToolRequest() { + toolReq := part.ToolRequest + toolCalls = append(toolCalls, ollamaToolCall{ + Function: ollamaFunctionCall{ + Name: toolReq.Name, + Arguments: toolReq.Input, + }, + }) + } else if part.IsToolResponse() { + toolResp := part.ToolResponse + outputJSON, err := json.Marshal(toolResp.Output) + if err != nil { + return nil, fmt.Errorf("failed to marshal tool response: %v", err) + } + contentBuilder.WriteString(string(outputJSON)) } else { return nil, errors.New("unsupported content type") } } message.Content = contentBuilder.String() + if len(toolCalls) > 0 { + message.ToolCalls = toolCalls + } if len(images) > 0 { message.Images = images } @@ -342,17 +436,27 @@ func translateChatResponse(responseData []byte) (*ai.ModelResponse, error) { modelResponse := &ai.ModelResponse{ FinishReason: ai.FinishReason("stop"), Message: &ai.Message{ - Role: ai.Role(response.Message.Role), + Role: ai.RoleModel, }, } - - aiPart := ai.NewTextPart(response.Message.Content) - modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart) + if len(response.Message.ToolCalls) > 0 { + for _, toolCall := range response.Message.ToolCalls { + toolRequest := &ai.ToolRequest{ + Name: toolCall.Function.Name, + Input: toolCall.Function.Arguments, + } + toolPart := ai.NewToolRequestPart(toolRequest) + modelResponse.Message.Content = append(modelResponse.Message.Content, toolPart) + } + } else if response.Message.Content != "" { + aiPart := ai.NewTextPart(response.Message.Content) + modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart) + } return modelResponse, nil } -// translateResponse translates Ollama generate response into a genkit response. +// translateModelResponse translates Ollama generate response into a genkit response. func translateModelResponse(responseData []byte) (*ai.ModelResponse, error) { var response ollamaModelResponse @@ -380,8 +484,20 @@ func translateChatChunk(input string) (*ai.ModelResponseChunk, error) { return nil, fmt.Errorf("failed to parse response JSON: %v", err) } chunk := &ai.ModelResponseChunk{} - aiPart := ai.NewTextPart(response.Message.Content) - chunk.Content = append(chunk.Content, aiPart) + if len(response.Message.ToolCalls) > 0 { + for _, toolCall := range response.Message.ToolCalls { + toolRequest := &ai.ToolRequest{ + Name: toolCall.Function.Name, + Input: toolCall.Function.Arguments, + } + toolPart := ai.NewToolRequestPart(toolRequest) + chunk.Content = append(chunk.Content, toolPart) + } + } else if response.Message.Content != "" { + aiPart := ai.NewTextPart(response.Message.Content) + chunk.Content = append(chunk.Content, aiPart) + } + return chunk, nil } diff --git a/go/samples/ollama-tools/main.go b/go/samples/ollama-tools/main.go new file mode 100644 index 000000000..85d1c6c68 --- /dev/null +++ b/go/samples/ollama-tools/main.go @@ -0,0 +1,121 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "fmt" + "time" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/ollama" +) + +// WeatherInput defines the input structure for the weather tool +type WeatherInput struct { + Location string `json:"location"` +} + +// WeatherData represents weather information +type WeatherData struct { + Location string `json:"location"` + TempC float64 `json:"temp_c"` + TempF float64 `json:"temp_f"` + Condition string `json:"condition"` +} + +func main() { + ctx := context.Background() + + // Initialize Genkit with the Ollama plugin + ollamaPlugin := &ollama.Ollama{ + ServerAddress: "http://localhost:11434", // Default Ollama server address + } + + g, err := genkit.Init(ctx, genkit.WithPlugins(ollamaPlugin)) + if err != nil { + fmt.Printf("Failed to initialize Genkit: %v\n", err) + return + } + + // Define the Ollama model + model := ollamaPlugin.DefineModel(g, + ollama.ModelDefinition{ + Name: "llama3.1", // Choose an appropriate model + Type: "chat", // Must be chat for tool support + }, + nil) + + // Define tools + weatherTool := genkit.DefineTool(g, "weather", "Get current weather for a location", + func(ctx *ai.ToolContext, input WeatherInput) (WeatherData, error) { + // Get weather data (simulated) + return simulateWeather(input.Location), nil + }, + ) + + // Create system message + systemMsg := ai.NewSystemTextMessage( + "You are a helpful assistant that can look up weather. "+ + "When providing weather information, use the appropriate tool.") + + // Create user message + userMsg := ai.NewUserTextMessage("I'd like to know the weather in Tokyo.") + + // Generate response with tools + fmt.Println("Generating response with weather tool...") + + resp, err := genkit.Generate(ctx, g, + ai.WithModel(model), + ai.WithMessages(systemMsg, userMsg), + ai.WithTools(weatherTool), + ai.WithToolChoice(ai.ToolChoiceAuto), + ) + + if err != nil { + fmt.Printf("Error: %v\n", err) + return + } + + // Print the final response + fmt.Println("\n----- Final Response -----") + fmt.Printf("%s\n", resp.Text()) + fmt.Println("--------------------------") +} + +// simulateWeather returns simulated weather data for a location +func simulateWeather(location string) WeatherData { + // In a real app, this would call a weather API + // For demonstration, we'll return mock data + tempC := 22.5 + if location == "Tokyo" || location == "Tokyo, Japan" { + tempC = 24.0 + } else if location == "Paris" || location == "Paris, France" { + tempC = 18.5 + } else if location == "New York" || location == "New York, USA" { + tempC = 15.0 + } + + conditions := []string{"Sunny", "Partly Cloudy", "Cloudy", "Rainy", "Stormy"} + condition := conditions[time.Now().Unix()%int64(len(conditions))] + + return WeatherData{ + Location: location, + TempC: tempC, + TempF: tempC*9/5 + 32, + Condition: condition, + } +}