Skip to content

Support Sorbet typed tools #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion lib/model_context_protocol/server.rb
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def call_tool(request)
end

begin
call_params = tool.method(:call).parameters.flatten
call_params = tool_call_parameters(tool)

if call_params.include?(:server_context)
tool.call(**arguments.transform_keys(&:to_sym), server_context:).to_h
else
Expand Down Expand Up @@ -258,5 +259,24 @@ def index_resources_by_uri(resources)
hash[resource.uri] = resource
end
end

def tool_call_parameters(tool)
method_def = tool_call_method_def(tool)
method_def.parameters.flatten
end

def tool_call_method_def(tool)
method = tool.method(:call)

if defined?(T::Utils) && T::Utils.respond_to?(:signature_for_method)
sorbet_typed_method_definition = T::Utils.signature_for_method(method)&.method

# Return the Sorbet typed method definition if it exists, otherwise fallback to original method
# definition if Sorbet is defined but not used by this tool.
sorbet_typed_method_definition || method
else
method
end
end
end
end
1 change: 1 addition & 0 deletions model_context_protocol.gemspec
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ Gem::Specification.new do |spec|

spec.add_dependency("json_rpc_handler", "~> 0.1")
spec.add_development_dependency("activesupport")
spec.add_development_dependency("sorbet-static-and-runtime")
end
38 changes: 38 additions & 0 deletions test/model_context_protocol/server_test.rb
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# typed: true
# frozen_string_literal: true

require "test_helper"
Expand Down Expand Up @@ -256,6 +257,43 @@ class ServerTest < ActiveSupport::TestCase
assert_instrumentation_data({ method: "tools/call", tool_name: })
end

test "#handle_json tools/call executes tool and returns result, when the tool is typed with Sorbet" do
class TypedTestTool < Tool
tool_name "test_tool"
description "a test tool for testing"
input_schema({ properties: { message: { type: "string" } }, required: ["message"] })

class << self
extend T::Sig

sig { params(message: String, server_context: T.nilable(T.untyped)).returns(Tool::Response) }
def call(message:, server_context: nil)
Tool::Response.new([{ type: "text", content: "OK" }])
end
end
end

request = JSON.generate({
jsonrpc: "2.0",
method: "tools/call",
params: { name: "test_tool", arguments: { message: "Hello, world!" } },
id: 1,
})

server = Server.new(
name: @server_name,
tools: [TypedTestTool],
prompts: [@prompt],
resources: [@resource],
resource_templates: [@resource_template],
)

raw_response = server.handle_json(request)
response = JSON.parse(raw_response, symbolize_names: true) if raw_response

assert_equal({ content: [{ type: "text", content: "OK" }], isError: false }, response[:result])
end

test "#handle tools/call returns internal error and reports exception if the tool raises an error" do
@server.configuration.exception_reporter.expects(:call).with do |exception, server_context|
assert_not_nil exception
Expand Down
27 changes: 25 additions & 2 deletions test/model_context_protocol/tool_test.rb
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# typed: true
# frozen_string_literal: true

require "test_helper"
Expand All @@ -17,7 +18,7 @@ class TestTool < Tool
)

class << self
def call(message, server_context: nil)
def call(message:, server_context: nil)
Tool::Response.new([{ type: "text", content: "OK" }])
end
end
Expand All @@ -42,7 +43,7 @@ def call(message, server_context: nil)

test "#call invokes the tool block and returns the response" do
tool = TestTool
response = tool.call("test")
response = tool.call(message: "test")
assert_equal response.content, [{ type: "text", content: "OK" }]
assert_equal response.is_error, false
end
Expand Down Expand Up @@ -203,5 +204,27 @@ class UpdatableAnnotationsTool < Tool
tool.annotations(title: "Updated")
assert_equal tool.annotations_value.title, "Updated"
end

test "#call with Sorbet typed tools invokes the tool block and returns the response" do
class TypedTestTool < Tool
tool_name "test_tool"
description "a test tool for testing"
input_schema({ properties: { message: { type: "string" } }, required: ["message"] })

class << self
extend T::Sig

sig { params(message: String, server_context: T.nilable(T.untyped)).returns(Tool::Response) }
def call(message:, server_context: nil)
Tool::Response.new([{ type: "text", content: "OK" }])
end
end
end

tool = TypedTestTool
response = tool.call(message: "test")
assert_equal response.content, [{ type: "text", content: "OK" }]
assert_equal response.is_error, false
end
end
end
2 changes: 2 additions & 0 deletions test/test_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
require "active_support"
require "active_support/test_case"

require "sorbet-runtime"

require_relative "instrumentation_test_helper"

Minitest::Reporters.use!(Minitest::Reporters::ProgressReporter.new)
Expand Down