Skip to content

Initial implementation of google search for Gemini models #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions lib/ruby_llm/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ def with_tools(*tools)
self
end

def with_google_search
raise UnsupportedFunctionsError, "Model #{@model.id} doesn't support function calling" unless @model.supports_functions
raise UnsupportedFunctionsError, "Google search is only supported with Gemini models" unless @model.provider == 'gemini'

@tools = [{
google_search: {}
}]
self
end
Comment on lines +50 to +58
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chat.rb shouldn't have any code that's provider specific.

I think we should implement a provider overrides API, e.g. .with_provider_overrides({tools: :google_search}) This name is not set in stone, let's discuss naming and implementation.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think of :

in lib/ruby_llm/chat.rb

def with_provider_overrides(overrides)
      unless @model.supports_functions
        raise UnsupportedFunctionsError, "Model #{@model.id} doesn't support function calling"
      end

      case overrides
      when Hash
        @tools = @provider.format_provider_overrides(overrides)
      when Array
        @tools = overrides
      else
        raise ArgumentError, 'Provider overrides must be a Hash or Array'
      end
      self
end

and then in gemini.rb

def format_provider_overrides(overrides)
        case overrides
        when { tools: :google_search }
          [{ google_search: {} }]
        else
          raise ArgumentError, "Unsupported provider override: #{overrides}"
        end
end

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The best interface IMO would be chat.with_tool(RubyLLM::Providers::Gemini::GoogleSearchTool)


def model=(model_id)
@model = Models.find model_id
@provider = Models.provider_for model_id
Expand Down Expand Up @@ -109,6 +119,10 @@ def handle_tool_calls(response, &)
end

def execute_tool(tool_call)
if tool_call.name.to_sym == :google_search
return tool_call.arguments
end

Comment on lines 121 to +125
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, no provider specific code.

tool = tools[tool_call.name.to_sym]
args = tool_call.arguments
tool.call(args)
Expand Down
6 changes: 6 additions & 0 deletions lib/ruby_llm/providers/gemini/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def complete(messages, tools:, temperature:, model:, &block) # rubocop:disable M
end
end

def with_google_search
@tools ||= []
@tools << { google_search: {} }
self
end

# Format methods can be private
private

Expand Down
14 changes: 11 additions & 3 deletions lib/ruby_llm/providers/gemini/tools.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,17 @@ module Tools
def format_tools(tools)
return [] if tools.empty?

[{
functionDeclarations: tools.values.map { |tool| function_declaration_for(tool) }
}]
formatted_tools = tools.map do |tool|
if tool.is_a?(Hash) && tool.key?(:google_search)
tool
else
{
functionDeclarations: [function_declaration_for(tool)]
}
end
end

formatted_tools.flatten
end

# Extract tool calls from response data
Expand Down