Skip to content

feat(go): Add tool support for ollama models #2796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 131 additions & 15 deletions go/plugins/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,20 @@ const provider = "ollama"

var (
mediaSupportedModels = []string{"llava", "bakllava", "llava-llama3", "llava:13b", "llava:7b", "llava:latest"}
roleMapping = map[ai.Role]string{
toolSupportedModels = []string{
"qwq", "mistral-small3.1", "llama3.3", "llama3.2", "llama3.1", "mistral",
"qwen2.5", "qwen2.5-coder", "qwen2", "mistral-nemo", "mixtral", "smollm2",
"mistral-small", "command-r", "hermes3", "mistral-large", "command-r-plus",
"phi4-mini", "granite3.1-dense", "granite3-dense", "granite3.2", "athene-v2",
"nemotron-mini", "nemotron", "llama3-groq-tool-use", "aya-expanse", "granite3-moe",
"granite3.2-vision", "granite3.1-moe", "cogito", "command-r7b", "firefunction-v2",
"granite3.3", "command-a", "command-r7b-arabic",
}
roleMapping = map[ai.Role]string{
ai.RoleUser: "user",
ai.RoleModel: "assistant",
ai.RoleSystem: "system",
ai.RoleTool: "tool",
}
)

Expand All @@ -58,12 +68,15 @@ func (o *Ollama) DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.M
if info != nil {
mi = *info
} else {
// Check if the model supports tools (must be a chat model and in the supported list)
supportsTools := model.Type == "chat" && slices.Contains(toolSupportedModels, model.Name)
mi = ai.ModelInfo{
Label: model.Name,
Supports: &ai.ModelSupports{
Multiturn: true,
SystemRole: true,
Media: slices.Contains(mediaSupportedModels, model.Name),
Tools: supportsTools,
},
Versions: []string{},
}
Expand Down Expand Up @@ -100,9 +113,10 @@ type generator struct {
}

type ollamaMessage struct {
Role string `json:"role"`
Content string `json:"content"`
Images []string `json:"images,omitempty"`
Role string `json:"role"`
Content string `json:"content,omitempty"`
Images []string `json:"images,omitempty"`
ToolCalls []ollamaToolCall `json:"tool_calls,omitempty"`
}

// Ollama has two API endpoints, one with a chat interface and another with a generate response interface.
Expand All @@ -125,6 +139,7 @@ type ollamaChatRequest struct {
Model string `json:"model"`
Stream bool `json:"stream"`
Format string `json:"format,omitempty"`
Tools []ollamaTool `json:"tools,omitempty"`
}

type ollamaModelRequest struct {
Expand All @@ -136,13 +151,38 @@ type ollamaModelRequest struct {
Format string `json:"format,omitempty"`
}

// Tool definition from Ollama API
type ollamaTool struct {
Type string `json:"type"`
Function ollamaFunction `json:"function"`
}

// Function definition for Ollama API
type ollamaFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters map[string]any `json:"parameters"`
}

// Tool Call from Ollama API
type ollamaToolCall struct {
Function ollamaFunctionCall `json:"function"`
}

// Function Call for Ollama API
type ollamaFunctionCall struct {
Name string `json:"name"`
Arguments any `json:"arguments"`
}

// TODO: Add optional parameters (images, format, options, etc.) based on your use case
type ollamaChatResponse struct {
Model string `json:"model"`
CreatedAt string `json:"created_at"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
Role string `json:"role"`
Content string `json:"content"`
ToolCalls []ollamaToolCall `json:"tool_calls,omitempty"`
} `json:"message"`
}

Expand Down Expand Up @@ -217,34 +257,47 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
}
messages = append(messages, message)
}
payload = ollamaChatRequest{
chatReq := ollamaChatRequest{
Messages: messages,
Model: g.model.Name,
Stream: stream,
Images: images,
}
if len(input.Tools) > 0 {
tools, err := convertTools(input.Tools)
if err != nil {
return nil, fmt.Errorf("failed to convert tools: %v", err)
}
chatReq.Tools = tools
}
payload = chatReq
}

client := &http.Client{Timeout: 30 * time.Second}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, err
}

// Determine the correct endpoint
endpoint := g.serverAddress + "/api/chat"
if !isChatModel {
endpoint = g.serverAddress + "/api/generate"
}

req, err := http.NewRequest("POST", endpoint, bytes.NewReader(payloadBytes))
if err != nil {
return nil, fmt.Errorf("failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
req = req.WithContext(ctx)

resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to send request: %v", err)
}
defer resp.Body.Close()

if cb == nil {
// Existing behavior for non-streaming responses
var err error
Expand All @@ -255,6 +308,7 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("server returned non-200 status: %d, body: %s", resp.StatusCode, body)
}

var response *ai.ModelResponse
if isChatModel {
response, err = translateChatResponse(body)
Expand All @@ -269,8 +323,12 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
} else {
var chunks []*ai.ModelResponseChunk
scanner := bufio.NewScanner(resp.Body)
chunkCount := 0

for scanner.Scan() {
line := scanner.Text()
chunkCount++

var chunk *ai.ModelResponseChunk
if isChatModel {
chunk, err = translateChatChunk(line)
Expand All @@ -283,9 +341,11 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
chunks = append(chunks, chunk)
cb(ctx, chunk)
}

if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("reading response stream: %v", err)
}

// Create a final response with the merged chunks
finalResponse := &ai.ModelResponse{
Request: input,
Expand All @@ -303,13 +363,29 @@ func (g *generator) generate(ctx context.Context, input *ai.ModelRequest, cb fun
}
}

// convertTools converts Genkit tool definitions to Ollama tool format
func convertTools(tools []*ai.ToolDefinition) ([]ollamaTool, error) {
ollamaTools := make([]ollamaTool, 0, len(tools))
for _, tool := range tools {
ollamaTools = append(ollamaTools, ollamaTool{
Type: "function",
Function: ollamaFunction{
Name: tool.Name,
Description: tool.Description,
Parameters: tool.InputSchema,
},
})
}
return ollamaTools, nil
}

func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) {
message := &ollamaMessage{
Role: roleMapping[role],
}
var contentBuilder strings.Builder
var toolCalls []ollamaToolCall
var images []string

for _, part := range parts {
if part.IsText() {
contentBuilder.WriteString(part.Text)
Expand All @@ -320,12 +396,30 @@ func convertParts(role ai.Role, parts []*ai.Part) (*ollamaMessage, error) {
}
base64Encoded := base64.StdEncoding.EncodeToString(data)
images = append(images, base64Encoded)
} else if part.IsToolRequest() {
toolReq := part.ToolRequest
toolCalls = append(toolCalls, ollamaToolCall{
Function: ollamaFunctionCall{
Name: toolReq.Name,
Arguments: toolReq.Input,
},
})
} else if part.IsToolResponse() {
toolResp := part.ToolResponse
outputJSON, err := json.Marshal(toolResp.Output)
if err != nil {
return nil, fmt.Errorf("failed to marshal tool response: %v", err)
}
contentBuilder.WriteString(string(outputJSON))
} else {
return nil, errors.New("unsupported content type")
}
}

message.Content = contentBuilder.String()
if len(toolCalls) > 0 {
message.ToolCalls = toolCalls
}
if len(images) > 0 {
message.Images = images
}
Expand All @@ -342,17 +436,27 @@ func translateChatResponse(responseData []byte) (*ai.ModelResponse, error) {
modelResponse := &ai.ModelResponse{
FinishReason: ai.FinishReason("stop"),
Message: &ai.Message{
Role: ai.Role(response.Message.Role),
Role: ai.RoleModel,
},
}

aiPart := ai.NewTextPart(response.Message.Content)
modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart)
if len(response.Message.ToolCalls) > 0 {
for _, toolCall := range response.Message.ToolCalls {
toolRequest := &ai.ToolRequest{
Name: toolCall.Function.Name,
Input: toolCall.Function.Arguments,
}
toolPart := ai.NewToolRequestPart(toolRequest)
modelResponse.Message.Content = append(modelResponse.Message.Content, toolPart)
}
} else if response.Message.Content != "" {
aiPart := ai.NewTextPart(response.Message.Content)
modelResponse.Message.Content = append(modelResponse.Message.Content, aiPart)
}

return modelResponse, nil
}

// translateResponse translates Ollama generate response into a genkit response.
// translateModelResponse translates Ollama generate response into a genkit response.
func translateModelResponse(responseData []byte) (*ai.ModelResponse, error) {
var response ollamaModelResponse

Expand Down Expand Up @@ -380,8 +484,20 @@ func translateChatChunk(input string) (*ai.ModelResponseChunk, error) {
return nil, fmt.Errorf("failed to parse response JSON: %v", err)
}
chunk := &ai.ModelResponseChunk{}
aiPart := ai.NewTextPart(response.Message.Content)
chunk.Content = append(chunk.Content, aiPart)
if len(response.Message.ToolCalls) > 0 {
for _, toolCall := range response.Message.ToolCalls {
toolRequest := &ai.ToolRequest{
Name: toolCall.Function.Name,
Input: toolCall.Function.Arguments,
}
toolPart := ai.NewToolRequestPart(toolRequest)
chunk.Content = append(chunk.Content, toolPart)
}
} else if response.Message.Content != "" {
aiPart := ai.NewTextPart(response.Message.Content)
chunk.Content = append(chunk.Content, aiPart)
}

return chunk, nil
}

Expand Down
Loading
Loading