diff --git a/.gitignore b/.gitignore index 750a5a9..521721a 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,6 @@ /spec/reports/ /tmp/ Gemfile.lock + +# Mac stuff +.DS_Store diff --git a/Gemfile b/Gemfile index ecf15d1..e087018 100644 --- a/Gemfile +++ b/Gemfile @@ -13,3 +13,7 @@ gem "rubocop-shopify", require: false gem "minitest-reporters" gem "mocha" gem "debug" + +group :test do + gem "webmock" +end diff --git a/README.md b/README.md index 00acd40..df8390d 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ If you want to build a local command-line application, you can use the stdio tra ```ruby #!/usr/bin/env ruby require "model_context_protocol" -require "model_context_protocol/transports/stdio" +require "model_context_protocol/server/transports/stdio" # Create a simple tool class ExampleTool < ModelContextProtocol::Tool @@ -97,7 +97,7 @@ server = ModelContextProtocol::Server.new( ) # Create and start the transport -transport = ModelContextProtocol::Transports::StdioTransport.new(server) +transport = ModelContextProtocol::Server::Transports::StdioTransport.new(server) transport.open ``` @@ -110,6 +110,65 @@ $ ./stdio_server.rb {"jsonrpc":"2.0","id":"3","result":["ExampleTool"]} ``` +## MCP Client + +The `ModelContextProtocol::Client` module provides client implementations for interacting with MCP servers. Currently, it supports HTTP transport for making JSON-RPC requests to MCP servers. + +### HTTP Client + +The `ModelContextProtocol::Client::Http` class provides a simple HTTP client for interacting with MCP servers: + +```ruby +client = ModelContextProtocol::Client::Http.new(url: "https://api.example.com/mcp") + +# List available tools +tools = client.tools +tools.each do |tool| + puts "Tool: #{tool.name}" + puts "Description: #{tool.description}" + puts "Input Schema: #{tool.input_schema}" +end + +# Call a specific tool +response = client.call_tool( + tool: tools.first, + input: { message: "Hello, world!" } +) +``` + +The HTTP client supports: +- Tool listing via the `tools/list` method +- Tool invocation via the `tools/call` method +- Automatic JSON-RPC 2.0 message formatting +- UUID v7 request ID generation +- Setting headers for things like authorization + +### HTTP Authorization + +By default, the HTTP client has no authentication, but it supports custom headers for authentication. For example, to use Bearer token authentication: + +```ruby +client = ModelContextProtocol::Client::Http.new( + url: "https://api.example.com/mcp", + headers: { + "Authorization" => "Bearer my_token" + } +) + +client.tools # will make the call using Bearer auth +``` + +You can add any custom headers needed for your authentication scheme. The client will include these headers in all requests. + +### Tool Objects + +The client provides wrapper objects for tools returned by the server: + +- `ModelContextProtocol::Client::Tool` - Represents a single tool with its metadata +- `ModelContextProtocol::Client::Tools` - Collection of tools with enumerable functionality + +These objects provide easy access to tool properties like name, description, and input schema. + ## Configuration The gem can be configured using the `ModelContextProtocol.configure` block: diff --git a/examples/stdio_server.rb b/examples/stdio_server.rb index 631822e..bf29eb3 100755 --- a/examples/stdio_server.rb +++ b/examples/stdio_server.rb @@ -3,7 +3,7 @@ $LOAD_PATH.unshift(File.expand_path("../lib", __dir__)) require "model_context_protocol" -require "model_context_protocol/transports/stdio" +require "model_context_protocol/server/transports/stdio" # Create a simple tool class ExampleTool < MCP::Tool @@ -91,5 +91,5 @@ def template(args, server_context:) end # Create and start the transport -transport = MCP::Transports::StdioTransport.new(server) +transport = ModelContextProtocol::Server::Transports::StdioTransport.new(server) transport.open diff --git a/lib/model_context_protocol.rb b/lib/model_context_protocol.rb index 02a7e98..dab8c70 100644 --- a/lib/model_context_protocol.rb +++ b/lib/model_context_protocol.rb @@ -1,23 +1,36 @@ # typed: strict # frozen_string_literal: true +require_relative "model_context_protocol/shared/version" +require_relative "model_context_protocol/shared/configuration" +require_relative "model_context_protocol/shared/instrumentation" +require_relative "model_context_protocol/shared/methods" +require_relative "model_context_protocol/shared/transport" +require_relative "model_context_protocol/shared/content" +require_relative "model_context_protocol/shared/string_utils" + +require_relative "model_context_protocol/shared/resource" +require_relative "model_context_protocol/shared/resource/contents" +require_relative "model_context_protocol/shared/resource/embedded" +require_relative "model_context_protocol/shared/resource_template" + +require_relative "model_context_protocol/shared/tool" +require_relative "model_context_protocol/shared/tool/input_schema" +require_relative "model_context_protocol/shared/tool/response" +require_relative "model_context_protocol/shared/tool/annotations" + +require_relative "model_context_protocol/shared/prompt" +require_relative "model_context_protocol/shared/prompt/argument" +require_relative "model_context_protocol/shared/prompt/message" +require_relative "model_context_protocol/shared/prompt/result" + require_relative "model_context_protocol/server" -require_relative "model_context_protocol/string_utils" -require_relative "model_context_protocol/tool" -require_relative "model_context_protocol/tool/input_schema" -require_relative "model_context_protocol/tool/annotations" -require_relative "model_context_protocol/tool/response" -require_relative "model_context_protocol/content" -require_relative "model_context_protocol/resource" -require_relative "model_context_protocol/resource/contents" -require_relative "model_context_protocol/resource/embedded" -require_relative "model_context_protocol/resource_template" -require_relative "model_context_protocol/prompt" -require_relative "model_context_protocol/prompt/argument" -require_relative "model_context_protocol/prompt/message" -require_relative "model_context_protocol/prompt/result" -require_relative "model_context_protocol/version" -require_relative "model_context_protocol/configuration" +require_relative "model_context_protocol/server/transports/stdio" + +require_relative "model_context_protocol/client" +require_relative "model_context_protocol/client/http" +require_relative "model_context_protocol/client/tools" +require_relative "model_context_protocol/client/tool" module ModelContextProtocol class << self diff --git a/lib/model_context_protocol/client.rb b/lib/model_context_protocol/client.rb new file mode 100644 index 0000000..bd454ca --- /dev/null +++ b/lib/model_context_protocol/client.rb @@ -0,0 +1,22 @@ +# frozen_string_literal: true + +# require "json_rpc_handler" +# require_relative "shared/instrumentation" +# require_relative "shared/methods" + +module ModelContextProtocol + module Client + # Can be made an abstract class if we need shared behavior + + class RequestHandlerError < StandardError + attr_reader :error_type, :original_error, :request + + def initialize(message, request, error_type: :internal_error, original_error: nil) + super(message) + @request = request + @error_type = error_type + @original_error = original_error + end + end + end +end diff --git a/lib/model_context_protocol/client/http.rb b/lib/model_context_protocol/client/http.rb new file mode 100644 index 0000000..05f663d --- /dev/null +++ b/lib/model_context_protocol/client/http.rb @@ -0,0 +1,107 @@ +# frozen_string_literal: true + +module ModelContextProtocol + module Client + class Http + DEFAULT_VERSION = "0.1.0" + + attr_reader :url, :version + + def initialize(url:, version: DEFAULT_VERSION, headers: {}) + @url = url + @version = version + @headers = headers + end + + def tools + response = make_request(method: "tools/list").body + + ::ModelContextProtocol::Client::Tools.new(response) + end + + def call_tool(tool:, input:) + response = make_request( + method: "tools/call", + params: { name: tool.name, arguments: input }, + ).body + + response.dig("result", "content", 0, "text") + end + + private + + attr_reader :headers + + def client + @client ||= Faraday.new(url) do |faraday| + faraday.request(:json) + faraday.response(:json) + faraday.response(:raise_error) + + headers.each do |key, value| + faraday.headers[key] = value + end + end + end + + def make_request(method:, params: nil) + client.post( + "", + { + jsonrpc: "2.0", + id: request_id, + method:, + params:, + mcp: { jsonrpc: "2.0", id: request_id, method:, params: }.compact, + }.compact, + ) + rescue Faraday::BadRequestError => e + raise RequestHandlerError.new( + "The #{method} request is invalid", + { method:, params: }, + error_type: :bad_request, + original_error: e, + ) + rescue Faraday::UnauthorizedError => e + raise RequestHandlerError.new( + "You are unauthorized to make #{method} requests", + { method:, params: }, + error_type: :unauthorized, + original_error: e, + ) + rescue Faraday::ForbiddenError => e + raise RequestHandlerError.new( + "You are forbidden to make #{method} requests", + { method:, params: }, + error_type: :forbidden, + original_error: e, + ) + rescue Faraday::ResourceNotFound => e + raise RequestHandlerError.new( + "The #{method} request is not found", + { method:, params: }, + error_type: :not_found, + original_error: e, + ) + rescue Faraday::UnprocessableEntityError => e + raise RequestHandlerError.new( + "The #{method} request is unprocessable", + { method:, params: }, + error_type: :unprocessable_entity, + original_error: e, + ) + rescue Faraday::Error => e # Catch-all + raise RequestHandlerError.new( + "Internal error handling #{method} request", + { method:, params: }, + error_type: :internal_error, + original_error: e, + ) + end + + def request_id + SecureRandom.uuid_v7 + end + end + end +end diff --git a/lib/model_context_protocol/client/tool.rb b/lib/model_context_protocol/client/tool.rb new file mode 100644 index 0000000..156a5b7 --- /dev/null +++ b/lib/model_context_protocol/client/tool.rb @@ -0,0 +1,26 @@ +# typed: false +# frozen_string_literal: true + +module ModelContextProtocol + module Client + class Tool + attr_reader :payload + + def initialize(payload) + @payload = payload + end + + def name + payload["name"] + end + + def description + payload["description"] + end + + def input_schema + payload["inputSchema"] + end + end + end +end diff --git a/lib/model_context_protocol/client/tools.rb b/lib/model_context_protocol/client/tools.rb new file mode 100644 index 0000000..a63f33f --- /dev/null +++ b/lib/model_context_protocol/client/tools.rb @@ -0,0 +1,30 @@ +# typed: false +# frozen_string_literal: true + +module ModelContextProtocol + module Client + class Tools + include Enumerable + + attr_reader :response + + def initialize(response) + @response = response + end + + def each(&block) + tools.each(&block) + end + + def all + tools + end + + private + + def tools + @tools ||= @response.dig("result", "tools")&.map { |tool| Tool.new(tool) } || [] + end + end + end +end diff --git a/lib/model_context_protocol/server.rb b/lib/model_context_protocol/server.rb index 0bde655..e1932bb 100644 --- a/lib/model_context_protocol/server.rb +++ b/lib/model_context_protocol/server.rb @@ -1,8 +1,8 @@ # frozen_string_literal: true require "json_rpc_handler" -require_relative "instrumentation" -require_relative "methods" +require_relative "shared/instrumentation" +require_relative "shared/methods" module ModelContextProtocol class Server diff --git a/lib/model_context_protocol/server/transports/stdio.rb b/lib/model_context_protocol/server/transports/stdio.rb new file mode 100644 index 0000000..097b87c --- /dev/null +++ b/lib/model_context_protocol/server/transports/stdio.rb @@ -0,0 +1,37 @@ +# frozen_string_literal: true + +require_relative "../../shared/transport" +require "json" + +module ModelContextProtocol + class Server + module Transports + class StdioTransport < Transport + def initialize(server) + @server = server + @open = false + $stdin.set_encoding("UTF-8") + $stdout.set_encoding("UTF-8") + super + end + + def open + @open = true + while @open && (line = $stdin.gets) + handle_json_request(line.strip) + end + end + + def close + @open = false + end + + def send_response(message) + json_message = message.is_a?(String) ? message : JSON.generate(message) + $stdout.puts(json_message) + $stdout.flush + end + end + end + end +end diff --git a/lib/model_context_protocol/configuration.rb b/lib/model_context_protocol/shared/configuration.rb similarity index 100% rename from lib/model_context_protocol/configuration.rb rename to lib/model_context_protocol/shared/configuration.rb diff --git a/lib/model_context_protocol/content.rb b/lib/model_context_protocol/shared/content.rb similarity index 100% rename from lib/model_context_protocol/content.rb rename to lib/model_context_protocol/shared/content.rb diff --git a/lib/model_context_protocol/instrumentation.rb b/lib/model_context_protocol/shared/instrumentation.rb similarity index 100% rename from lib/model_context_protocol/instrumentation.rb rename to lib/model_context_protocol/shared/instrumentation.rb diff --git a/lib/model_context_protocol/methods.rb b/lib/model_context_protocol/shared/methods.rb similarity index 100% rename from lib/model_context_protocol/methods.rb rename to lib/model_context_protocol/shared/methods.rb diff --git a/lib/model_context_protocol/prompt.rb b/lib/model_context_protocol/shared/prompt.rb similarity index 100% rename from lib/model_context_protocol/prompt.rb rename to lib/model_context_protocol/shared/prompt.rb diff --git a/lib/model_context_protocol/prompt/argument.rb b/lib/model_context_protocol/shared/prompt/argument.rb similarity index 100% rename from lib/model_context_protocol/prompt/argument.rb rename to lib/model_context_protocol/shared/prompt/argument.rb diff --git a/lib/model_context_protocol/prompt/message.rb b/lib/model_context_protocol/shared/prompt/message.rb similarity index 100% rename from lib/model_context_protocol/prompt/message.rb rename to lib/model_context_protocol/shared/prompt/message.rb diff --git a/lib/model_context_protocol/prompt/result.rb b/lib/model_context_protocol/shared/prompt/result.rb similarity index 100% rename from lib/model_context_protocol/prompt/result.rb rename to lib/model_context_protocol/shared/prompt/result.rb diff --git a/lib/model_context_protocol/resource.rb b/lib/model_context_protocol/shared/resource.rb similarity index 100% rename from lib/model_context_protocol/resource.rb rename to lib/model_context_protocol/shared/resource.rb diff --git a/lib/model_context_protocol/resource/contents.rb b/lib/model_context_protocol/shared/resource/contents.rb similarity index 100% rename from lib/model_context_protocol/resource/contents.rb rename to lib/model_context_protocol/shared/resource/contents.rb diff --git a/lib/model_context_protocol/resource/embedded.rb b/lib/model_context_protocol/shared/resource/embedded.rb similarity index 100% rename from lib/model_context_protocol/resource/embedded.rb rename to lib/model_context_protocol/shared/resource/embedded.rb diff --git a/lib/model_context_protocol/resource_template.rb b/lib/model_context_protocol/shared/resource_template.rb similarity index 100% rename from lib/model_context_protocol/resource_template.rb rename to lib/model_context_protocol/shared/resource_template.rb diff --git a/lib/model_context_protocol/string_utils.rb b/lib/model_context_protocol/shared/string_utils.rb similarity index 100% rename from lib/model_context_protocol/string_utils.rb rename to lib/model_context_protocol/shared/string_utils.rb diff --git a/lib/model_context_protocol/tool.rb b/lib/model_context_protocol/shared/tool.rb similarity index 100% rename from lib/model_context_protocol/tool.rb rename to lib/model_context_protocol/shared/tool.rb diff --git a/lib/model_context_protocol/tool/annotations.rb b/lib/model_context_protocol/shared/tool/annotations.rb similarity index 100% rename from lib/model_context_protocol/tool/annotations.rb rename to lib/model_context_protocol/shared/tool/annotations.rb diff --git a/lib/model_context_protocol/tool/input_schema.rb b/lib/model_context_protocol/shared/tool/input_schema.rb similarity index 100% rename from lib/model_context_protocol/tool/input_schema.rb rename to lib/model_context_protocol/shared/tool/input_schema.rb diff --git a/lib/model_context_protocol/tool/response.rb b/lib/model_context_protocol/shared/tool/response.rb similarity index 100% rename from lib/model_context_protocol/tool/response.rb rename to lib/model_context_protocol/shared/tool/response.rb diff --git a/lib/model_context_protocol/transport.rb b/lib/model_context_protocol/shared/transport.rb similarity index 100% rename from lib/model_context_protocol/transport.rb rename to lib/model_context_protocol/shared/transport.rb diff --git a/lib/model_context_protocol/version.rb b/lib/model_context_protocol/shared/version.rb similarity index 75% rename from lib/model_context_protocol/version.rb rename to lib/model_context_protocol/shared/version.rb index c2e2323..80a6c26 100644 --- a/lib/model_context_protocol/version.rb +++ b/lib/model_context_protocol/shared/version.rb @@ -1,5 +1,5 @@ # frozen_string_literal: true module ModelContextProtocol - VERSION = "0.7.0" + VERSION = "1.0.0" end diff --git a/lib/model_context_protocol/transports/stdio.rb b/lib/model_context_protocol/transports/stdio.rb deleted file mode 100644 index f72508b..0000000 --- a/lib/model_context_protocol/transports/stdio.rb +++ /dev/null @@ -1,35 +0,0 @@ -# frozen_string_literal: true - -require_relative "../transport" -require "json" - -module ModelContextProtocol - module Transports - class StdioTransport < Transport - def initialize(server) - @server = server - @open = false - $stdin.set_encoding("UTF-8") - $stdout.set_encoding("UTF-8") - super - end - - def open - @open = true - while @open && (line = $stdin.gets) - handle_json_request(line.strip) - end - end - - def close - @open = false - end - - def send_response(message) - json_message = message.is_a?(String) ? message : JSON.generate(message) - $stdout.puts(json_message) - $stdout.flush - end - end - end -end diff --git a/model_context_protocol.gemspec b/model_context_protocol.gemspec index 44ee5c3..6014bb5 100644 --- a/model_context_protocol.gemspec +++ b/model_context_protocol.gemspec @@ -1,6 +1,6 @@ # frozen_string_literal: true -require_relative "lib/model_context_protocol/version" +require_relative "lib/model_context_protocol/shared/version" Gem::Specification.new do |spec| spec.name = "model_context_protocol" @@ -27,6 +27,7 @@ Gem::Specification.new do |spec| spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) } spec.require_paths = ["lib"] + spec.add_dependency("faraday", ">= 2.0") spec.add_dependency("json_rpc_handler", "~> 0.1") spec.add_development_dependency("activesupport") spec.add_development_dependency("sorbet-static-and-runtime") diff --git a/test/model_context_protocol/client/http_test.rb b/test/model_context_protocol/client/http_test.rb new file mode 100644 index 0000000..d5336e1 --- /dev/null +++ b/test/model_context_protocol/client/http_test.rb @@ -0,0 +1,477 @@ +# frozen_string_literal: true + +require "test_helper" +require "faraday" +require "securerandom" +require "webmock/minitest" + +module ModelContextProtocol + module Client + class HttpTest < Minitest::Test + def test_initialization_with_default_version + assert_equal("0.1.0", client.version) + assert_equal(url, client.url) + end + + def test_initialization_with_custom_version + custom_version = "1.2.3" + client = Http.new(url:, version: custom_version) + assert_equal(custom_version, client.version) + end + + def test_headers_are_added_to_the_request + headers = { "Authorization" => "Bearer token" } + client = Http.new(url:, headers:) + client.stubs(:request_id).returns(mock_request_id) + + stub_request(:post, url) + .with( + headers: { + "Authorization" => "Bearer token", + "Content-Type" => "application/json", + }, + body: { + method: "tools/list", + jsonrpc: "2.0", + id: mock_request_id, + mcp: { + method: "tools/list", + jsonrpc: "2.0", + id: mock_request_id, + }, + }, + ) + .to_return( + status: 200, + headers: { "Content-Type" => "application/json" }, + body: { result: { tools: [] } }.to_json, + ) + + # The test passes if the request is made with the correct headers + # If headers are wrong, the stub_request won't match and will raise + client.tools + end + + def test_tools_returns_tools_instance + stub_request(:post, url) + .with( + body: { + method: "tools/list", + jsonrpc: "2.0", + id: mock_request_id, + mcp: { + method: "tools/list", + jsonrpc: "2.0", + id: mock_request_id, + }, + }, + ) + .to_return( + status: 200, + headers: { + "Content-Type" => "application/json", + }, + body: { + result: { + tools: [ + { + name: "test_tool", + description: "A test tool", + inputSchema: { + type: "object", + properties: {}, + }, + }, + ], + }, + }.to_json, + ) + + tools = client.tools + assert_instance_of(Tools, tools) + assert_equal(1, tools.count) + assert_equal("test_tool", tools.first.name) + end + + def test_call_tool_returns_tool_response + tool = Tool.new( + "name" => "test_tool", + "description" => "A test tool", + "inputSchema" => { + "type" => "object", + "properties" => {}, + }, + ) + input = { "param" => "value" } + + stub_request(:post, url) + .with( + body: { + jsonrpc: "2.0", + id: mock_request_id, + method: "tools/call", + params: { + name: "test_tool", + arguments: input, + }, + mcp: { + jsonrpc: "2.0", + id: mock_request_id, + method: "tools/call", + params: { + name: "test_tool", + arguments: input, + }, + }, + }, + ) + .to_return( + status: 200, + headers: { + "Content-Type" => "application/json", + }, + body: { + result: { + content: [ + { + text: "Tool response", + }, + ], + }, + }.to_json, + ) + + response = client.call_tool(tool: tool, input: input) + assert_equal("Tool response", response) + end + + def test_call_tool_handles_empty_response + tool = Tool.new( + "name" => "test_tool", + "description" => "A test tool", + "inputSchema" => { + "type" => "object", + "properties" => {}, + }, + ) + input = { "param" => "value" } + + stub_request(:post, url) + .with( + body: { + jsonrpc: "2.0", + id: mock_request_id, + method: "tools/call", + params: { + name: "test_tool", + arguments: input, + }, + mcp: { + jsonrpc: "2.0", + id: mock_request_id, + method: "tools/call", + params: { + name: "test_tool", + arguments: input, + }, + }, + }, + ) + .to_return( + status: 200, + headers: { + "Content-Type" => "application/json", + }, + body: { + result: { + content: [], + }, + }.to_json, + ) + + response = client.call_tool(tool: tool, input: input) + assert_nil(response) + end + + def test_raises_bad_request_error + tool = Tool.new( + "name" => "test_tool", + "description" => "A test tool", + "inputSchema" => { + "type" => "object", + "properties" => {}, + }, + ) + input = { "param" => "value" } + + stub_request(:post, url) + .with( + body: { + jsonrpc: "2.0", + id: mock_request_id, + method: "tools/call", + params: { + name: "test_tool", + arguments: input, + }, + mcp: { + jsonrpc: "2.0", + id: mock_request_id, + method: "tools/call", + params: { + name: "test_tool", + arguments: input, + }, + }, + }, + ) + .to_return(status: 400) + + error = assert_raises(RequestHandlerError) do + client.call_tool(tool: tool, input: input) + end + + assert_equal("The tools/call request is invalid", error.message) + assert_equal(:bad_request, error.error_type) + assert_equal({ method: "tools/call", params: { name: "test_tool", arguments: input } }, error.request) + end + + def test_raises_unauthorized_error + tool = Tool.new( + "name" => "test_tool", + "description" => "A test tool", + "inputSchema" => { + "type" => "object", + "properties" => {}, + }, + ) + input = { "param" => "value" } + + stub_request(:post, url) + .with( + body: { + jsonrpc: "2.0", + id: mock_request_id, + method: "tools/call", + params: { + name: "test_tool", + arguments: input, + }, + mcp: { + jsonrpc: "2.0", + id: mock_request_id, + method: "tools/call", + params: { + name: "test_tool", + arguments: input, + }, + }, + }, + ) + .to_return(status: 401) + + error = assert_raises(RequestHandlerError) do + client.call_tool(tool: tool, input: input) + end + + assert_equal("You are unauthorized to make tools/call requests", error.message) + assert_equal(:unauthorized, error.error_type) + assert_equal({ method: "tools/call", params: { name: "test_tool", arguments: input } }, error.request) + end + + def test_raises_forbidden_error + tool = Tool.new( + "name" => "test_tool", + "description" => "A test tool", + "inputSchema" => { + "type" => "object", + "properties" => {}, + }, + ) + input = { "param" => "value" } + + stub_request(:post, url) + .with( + body: { + jsonrpc: "2.0", + id: mock_request_id, + method: "tools/call", + params: { + name: "test_tool", + arguments: input, + }, + mcp: { + jsonrpc: "2.0", + id: mock_request_id, + method: "tools/call", + params: { + name: "test_tool", + arguments: input, + }, + }, + }, + ) + .to_return(status: 403) + + error = assert_raises(RequestHandlerError) do + client.call_tool(tool: tool, input: input) + end + + assert_equal("You are forbidden to make tools/call requests", error.message) + assert_equal(:forbidden, error.error_type) + assert_equal({ method: "tools/call", params: { name: "test_tool", arguments: input } }, error.request) + end + + def test_raises_not_found_error + tool = Tool.new( + "name" => "test_tool", + "description" => "A test tool", + "inputSchema" => { + "type" => "object", + "properties" => {}, + }, + ) + input = { "param" => "value" } + + stub_request(:post, url) + .with( + body: { + jsonrpc: "2.0", + id: mock_request_id, + method: "tools/call", + params: { + name: "test_tool", + arguments: input, + }, + mcp: { + jsonrpc: "2.0", + id: mock_request_id, + method: "tools/call", + params: { + name: "test_tool", + arguments: input, + }, + }, + }, + ) + .to_return(status: 404) + + error = assert_raises(RequestHandlerError) do + client.call_tool(tool: tool, input: input) + end + + assert_equal("The tools/call request is not found", error.message) + assert_equal(:not_found, error.error_type) + assert_equal({ method: "tools/call", params: { name: "test_tool", arguments: input } }, error.request) + end + + def test_raises_unprocessable_entity_error + tool = Tool.new( + "name" => "test_tool", + "description" => "A test tool", + "inputSchema" => { + "type" => "object", + "properties" => {}, + }, + ) + input = { "param" => "value" } + + stub_request(:post, url) + .with( + body: { + jsonrpc: "2.0", + id: mock_request_id, + method: "tools/call", + params: { + name: "test_tool", + arguments: input, + }, + mcp: { + jsonrpc: "2.0", + id: mock_request_id, + method: "tools/call", + params: { + name: "test_tool", + arguments: input, + }, + }, + }, + ) + .to_return(status: 422) + + error = assert_raises(RequestHandlerError) do + client.call_tool(tool: tool, input: input) + end + + assert_equal("The tools/call request is unprocessable", error.message) + assert_equal(:unprocessable_entity, error.error_type) + assert_equal({ method: "tools/call", params: { name: "test_tool", arguments: input } }, error.request) + end + + def test_raises_internal_error + tool = Tool.new( + "name" => "test_tool", + "description" => "A test tool", + "inputSchema" => { + "type" => "object", + "properties" => {}, + }, + ) + input = { "param" => "value" } + + stub_request(:post, url) + .with( + body: { + jsonrpc: "2.0", + id: mock_request_id, + method: "tools/call", + params: { + name: "test_tool", + arguments: input, + }, + mcp: { + jsonrpc: "2.0", + id: mock_request_id, + method: "tools/call", + params: { + name: "test_tool", + arguments: input, + }, + }, + }, + ) + .to_return(status: 500) + + error = assert_raises(RequestHandlerError) do + client.call_tool(tool: tool, input: input) + end + + assert_equal("Internal error handling tools/call request", error.message) + assert_equal(:internal_error, error.error_type) + assert_equal({ method: "tools/call", params: { name: "test_tool", arguments: input } }, error.request) + end + + private + + def stub_request(method, url) + WebMock.stub_request(method, url) + end + + def mock_request_id + "random_request_id" + end + + def url + "http://example.com" + end + + def client + @client ||= begin + client = Http.new(url:) + client.stubs(:request_id).returns(mock_request_id) + client + end + end + end + end +end diff --git a/test/model_context_protocol/client/tool_test.rb b/test/model_context_protocol/client/tool_test.rb new file mode 100644 index 0000000..6dbcbc1 --- /dev/null +++ b/test/model_context_protocol/client/tool_test.rb @@ -0,0 +1,46 @@ +# frozen_string_literal: true + +require "test_helper" + +module ModelContextProtocol + module Client + class ToolTest < Minitest::Test + def test_name_returns_name_from_payload + tool = Tool.new("name" => "test_tool") + assert_equal("test_tool", tool.name) + end + + def test_name_returns_nil_when_not_in_payload + tool = Tool.new({}) + assert_nil(tool.name) + end + + def test_description_returns_description_from_payload + tool = Tool.new("description" => "A test tool") + assert_equal("A test tool", tool.description) + end + + def test_description_returns_nil_when_not_in_payload + tool = Tool.new({}) + assert_nil(tool.description) + end + + def test_input_schema_returns_input_schema_from_payload + schema = { "type" => "object", "properties" => { "foo" => { "type" => "string" } } } + tool = Tool.new("inputSchema" => schema) + assert_equal(schema, tool.input_schema) + end + + def test_input_schema_returns_nil_when_not_in_payload + tool = Tool.new({}) + assert_nil(tool.input_schema) + end + + def test_payload_is_accessible + payload = { "name" => "test", "description" => "desc", "inputSchema" => {} } + tool = Tool.new(payload) + assert_equal(payload, tool.payload) + end + end + end +end diff --git a/test/model_context_protocol/client/tools_test.rb b/test/model_context_protocol/client/tools_test.rb new file mode 100644 index 0000000..c832ecd --- /dev/null +++ b/test/model_context_protocol/client/tools_test.rb @@ -0,0 +1,96 @@ +# frozen_string_literal: true + +require "test_helper" + +module ModelContextProtocol + module Client + class ToolsTest < Minitest::Test + def test_each_iterates_over_tools + response = { + "result" => { + "tools" => [ + { "name" => "tool1", "description" => "First tool" }, + { "name" => "tool2", "description" => "Second tool" }, + ], + }, + } + tools = Tools.new(response) + + tool_names = [] + tools.each { |tool| tool_names << tool.name } + + assert_equal(["tool1", "tool2"], tool_names) + end + + def test_all_returns_array_of_tools + response = { + "result" => { + "tools" => [ + { "name" => "tool1", "description" => "First tool" }, + { "name" => "tool2", "description" => "Second tool" }, + ], + }, + } + tools = Tools.new(response) + + all_tools = tools.all + assert_equal(2, all_tools.length) + assert(all_tools.all? { |tool| tool.is_a?(Tool) }) + assert_equal(["tool1", "tool2"], all_tools.map(&:name)) + end + + def test_handles_empty_tools_array + response = { "result" => { "tools" => [] } } + tools = Tools.new(response) + + assert_equal([], tools.all) + assert_equal(0, tools.count) + end + + def test_handles_missing_tools_key + response = { "result" => {} } + tools = Tools.new(response) + + assert_equal([], tools.all) + assert_equal(0, tools.count) + end + + def test_handles_missing_result_key + response = {} + tools = Tools.new(response) + + assert_equal([], tools.all) + assert_equal(0, tools.count) + end + + def test_tools_are_initialized_with_correct_payload + response = { + "result" => { + "tools" => [ + { + "name" => "test_tool", + "description" => "A test tool", + "inputSchema" => { "type" => "object" }, + }, + ], + }, + } + tools = Tools.new(response) + tool = tools.all.first + + assert_equal("test_tool", tool.name) + assert_equal("A test tool", tool.description) + assert_equal({ "type" => "object" }, tool.input_schema) + end + + def test_includes_enumerable + response = { "result" => { "tools" => [] } } + tools = Tools.new(response) + + assert(tools.respond_to?(:map)) + assert(tools.respond_to?(:select)) + assert(tools.respond_to?(:find)) + end + end + end +end diff --git a/test/model_context_protocol/client_test.rb b/test/model_context_protocol/client_test.rb new file mode 100644 index 0000000..01522bf --- /dev/null +++ b/test/model_context_protocol/client_test.rb @@ -0,0 +1,8 @@ +# frozen_string_literal: true + +require "test_helper" + +module ModelContextProtocol + class ClientTest < Minitest::Test + end +end diff --git a/test/model_context_protocol/server/transports/stdio_transport_test.rb b/test/model_context_protocol/server/transports/stdio_transport_test.rb new file mode 100644 index 0000000..2b05d5f --- /dev/null +++ b/test/model_context_protocol/server/transports/stdio_transport_test.rb @@ -0,0 +1,127 @@ +# frozen_string_literal: true + +require "test_helper" +require "model_context_protocol/server/transports/stdio" +require "json" + +module ModelContextProtocol + class Server + module Transports + class StdioTransportTest < ActiveSupport::TestCase + include InstrumentationTestHelper + + setup do + configuration = ModelContextProtocol::Configuration.new + configuration.instrumentation_callback = instrumentation_helper.callback + @server = Server.new(name: "test_server", configuration: configuration) + @transport = StdioTransport.new(@server) + end + + test "initializes with server and closed state" do + server = @transport.instance_variable_get(:@server) + assert_equal @server.object_id, server.object_id + refute @transport.instance_variable_get(:@open) + end + + test "processes JSON-RPC requests from stdin and sends responses to stdout" do + request = { + jsonrpc: "2.0", + method: "ping", + id: "123", + } + input = StringIO.new(JSON.generate(request) + "\n") + output = StringIO.new + + original_stdin = $stdin + original_stdout = $stdout + + begin + $stdin = input + $stdout = output + + thread = Thread.new { @transport.open } + sleep(0.1) + @transport.close + thread.join + + response = JSON.parse(output.string, symbolize_names: true) + assert_equal("2.0", response[:jsonrpc]) + assert_equal("123", response[:id]) + assert_equal({}, response[:result]) + refute(@transport.instance_variable_get(:@open)) + ensure + $stdin = original_stdin + $stdout = original_stdout + end + end + + test "sends string responses to stdout" do + output = StringIO.new + original_stdout = $stdout + + begin + $stdout = output + @transport.send_response("test response") + assert_equal("test response\n", output.string) + ensure + $stdout = original_stdout + end + end + + test "sends JSON responses to stdout" do + output = StringIO.new + original_stdout = $stdout + + begin + $stdout = output + response = { key: "value" } + @transport.send_response(response) + assert_equal(JSON.generate(response) + "\n", output.string) + ensure + $stdout = original_stdout + end + end + + test "handles valid JSON-RPC requests" do + request = { + jsonrpc: "2.0", + method: "ping", + id: "123", + } + output = StringIO.new + original_stdout = $stdout + + begin + $stdout = output + @transport.send(:handle_request, JSON.generate(request)) + response = JSON.parse(output.string, symbolize_names: true) + assert_equal("2.0", response[:jsonrpc]) + assert_nil(response[:id]) + assert_nil(response[:result]) + ensure + $stdout = original_stdout + end + end + + test "handles invalid JSON requests" do + invalid_json = "invalid json" + output = StringIO.new + original_stdout = $stdout + + begin + $stdout = output + @transport.send(:handle_request, invalid_json) + response = JSON.parse(output.string, symbolize_names: true) + assert_equal("2.0", response[:jsonrpc]) + assert_nil(response[:id]) + assert_equal(-32600, response[:error][:code]) + assert_equal("Invalid Request", response[:error][:message]) + assert_equal("Request must be an array or a hash", response[:error][:data]) + ensure + $stdout = original_stdout + end + end + end + end + end +end diff --git a/test/model_context_protocol/configuration_test.rb b/test/model_context_protocol/shared/configuration_test.rb similarity index 100% rename from test/model_context_protocol/configuration_test.rb rename to test/model_context_protocol/shared/configuration_test.rb diff --git a/test/model_context_protocol/instrumentation_test.rb b/test/model_context_protocol/shared/instrumentation_test.rb similarity index 100% rename from test/model_context_protocol/instrumentation_test.rb rename to test/model_context_protocol/shared/instrumentation_test.rb diff --git a/test/model_context_protocol/methods_test.rb b/test/model_context_protocol/shared/methods_test.rb similarity index 100% rename from test/model_context_protocol/methods_test.rb rename to test/model_context_protocol/shared/methods_test.rb diff --git a/test/model_context_protocol/prompt_test.rb b/test/model_context_protocol/shared/prompt_test.rb similarity index 100% rename from test/model_context_protocol/prompt_test.rb rename to test/model_context_protocol/shared/prompt_test.rb diff --git a/test/model_context_protocol/string_utils_test.rb b/test/model_context_protocol/shared/string_utils_test.rb similarity index 100% rename from test/model_context_protocol/string_utils_test.rb rename to test/model_context_protocol/shared/string_utils_test.rb diff --git a/test/model_context_protocol/tool/input_schema_test.rb b/test/model_context_protocol/shared/tool/input_schema_test.rb similarity index 100% rename from test/model_context_protocol/tool/input_schema_test.rb rename to test/model_context_protocol/shared/tool/input_schema_test.rb diff --git a/test/model_context_protocol/tool_test.rb b/test/model_context_protocol/shared/tool_test.rb similarity index 100% rename from test/model_context_protocol/tool_test.rb rename to test/model_context_protocol/shared/tool_test.rb diff --git a/test/model_context_protocol/transports/stdio_transport_test.rb b/test/model_context_protocol/transports/stdio_transport_test.rb deleted file mode 100644 index 498b0ff..0000000 --- a/test/model_context_protocol/transports/stdio_transport_test.rb +++ /dev/null @@ -1,125 +0,0 @@ -# frozen_string_literal: true - -require "test_helper" -require "model_context_protocol/transports/stdio" -require "json" - -module ModelContextProtocol - module Transports - class StdioTransportTest < ActiveSupport::TestCase - include InstrumentationTestHelper - - setup do - configuration = ModelContextProtocol::Configuration.new - configuration.instrumentation_callback = instrumentation_helper.callback - @server = Server.new(name: "test_server", configuration: configuration) - @transport = StdioTransport.new(@server) - end - - test "initializes with server and closed state" do - server = @transport.instance_variable_get(:@server) - assert_equal @server.object_id, server.object_id - refute @transport.instance_variable_get(:@open) - end - - test "processes JSON-RPC requests from stdin and sends responses to stdout" do - request = { - jsonrpc: "2.0", - method: "ping", - id: "123", - } - input = StringIO.new(JSON.generate(request) + "\n") - output = StringIO.new - - original_stdin = $stdin - original_stdout = $stdout - - begin - $stdin = input - $stdout = output - - thread = Thread.new { @transport.open } - sleep(0.1) - @transport.close - thread.join - - response = JSON.parse(output.string, symbolize_names: true) - assert_equal("2.0", response[:jsonrpc]) - assert_equal("123", response[:id]) - assert_equal({}, response[:result]) - refute(@transport.instance_variable_get(:@open)) - ensure - $stdin = original_stdin - $stdout = original_stdout - end - end - - test "sends string responses to stdout" do - output = StringIO.new - original_stdout = $stdout - - begin - $stdout = output - @transport.send_response("test response") - assert_equal("test response\n", output.string) - ensure - $stdout = original_stdout - end - end - - test "sends JSON responses to stdout" do - output = StringIO.new - original_stdout = $stdout - - begin - $stdout = output - response = { key: "value" } - @transport.send_response(response) - assert_equal(JSON.generate(response) + "\n", output.string) - ensure - $stdout = original_stdout - end - end - - test "handles valid JSON-RPC requests" do - request = { - jsonrpc: "2.0", - method: "ping", - id: "123", - } - output = StringIO.new - original_stdout = $stdout - - begin - $stdout = output - @transport.send(:handle_request, JSON.generate(request)) - response = JSON.parse(output.string, symbolize_names: true) - assert_equal("2.0", response[:jsonrpc]) - assert_nil(response[:id]) - assert_nil(response[:result]) - ensure - $stdout = original_stdout - end - end - - test "handles invalid JSON requests" do - invalid_json = "invalid json" - output = StringIO.new - original_stdout = $stdout - - begin - $stdout = output - @transport.send(:handle_request, invalid_json) - response = JSON.parse(output.string, symbolize_names: true) - assert_equal("2.0", response[:jsonrpc]) - assert_nil(response[:id]) - assert_equal(-32600, response[:error][:code]) - assert_equal("Invalid Request", response[:error][:message]) - assert_equal("Request must be an array or a hash", response[:error][:data]) - ensure - $stdout = original_stdout - end - end - end - end -end