diff --git a/lib/model_context_protocol/server.rb b/lib/model_context_protocol/server.rb index f6de3ee..6461610 100644 --- a/lib/model_context_protocol/server.rb +++ b/lib/model_context_protocol/server.rb @@ -197,7 +197,8 @@ def call_tool(request) end begin - call_params = tool.method(:call).parameters.flatten + call_params = tool_call_parameters(tool) + if call_params.include?(:server_context) tool.call(**arguments.transform_keys(&:to_sym), server_context:).to_h else @@ -258,5 +259,24 @@ def index_resources_by_uri(resources) hash[resource.uri] = resource end end + + def tool_call_parameters(tool) + method_def = tool_call_method_def(tool) + method_def.parameters.flatten + end + + def tool_call_method_def(tool) + method = tool.method(:call) + + if defined?(T::Utils) && T::Utils.respond_to?(:signature_for_method) + sorbet_typed_method_definition = T::Utils.signature_for_method(method)&.method + + # Return the Sorbet typed method definition if it exists, otherwise fallback to original method + # definition if Sorbet is defined but not used by this tool. + sorbet_typed_method_definition || method + else + method + end + end end end diff --git a/model_context_protocol.gemspec b/model_context_protocol.gemspec index 0f8278d..44ee5c3 100644 --- a/model_context_protocol.gemspec +++ b/model_context_protocol.gemspec @@ -29,4 +29,5 @@ Gem::Specification.new do |spec| spec.add_dependency("json_rpc_handler", "~> 0.1") spec.add_development_dependency("activesupport") + spec.add_development_dependency("sorbet-static-and-runtime") end diff --git a/test/model_context_protocol/server_test.rb b/test/model_context_protocol/server_test.rb index 5890d6b..02d0a78 100644 --- a/test/model_context_protocol/server_test.rb +++ b/test/model_context_protocol/server_test.rb @@ -1,3 +1,4 @@ +# typed: true # frozen_string_literal: true require "test_helper" @@ -256,6 +257,43 @@ class ServerTest < ActiveSupport::TestCase assert_instrumentation_data({ method: "tools/call", tool_name: }) end + test "#handle_json tools/call executes tool and returns result, when the tool is typed with Sorbet" do + class TypedTestTool < Tool + tool_name "test_tool" + description "a test tool for testing" + input_schema({ properties: { message: { type: "string" } }, required: ["message"] }) + + class << self + extend T::Sig + + sig { params(message: String, server_context: T.nilable(T.untyped)).returns(Tool::Response) } + def call(message:, server_context: nil) + Tool::Response.new([{ type: "text", content: "OK" }]) + end + end + end + + request = JSON.generate({ + jsonrpc: "2.0", + method: "tools/call", + params: { name: "test_tool", arguments: { message: "Hello, world!" } }, + id: 1, + }) + + server = Server.new( + name: @server_name, + tools: [TypedTestTool], + prompts: [@prompt], + resources: [@resource], + resource_templates: [@resource_template], + ) + + raw_response = server.handle_json(request) + response = JSON.parse(raw_response, symbolize_names: true) if raw_response + + assert_equal({ content: [{ type: "text", content: "OK" }], isError: false }, response[:result]) + end + test "#handle tools/call returns internal error and reports exception if the tool raises an error" do @server.configuration.exception_reporter.expects(:call).with do |exception, server_context| assert_not_nil exception diff --git a/test/model_context_protocol/tool_test.rb b/test/model_context_protocol/tool_test.rb index b137832..1b96022 100644 --- a/test/model_context_protocol/tool_test.rb +++ b/test/model_context_protocol/tool_test.rb @@ -1,3 +1,4 @@ +# typed: true # frozen_string_literal: true require "test_helper" @@ -17,7 +18,7 @@ class TestTool < Tool ) class << self - def call(message, server_context: nil) + def call(message:, server_context: nil) Tool::Response.new([{ type: "text", content: "OK" }]) end end @@ -42,7 +43,7 @@ def call(message, server_context: nil) test "#call invokes the tool block and returns the response" do tool = TestTool - response = tool.call("test") + response = tool.call(message: "test") assert_equal response.content, [{ type: "text", content: "OK" }] assert_equal response.is_error, false end @@ -203,5 +204,27 @@ class UpdatableAnnotationsTool < Tool tool.annotations(title: "Updated") assert_equal tool.annotations_value.title, "Updated" end + + test "#call with Sorbet typed tools invokes the tool block and returns the response" do + class TypedTestTool < Tool + tool_name "test_tool" + description "a test tool for testing" + input_schema({ properties: { message: { type: "string" } }, required: ["message"] }) + + class << self + extend T::Sig + + sig { params(message: String, server_context: T.nilable(T.untyped)).returns(Tool::Response) } + def call(message:, server_context: nil) + Tool::Response.new([{ type: "text", content: "OK" }]) + end + end + end + + tool = TypedTestTool + response = tool.call(message: "test") + assert_equal response.content, [{ type: "text", content: "OK" }] + assert_equal response.is_error, false + end end end diff --git a/test/test_helper.rb b/test/test_helper.rb index e678470..0ba9013 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -13,6 +13,8 @@ require "active_support" require "active_support/test_case" +require "sorbet-runtime" + require_relative "instrumentation_test_helper" Minitest::Reporters.use!(Minitest::Reporters::ProgressReporter.new)