diff --git a/kotlin-sdk-test/build.gradle.kts b/kotlin-sdk-test/build.gradle.kts index 62d9f365..9f87efd6 100644 --- a/kotlin-sdk-test/build.gradle.kts +++ b/kotlin-sdk-test/build.gradle.kts @@ -17,5 +17,13 @@ kotlin { implementation(libs.kotlinx.coroutines.test) } } + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + implementation(libs.kotlin.logging) + implementation(libs.ktor.server.cio) + implementation(libs.ktor.client.cio) + } + } } } diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientIntegrationTest.kt deleted file mode 100644 index 562601aa..00000000 --- a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientIntegrationTest.kt +++ /dev/null @@ -1,41 +0,0 @@ -package io.modelcontextprotocol.kotlin.sdk.client - -import io.modelcontextprotocol.kotlin.sdk.Implementation -import io.modelcontextprotocol.kotlin.sdk.ListToolsResult -import kotlinx.coroutines.test.runTest -import kotlinx.io.asSink -import kotlinx.io.asSource -import kotlinx.io.buffered -import org.junit.jupiter.api.Disabled -import org.junit.jupiter.api.Test -import java.net.Socket - -class ClientIntegrationTest { - - fun createTransport(): StdioClientTransport { - val socket = Socket("localhost", 3000) - - return StdioClientTransport( - socket.inputStream.asSource().buffered(), - socket.outputStream.asSink().buffered(), - ) - } - - @Disabled("This test requires a running server") - @Test - fun testRequestTools() = runTest { - val client = Client( - Implementation("test", "1.0"), - ) - - val transport = createTransport() - try { - client.connect(transport) - - val response: ListToolsResult = client.listTools() - println(response.tools) - } finally { - transport.close() - } - } -} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt new file mode 100644 index 00000000..c367cf12 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt @@ -0,0 +1,102 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin + +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.sse.SSE +import io.ktor.server.application.install +import io.ktor.server.engine.EmbeddedServer +import io.ktor.server.engine.embeddedServer +import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.SseClientTransport +import io.modelcontextprotocol.kotlin.sdk.integration.utils.Retry +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.mcp +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import kotlin.time.Duration.Companion.seconds +import io.ktor.server.cio.CIO as ServerCIO +import io.ktor.server.sse.SSE as ServerSSE + +@Retry(times = 3) +abstract class KotlinTestBase { + + protected val host = "localhost" + protected abstract val port: Int + + protected lateinit var server: Server + protected lateinit var client: Client + protected lateinit var serverEngine: EmbeddedServer<*, *> + + protected abstract fun configureServerCapabilities(): ServerCapabilities + protected abstract fun configureServer() + + @BeforeEach + fun setUp() { + setupServer() + runBlocking { + setupClient() + } + } + + protected suspend fun setupClient() { + val transport = SseClientTransport( + HttpClient(CIO) { + install(SSE) + }, + "http://$host:$port", + ) + client = Client( + Implementation("test", "1.0"), + ) + client.connect(transport) + } + + protected fun setupServer() { + val capabilities = configureServerCapabilities() + + server = Server( + Implementation(name = "test-server", version = "1.0"), + ServerOptions(capabilities = capabilities), + ) + + configureServer() + + serverEngine = embeddedServer(ServerCIO, host = host, port = port) { + install(ServerSSE) + routing { + mcp { server } + } + }.start(wait = false) + } + + @AfterEach + fun tearDown() { + // close client + if (::client.isInitialized) { + try { + runBlocking { + withTimeout(3.seconds) { + client.close() + } + } + } catch (e: Exception) { + println("Warning: Error during client close: ${e.message}") + } + } + + // stop server + if (::serverEngine.isInitialized) { + try { + serverEngine.stop(500, 1000) + } catch (e: Exception) { + println("Warning: Error during server stop: ${e.message}") + } + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptEdgeCasesTest.kt new file mode 100644 index 00000000..f5e736d7 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptEdgeCasesTest.kt @@ -0,0 +1,412 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin + +import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest +import io.modelcontextprotocol.kotlin.sdk.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.Role +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class PromptEdgeCasesTest : KotlinTestBase() { + + override val port = 3008 + + private val basicPromptName = "basic-prompt" + private val basicPromptDescription = "A basic prompt for testing" + + private val complexPromptName = "complex-prompt" + private val complexPromptDescription = "A complex prompt with many arguments" + + private val largePromptName = "large-prompt" + private val largePromptDescription = "A very large prompt for testing" + private val largePromptContent = "X".repeat(100_000) // 100KB of data + + private val specialCharsPromptName = "special-chars-prompt" + private val specialCharsPromptDescription = "A prompt with special characters" + private val specialCharsContent = "!@#$%^&*()_+{}|:\"<>?~`-=[]\\;',./\n\t" + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + prompts = ServerCapabilities.Prompts( + listChanged = true, + ), + ) + + override fun configureServer() { + server.addPrompt( + name = basicPromptName, + description = basicPromptDescription, + arguments = listOf( + PromptArgument( + name = "name", + description = "The name to greet", + required = true, + ), + ), + ) { request -> + val name = request.arguments?.get("name") ?: "World" + + GetPromptResult( + description = basicPromptDescription, + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent(text = "Hello, $name!"), + ), + PromptMessage( + role = Role.assistant, + content = TextContent(text = "Greetings, $name! How can I assist you today?"), + ), + ), + ) + } + + server.addPrompt( + name = complexPromptName, + description = complexPromptDescription, + arguments = listOf( + PromptArgument(name = "arg1", description = "Argument 1", required = true), + PromptArgument(name = "arg2", description = "Argument 2", required = true), + PromptArgument(name = "arg3", description = "Argument 3", required = true), + PromptArgument(name = "arg4", description = "Argument 4", required = false), + PromptArgument(name = "arg5", description = "Argument 5", required = false), + PromptArgument(name = "arg6", description = "Argument 6", required = false), + PromptArgument(name = "arg7", description = "Argument 7", required = false), + PromptArgument(name = "arg8", description = "Argument 8", required = false), + PromptArgument(name = "arg9", description = "Argument 9", required = false), + PromptArgument(name = "arg10", description = "Argument 10", required = false), + ), + ) { request -> + // validate required arguments + val requiredArgs = listOf("arg1", "arg2", "arg3") + for (argName in requiredArgs) { + if (request.arguments?.get(argName) == null) { + throw IllegalArgumentException("Missing required argument: $argName") + } + } + + val args = mutableMapOf() + for (i in 1..10) { + val argName = "arg$i" + val argValue = request.arguments?.get(argName) + if (argValue != null) { + args[argName] = argValue + } + } + + GetPromptResult( + description = complexPromptDescription, + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent( + text = "Arguments: ${ + args.entries.joinToString { + "${it.key}=${it.value}" + } + }", + ), + ), + PromptMessage( + role = Role.assistant, + content = TextContent(text = "Received ${args.size} arguments"), + ), + ), + ) + } + + // Very large prompt + server.addPrompt( + name = largePromptName, + description = largePromptDescription, + arguments = listOf( + PromptArgument( + name = "size", + description = "Size multiplier", + required = false, + ), + ), + ) { request -> + val size = request.arguments?.get("size")?.toIntOrNull() ?: 1 + val content = largePromptContent.repeat(size) + + GetPromptResult( + description = largePromptDescription, + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent(text = "Generate a large response"), + ), + PromptMessage( + role = Role.assistant, + content = TextContent(text = content), + ), + ), + ) + } + + server.addPrompt( + name = specialCharsPromptName, + description = specialCharsPromptDescription, + arguments = listOf( + PromptArgument( + name = "special", + description = "Special characters to include", + required = false, + ), + ), + ) { request -> + val special = request.arguments?.get("special") ?: specialCharsContent + + GetPromptResult( + description = specialCharsPromptDescription, + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent(text = "Special characters: $special"), + ), + PromptMessage( + role = Role.assistant, + content = TextContent(text = "Received special characters: $special"), + ), + ), + ) + } + } + + @Test + fun testBasicPrompt() { + runTest { + val testName = "Alice" + val result = client.getPrompt( + GetPromptRequest( + name = basicPromptName, + arguments = mapOf("name" to testName), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals(basicPromptDescription, result.description, "Prompt description should match") + + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.user } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + assertEquals("Hello, $testName!", userContent.text, "User message content should match") + + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + assertEquals( + "Greetings, $testName! How can I assist you today?", + assistantContent.text, + "Assistant message content should match", + ) + } + } + + @Test + fun testComplexPromptWithManyArguments() { + runTest { + val arguments = (1..10).associate { i -> "arg$i" to "value$i" } + + val result = client.getPrompt( + GetPromptRequest( + name = complexPromptName, + arguments = arguments, + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals(complexPromptDescription, result.description, "Prompt description should match") + + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.user } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + + // verify all arguments + val text = userContent.text ?: "" + for (i in 1..10) { + assertTrue(text.contains("arg$i=value$i"), "Message should contain arg$i=value$i") + } + + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + assertEquals( + "Received 10 arguments", + assistantContent.text, + "Assistant message should indicate 10 arguments", + ) + } + } + + @Test + fun testLargePrompt() { + runTest { + val result = client.getPrompt( + GetPromptRequest( + name = largePromptName, + arguments = mapOf("size" to "1"), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals(largePromptDescription, result.description, "Prompt description should match") + + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + val text = assistantContent.text ?: "" + assertEquals(100_000, text.length, "Assistant message should be 100KB in size") + } + } + + @Test + fun testSpecialCharacters() { + runTest { + val result = client.getPrompt( + GetPromptRequest( + name = specialCharsPromptName, + arguments = mapOf("special" to specialCharsContent), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals(specialCharsPromptDescription, result.description, "Prompt description should match") + + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.user } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + val userText = userContent.text ?: "" + assertTrue(userText.contains(specialCharsContent), "User message should contain special characters") + + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + val assistantText = assistantContent.text ?: "" + assertTrue( + assistantText.contains(specialCharsContent), + "Assistant message should contain special characters", + ) + } + } + + @Test + fun testMissingRequiredArguments() { + runTest { + val exception = assertThrows { + runBlocking { + client.getPrompt( + GetPromptRequest( + name = complexPromptName, + arguments = mapOf("arg4" to "value4", "arg5" to "value5"), + ), + ) + } + } + + assertTrue( + exception.message?.contains("arg1") == true || + exception.message?.contains("arg2") == true || + exception.message?.contains("arg3") == true || + exception.message?.contains("required") == true, + "Exception should mention missing required arguments", + ) + } + } + + @Test + fun testConcurrentPromptRequests() { + runTest { + val concurrentCount = 10 + val results = mutableListOf() + + runBlocking { + repeat(concurrentCount) { index -> + launch { + val promptName = when (index % 4) { + 0 -> basicPromptName + 1 -> complexPromptName + 2 -> largePromptName + else -> specialCharsPromptName + } + + val arguments = when (promptName) { + basicPromptName -> mapOf("name" to "User$index") + complexPromptName -> mapOf("arg1" to "v1", "arg2" to "v2", "arg3" to "v3") + largePromptName -> mapOf("size" to "1") + else -> mapOf("special" to "!@#$%^&*()") + } + + val result = client.getPrompt( + GetPromptRequest( + name = promptName, + arguments = arguments, + ), + ) + + synchronized(results) { + results.add(result) + } + } + } + } + + assertEquals(concurrentCount, results.size, "All concurrent operations should complete") + + results.forEach { result -> + assertNotNull(result, "Result should not be null") + assertTrue(result.messages.isNotEmpty(), "Result messages should not be empty") + } + } + } + + @Test + fun testNonExistentPrompt() { + runTest { + val nonExistentPromptName = "non-existent-prompt" + + val exception = assertThrows { + runBlocking { + client.getPrompt( + GetPromptRequest( + name = nonExistentPromptName, + arguments = mapOf("name" to "Test"), + ), + ) + } + } + + assertTrue( + exception.message?.contains("not found") == true || + exception.message?.contains("does not exist") == true || + exception.message?.contains("unknown") == true || + exception.message?.contains("error") == true, + "Exception should indicate prompt not found", + ) + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptIntegrationTest.kt new file mode 100644 index 00000000..a609c2ba --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/PromptIntegrationTest.kt @@ -0,0 +1,470 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin + +import io.modelcontextprotocol.kotlin.sdk.GetPromptRequest +import io.modelcontextprotocol.kotlin.sdk.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.ImageContent +import io.modelcontextprotocol.kotlin.sdk.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.PromptMessageContent +import io.modelcontextprotocol.kotlin.sdk.Role +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class PromptIntegrationTest : KotlinTestBase() { + + override val port = 3004 + private val testPromptName = "greeting" + private val testPromptDescription = "A simple greeting prompt" + private val complexPromptName = "multimodal-prompt" + private val complexPromptDescription = "A prompt with multiple content types" + private val conversationPromptName = "conversation" + private val conversationPromptDescription = "A prompt with multiple messages and roles" + private val strictPromptName = "strict-prompt" + private val strictPromptDescription = "A prompt with required arguments" + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + prompts = ServerCapabilities.Prompts( + listChanged = true, + ), + ) + + override fun configureServer() { + // simple prompt with a name parameter + server.addPrompt( + name = testPromptName, + description = testPromptDescription, + arguments = listOf( + PromptArgument( + name = "name", + description = "The name to greet", + required = true, + ), + ), + ) { request -> + val name = request.arguments?.get("name") ?: "World" + + GetPromptResult( + description = testPromptDescription, + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent(text = "Hello, $name!"), + ), + PromptMessage( + role = Role.assistant, + content = TextContent(text = "Greetings, $name! How can I assist you today?"), + ), + ), + ) + } + + // prompt with multiple content types + server.addPrompt( + name = complexPromptName, + description = complexPromptDescription, + arguments = listOf( + PromptArgument( + name = "topic", + description = "The topic to discuss", + required = false, + ), + PromptArgument( + name = "includeImage", + description = "Whether to include an image", + required = false, + ), + ), + ) { request -> + val topic = request.arguments?.get("topic") ?: "general knowledge" + val includeImage = request.arguments?.get("includeImage")?.toBoolean() ?: true + + val messages = mutableListOf() + + messages.add( + PromptMessage( + role = Role.user, + content = TextContent(text = "I'd like to discuss $topic."), + ), + ) + + val assistantContents = mutableListOf() + assistantContents.add(TextContent(text = "I'd be happy to discuss $topic with you.")) + + if (includeImage) { + assistantContents.add( + ImageContent( + data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BmMIQAAAABJRU5ErkJggg==", + mimeType = "image/png", + ), + ) + } + + messages.add( + PromptMessage( + role = Role.assistant, + content = assistantContents[0], + ), + ) + + GetPromptResult( + description = complexPromptDescription, + messages = messages, + ) + } + + // prompt with multiple messages and roles + server.addPrompt( + name = conversationPromptName, + description = conversationPromptDescription, + arguments = listOf( + PromptArgument( + name = "topic", + description = "The topic of the conversation", + required = false, + ), + ), + ) { request -> + val topic = request.arguments?.get("topic") ?: "weather" + + GetPromptResult( + description = conversationPromptDescription, + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent(text = "Let's talk about the $topic."), + ), + PromptMessage( + role = Role.assistant, + content = TextContent( + text = "Sure, I'd love to discuss the $topic. What would you like to know?", + ), + ), + PromptMessage( + role = Role.user, + content = TextContent(text = "What's your opinion on the $topic?"), + ), + PromptMessage( + role = Role.assistant, + content = TextContent( + text = "As an AI, I don't have personal opinions," + + " but I can provide information about $topic.", + ), + ), + PromptMessage( + role = Role.user, + content = TextContent(text = "That's helpful, thank you!"), + ), + PromptMessage( + role = Role.assistant, + content = TextContent( + text = "You're welcome! Let me know if you have more questions about $topic.", + ), + ), + ), + ) + } + + // prompt with strict required arguments + server.addPrompt( + name = strictPromptName, + description = strictPromptDescription, + arguments = listOf( + PromptArgument( + name = "requiredArg1", + description = "First required argument", + required = true, + ), + PromptArgument( + name = "requiredArg2", + description = "Second required argument", + required = true, + ), + PromptArgument( + name = "optionalArg", + description = "Optional argument", + required = false, + ), + ), + ) { request -> + val args = request.arguments ?: emptyMap() + val arg1 = args["requiredArg1"] ?: throw IllegalArgumentException( + "Missing required argument: requiredArg1", + ) + val arg2 = args["requiredArg2"] ?: throw IllegalArgumentException( + "Missing required argument: requiredArg2", + ) + val optArg = args["optionalArg"] ?: "default" + + GetPromptResult( + description = strictPromptDescription, + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent(text = "Required arguments: $arg1, $arg2. Optional: $optArg"), + ), + PromptMessage( + role = Role.assistant, + content = TextContent(text = "I received your arguments: $arg1, $arg2, and $optArg"), + ), + ), + ) + } + } + + @Test + fun testListPrompts() = runTest { + val result = client.listPrompts() + + assertNotNull(result, "List prompts result should not be null") + assertTrue(result.prompts.isNotEmpty(), "Prompts list should not be empty") + + val testPrompt = result.prompts.find { it.name == testPromptName } + assertNotNull(testPrompt, "Test prompt should be in the list") + assertEquals( + testPromptDescription, + testPrompt.description, + "Prompt description should match", + ) + + val arguments = testPrompt.arguments ?: error("Prompt arguments should not be null") + assertTrue(arguments.isNotEmpty(), "Prompt arguments should not be empty") + + val nameArg = arguments.find { it.name == "name" } + assertNotNull(nameArg, "Name argument should be in the list") + assertEquals( + "The name to greet", + nameArg.description, + "Argument description should match", + ) + assertEquals(true, nameArg.required, "Argument required flag should match") + } + + @Test + fun testGetPrompt() = runTest { + val testName = "Alice" + val result = client.getPrompt( + GetPromptRequest( + name = testPromptName, + arguments = mapOf("name" to testName), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals( + testPromptDescription, + result.description, + "Prompt description should match", + ) + + assertTrue(result.messages.isNotEmpty(), "Prompt messages should not be empty") + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.user } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + assertNotNull(userContent.text, "User message text should not be null") + assertEquals( + "Hello, $testName!", + userContent.text, + "User message content should match", + ) + + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + assertNotNull(assistantContent.text, "Assistant message text should not be null") + assertEquals( + "Greetings, $testName! How can I assist you today?", + assistantContent.text, + "Assistant message content should match", + ) + } + + @Test + fun testMissingRequiredArguments() = runTest { + val promptsList = client.listPrompts() + assertNotNull(promptsList, "Prompts list should not be null") + val strictPrompt = promptsList.prompts.find { it.name == strictPromptName } + assertNotNull(strictPrompt, "Strict prompt should be in the list") + + val argsDef = strictPrompt.arguments ?: error("Prompt arguments should not be null") + val requiredArgs = argsDef.filter { it.required == true } + assertEquals( + 2, + requiredArgs.size, + "Strict prompt should have 2 required arguments", + ) + + // test missing required arg + val exception = assertThrows { + runBlocking { + client.getPrompt( + GetPromptRequest( + name = strictPromptName, + arguments = mapOf("requiredArg1" to "value1"), + ), + ) + } + } + + assertEquals( + true, + exception.message?.contains("requiredArg2"), + "Exception should mention the missing argument", + ) + + // test with no args + val exception2 = assertThrows { + runBlocking { + client.getPrompt( + GetPromptRequest( + name = strictPromptName, + arguments = emptyMap(), + ), + ) + } + } + + assertEquals( + exception2.message?.contains("requiredArg"), + true, + "Exception should mention a missing required argument", + ) + + // test with all required args + val result = client.getPrompt( + GetPromptRequest( + name = strictPromptName, + arguments = mapOf( + "requiredArg1" to "value1", + "requiredArg2" to "value2", + ), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.user } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + val userText = requireNotNull(userContent.text) + assertTrue(userText.contains("value1"), "Message should contain first argument") + assertTrue(userText.contains("value2"), "Message should contain second argument") + } + + @Test + fun testComplexContentTypes() = runTest { + val topic = "artificial intelligence" + val result = client.getPrompt( + GetPromptRequest( + name = complexPromptName, + arguments = mapOf( + "topic" to topic, + "includeImage" to "true", + ), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals( + complexPromptDescription, + result.description, + "Prompt description should match", + ) + + assertTrue(result.messages.isNotEmpty(), "Prompt messages should not be empty") + assertEquals(2, result.messages.size, "Prompt should have 2 messages") + + val userMessage = result.messages.find { it.role == Role.user } + assertNotNull(userMessage, "User message should be in the list") + val userContent = userMessage.content as? TextContent + assertNotNull(userContent, "User message content should be TextContent") + val userText2 = requireNotNull(userContent.text) + assertTrue(userText2.contains(topic), "User message should contain the topic") + + val assistantMessage = result.messages.find { it.role == Role.assistant } + assertNotNull(assistantMessage, "Assistant message should be in the list") + val assistantContent = assistantMessage.content as? TextContent + assertNotNull(assistantContent, "Assistant message content should be TextContent") + val assistantText = requireNotNull(assistantContent.text) + assertTrue( + assistantText.contains(topic), + "Assistant message should contain the topic", + ) + + val resultNoImage = client.getPrompt( + GetPromptRequest( + name = complexPromptName, + arguments = mapOf( + "topic" to topic, + "includeImage" to "false", + ), + ), + ) + + assertNotNull(resultNoImage, "Get prompt result (no image) should not be null") + assertEquals(2, resultNoImage.messages.size, "Prompt should have 2 messages") + } + + @Test + fun testMultipleMessagesAndRoles() = runTest { + val topic = "climate change" + val result = client.getPrompt( + GetPromptRequest( + name = conversationPromptName, + arguments = mapOf("topic" to topic), + ), + ) + + assertNotNull(result, "Get prompt result should not be null") + assertEquals( + conversationPromptDescription, + result.description, + "Prompt description should match", + ) + + assertTrue(result.messages.isNotEmpty(), "Prompt messages should not be empty") + assertEquals(6, result.messages.size, "Prompt should have 6 messages") + + val userMessages = result.messages.filter { it.role == Role.user } + val assistantMessages = result.messages.filter { it.role == Role.assistant } + + assertEquals(3, userMessages.size, "Should have 3 user messages") + assertEquals(3, assistantMessages.size, "Should have 3 assistant messages") + + for (i in 0 until result.messages.size) { + val expectedRole = if (i % 2 == 0) Role.user else Role.assistant + assertEquals( + expectedRole, + result.messages[i].role, + "Message $i should have role $expectedRole", + ) + } + + for (message in result.messages) { + val content = message.content as? TextContent + assertNotNull(content, "Message content should be TextContent") + val text = requireNotNull(content.text) + + // Either the message contains the topic or it's a generic conversation message + val containsTopic = text.contains(topic) + val isGenericMessage = text.contains("thank you") || text.contains("welcome") + + assertTrue( + containsTopic || isGenericMessage, + "Message should either contain the topic or be a generic conversation message", + ) + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt new file mode 100644 index 00000000..232ac025 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceEdgeCasesTest.kt @@ -0,0 +1,285 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin + +import io.modelcontextprotocol.kotlin.sdk.BlobResourceContents +import io.modelcontextprotocol.kotlin.sdk.EmptyRequestResult +import io.modelcontextprotocol.kotlin.sdk.Method +import io.modelcontextprotocol.kotlin.sdk.ReadResourceRequest +import io.modelcontextprotocol.kotlin.sdk.ReadResourceResult +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.SubscribeRequest +import io.modelcontextprotocol.kotlin.sdk.TextResourceContents +import io.modelcontextprotocol.kotlin.sdk.UnsubscribeRequest +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import java.util.concurrent.atomic.AtomicBoolean +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class ResourceEdgeCasesTest : KotlinTestBase() { + + override val port = 3007 + + private val testResourceUri = "test://example.txt" + private val testResourceName = "Test Resource" + private val testResourceDescription = "A test resource for integration testing" + private val testResourceContent = "This is the content of the test resource." + + private val binaryResourceUri = "test://image.png" + private val binaryResourceName = "Binary Resource" + private val binaryResourceDescription = "A binary resource for testing" + private val binaryResourceContent = + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" + + private val largeResourceUri = "test://large.txt" + private val largeResourceName = "Large Resource" + private val largeResourceDescription = "A large text resource for testing" + private val largeResourceContent = "X".repeat(100_000) // 100KB of data + + private val dynamicResourceUri = "test://dynamic.txt" + private val dynamicResourceName = "Dynamic Resource" + private val dynamicResourceContent = AtomicBoolean(false) + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + resources = ServerCapabilities.Resources( + subscribe = true, + listChanged = true, + ), + ) + + override fun configureServer() { + server.addResource( + uri = testResourceUri, + name = testResourceName, + description = testResourceDescription, + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = testResourceContent, + uri = request.uri, + mimeType = "text/plain", + ), + ), + ) + } + + server.addResource( + uri = binaryResourceUri, + name = binaryResourceName, + description = binaryResourceDescription, + mimeType = "image/png", + ) { request -> + ReadResourceResult( + contents = listOf( + BlobResourceContents( + blob = binaryResourceContent, + uri = request.uri, + mimeType = "image/png", + ), + ), + ) + } + + server.addResource( + uri = largeResourceUri, + name = largeResourceName, + description = largeResourceDescription, + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = largeResourceContent, + uri = request.uri, + mimeType = "text/plain", + ), + ), + ) + } + + server.addResource( + uri = dynamicResourceUri, + name = dynamicResourceName, + description = "A resource that can be updated", + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = if (dynamicResourceContent.get()) "Updated content" else "Original content", + uri = request.uri, + mimeType = "text/plain", + ), + ), + ) + } + + server.setRequestHandler(Method.Defined.ResourcesSubscribe) { _, _ -> + EmptyRequestResult() + } + + server.setRequestHandler(Method.Defined.ResourcesUnsubscribe) { _, _ -> + EmptyRequestResult() + } + } + + @Test + fun testBinaryResource() { + runTest { + val result = client.readResource(ReadResourceRequest(uri = binaryResourceUri)) + + assertNotNull(result, "Read resource result should not be null") + assertTrue(result.contents.isNotEmpty(), "Resource contents should not be empty") + + val content = result.contents.firstOrNull() as? BlobResourceContents + assertNotNull(content, "Resource content should be BlobResourceContents") + assertEquals(binaryResourceContent, content.blob, "Binary resource content should match") + assertEquals("image/png", content.mimeType, "MIME type should match") + } + } + + @Test + fun testLargeResource() { + runTest { + val result = client.readResource(ReadResourceRequest(uri = largeResourceUri)) + + assertNotNull(result, "Read resource result should not be null") + assertTrue(result.contents.isNotEmpty(), "Resource contents should not be empty") + + val content = result.contents.firstOrNull() as? TextResourceContents + assertNotNull(content, "Resource content should be TextResourceContents") + assertEquals(100_000, content.text.length, "Large resource content length should match") + assertEquals("X".repeat(100_000), content.text, "Large resource content should match") + } + } + + @Test + fun testInvalidResourceUri() { + runTest { + val invalidUri = "test://nonexistent.txt" + + val exception = assertThrows { + runBlocking { + client.readResource(ReadResourceRequest(uri = invalidUri)) + } + } + + assertTrue( + exception.message?.contains("not found") == true || + exception.message?.contains("invalid") == true || + exception.message?.contains("error") == true, + "Exception should indicate resource not found or invalid URI", + ) + } + } + + @Test + fun testDynamicResource() { + runTest { + val initialResult = client.readResource(ReadResourceRequest(uri = dynamicResourceUri)) + assertNotNull(initialResult, "Initial read result should not be null") + val initialContent = (initialResult.contents.firstOrNull() as? TextResourceContents)?.text + assertEquals("Original content", initialContent, "Initial content should match") + + // update resource + dynamicResourceContent.set(true) + + val updatedResult = client.readResource(ReadResourceRequest(uri = dynamicResourceUri)) + assertNotNull(updatedResult, "Updated read result should not be null") + val updatedContent = (updatedResult.contents.firstOrNull() as? TextResourceContents)?.text + assertEquals("Updated content", updatedContent, "Updated content should match") + } + } + + @Test + fun testResourceAddAndRemove() { + runTest { + val initialList = client.listResources() + assertNotNull(initialList, "Initial list result should not be null") + val initialCount = initialList.resources.size + + val newResourceUri = "test://new-resource.txt" + server.addResource( + uri = newResourceUri, + name = "New Resource", + description = "A newly added resource", + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = "New resource content", + uri = request.uri, + mimeType = "text/plain", + ), + ), + ) + } + + val updatedList = client.listResources() + assertNotNull(updatedList, "Updated list result should not be null") + val updatedCount = updatedList.resources.size + + assertEquals(initialCount + 1, updatedCount, "Resource count should increase by 1") + val newResource = updatedList.resources.find { it.uri == newResourceUri } + assertNotNull(newResource, "New resource should be in the list") + + server.removeResource(newResourceUri) + + val finalList = client.listResources() + assertNotNull(finalList, "Final list result should not be null") + val finalCount = finalList.resources.size + + assertEquals(initialCount, finalCount, "Resource count should return to initial value") + val removedResource = finalList.resources.find { it.uri == newResourceUri } + assertEquals(null, removedResource, "Resource should be removed from the list") + } + } + + @Test + fun testConcurrentResourceOperations() { + runTest { + val concurrentCount = 10 + val results = mutableListOf() + + runBlocking { + repeat(concurrentCount) { index -> + launch { + val uri = when (index % 3) { + 0 -> testResourceUri + 1 -> binaryResourceUri + else -> largeResourceUri + } + + val result = client.readResource(ReadResourceRequest(uri = uri)) + synchronized(results) { + results.add(result) + } + } + } + } + + assertEquals(concurrentCount, results.size, "All concurrent operations should complete") + results.forEach { result -> + assertNotNull(result, "Result should not be null") + assertTrue(result.contents.isNotEmpty(), "Result contents should not be empty") + } + } + } + + @Test + fun testSubscribeAndUnsubscribe() { + runTest { + val subscribeResult = client.subscribeResource(SubscribeRequest(uri = testResourceUri)) + assertNotNull(subscribeResult, "Subscribe result should not be null") + + val unsubscribeResult = client.unsubscribeResource(UnsubscribeRequest(uri = testResourceUri)) + assertNotNull(unsubscribeResult, "Unsubscribe result should not be null") + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceIntegrationTest.kt new file mode 100644 index 00000000..c467b2a1 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ResourceIntegrationTest.kt @@ -0,0 +1,94 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin + +import io.modelcontextprotocol.kotlin.sdk.EmptyRequestResult +import io.modelcontextprotocol.kotlin.sdk.Method +import io.modelcontextprotocol.kotlin.sdk.ReadResourceRequest +import io.modelcontextprotocol.kotlin.sdk.ReadResourceResult +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.SubscribeRequest +import io.modelcontextprotocol.kotlin.sdk.TextResourceContents +import io.modelcontextprotocol.kotlin.sdk.UnsubscribeRequest +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import org.junit.jupiter.api.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class ResourceIntegrationTest : KotlinTestBase() { + + override val port = 3005 + private val testResourceUri = "test://example.txt" + private val testResourceName = "Test Resource" + private val testResourceDescription = "A test resource for integration testing" + private val testResourceContent = "This is the content of the test resource." + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + resources = ServerCapabilities.Resources( + subscribe = true, + listChanged = true, + ), + ) + + override fun configureServer() { + server.addResource( + uri = testResourceUri, + name = testResourceName, + description = testResourceDescription, + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents( + text = testResourceContent, + uri = request.uri, + mimeType = "text/plain", + ), + ), + ) + } + + server.setRequestHandler(Method.Defined.ResourcesSubscribe) { _, _ -> + EmptyRequestResult() + } + + server.setRequestHandler(Method.Defined.ResourcesUnsubscribe) { _, _ -> + EmptyRequestResult() + } + } + + @Test + fun testListResources() = runTest { + val result = client.listResources() + + assertNotNull(result, "List resources result should not be null") + assertTrue(result.resources.isNotEmpty(), "Resources list should not be empty") + + val testResource = result.resources.find { it.uri == testResourceUri } + assertNotNull(testResource, "Test resource should be in the list") + assertEquals(testResourceName, testResource.name, "Resource name should match") + assertEquals(testResourceDescription, testResource.description, "Resource description should match") + } + + @Test + fun testReadResource() = runTest { + val result = client.readResource(ReadResourceRequest(uri = testResourceUri)) + + assertNotNull(result, "Read resource result should not be null") + assertTrue(result.contents.isNotEmpty(), "Resource contents should not be empty") + + val content = result.contents.firstOrNull() as? TextResourceContents + assertNotNull(content, "Resource content should be TextResourceContents") + assertEquals(testResourceContent, content.text, "Resource content should match") + } + + @Test + fun testSubscribeAndUnsubscribe() { + runTest { + val subscribeResult = client.subscribeResource(SubscribeRequest(uri = testResourceUri)) + assertNotNull(subscribeResult, "Subscribe result should not be null") + + val unsubscribeResult = client.unsubscribeResource(UnsubscribeRequest(uri = testResourceUri)) + assertNotNull(unsubscribeResult, "Unsubscribe result should not be null") + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt new file mode 100644 index 00000000..0cb8c506 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolEdgeCasesTest.kt @@ -0,0 +1,505 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin + +import io.modelcontextprotocol.kotlin.sdk.CallToolRequest +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.CallToolResultBase +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.Tool +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertCallToolResult +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertJsonProperty +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertTextContent +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.add +import kotlinx.serialization.json.buildJsonArray +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class ToolEdgeCasesTest : KotlinTestBase() { + + override val port = 3009 + + private val basicToolName = "basic-tool" + private val basicToolDescription = "A basic tool for testing" + + private val complexToolName = "complex-tool" + private val complexToolDescription = "A complex tool with nested schema" + + private val largeToolName = "large-tool" + private val largeToolDescription = "A tool that returns a large response" + private val largeToolContent = "X".repeat(100_000) // 100KB of data + + private val slowToolName = "slow-tool" + private val slowToolDescription = "A tool that takes time to respond" + + private val specialCharsToolName = "special-chars-tool" + private val specialCharsToolDescription = "A tool that handles special characters" + private val specialCharsContent = "!@#$%^&*()_+{}|:\"<>?~`-=[]\\;',./\n\t" + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + tools = ServerCapabilities.Tools( + listChanged = true, + ), + ) + + override fun configureServer() { + server.addTool( + name = basicToolName, + description = basicToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "text", + buildJsonObject { + put("type", "string") + put("description", "The text to echo back") + }, + ) + }, + required = listOf("text"), + ), + ) { request -> + val text = (request.arguments["text"] as? JsonPrimitive)?.content ?: "No text provided" + + CallToolResult( + content = listOf(TextContent(text = "Echo: $text")), + structuredContent = buildJsonObject { + put("result", text) + }, + ) + } + + server.addTool( + name = complexToolName, + description = complexToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "user", + buildJsonObject { + put("type", "object") + put("description", "User information") + put( + "properties", + buildJsonObject { + put( + "name", + buildJsonObject { + put("type", "string") + put("description", "User's name") + }, + ) + put( + "age", + buildJsonObject { + put("type", "integer") + put("description", "User's age") + }, + ) + put( + "address", + buildJsonObject { + put("type", "object") + put("description", "User's address") + put( + "properties", + buildJsonObject { + put( + "street", + buildJsonObject { + put("type", "string") + }, + ) + put( + "city", + buildJsonObject { + put("type", "string") + }, + ) + put( + "country", + buildJsonObject { + put("type", "string") + }, + ) + }, + ) + }, + ) + }, + ) + }, + ) + put( + "options", + buildJsonObject { + put("type", "array") + put("description", "Additional options") + put( + "items", + buildJsonObject { + put("type", "string") + }, + ) + }, + ) + }, + required = listOf("user"), + ), + ) { request -> + val user = request.arguments["user"] as? JsonObject + val name = (user?.get("name") as? JsonPrimitive)?.content ?: "Unknown" + val age = (user?.get("age") as? JsonPrimitive)?.content?.toIntOrNull() ?: 0 + + val address = user?.get("address") as? JsonObject + val street = (address?.get("street") as? JsonPrimitive)?.content ?: "Unknown" + val city = (address?.get("city") as? JsonPrimitive)?.content ?: "Unknown" + val country = (address?.get("country") as? JsonPrimitive)?.content ?: "Unknown" + + val options = (request.arguments["options"] as? JsonArray)?.mapNotNull { + (it as? JsonPrimitive)?.content + } ?: emptyList() + + val summary = + "User: $name, Age: $age, Address: $street, $city, $country, Options: ${options.joinToString(", ")}" + + CallToolResult( + content = listOf(TextContent(text = summary)), + structuredContent = buildJsonObject { + put("name", name) + put("age", age) + put( + "address", + buildJsonObject { + put("street", street) + put("city", city) + put("country", country) + }, + ) + put( + "options", + buildJsonArray { + options.forEach { add(it) } + }, + ) + }, + ) + } + + server.addTool( + name = largeToolName, + description = largeToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "size", + buildJsonObject { + put("type", "integer") + put("description", "Size multiplier") + }, + ) + }, + ), + ) { request -> + val size = (request.arguments["size"] as? JsonPrimitive)?.content?.toIntOrNull() ?: 1 + val content = largeToolContent.take(largeToolContent.length.coerceAtMost(size * 1000)) + + CallToolResult( + content = listOf(TextContent(text = content)), + structuredContent = buildJsonObject { + put("size", content.length) + }, + ) + } + + server.addTool( + name = slowToolName, + description = slowToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "delay", + buildJsonObject { + put("type", "integer") + put("description", "Delay in milliseconds") + }, + ) + }, + ), + ) { request -> + val delay = (request.arguments["delay"] as? JsonPrimitive)?.content?.toIntOrNull() ?: 1000 + + // simulate slow operation + runBlocking { + delay(delay.toLong()) + } + + CallToolResult( + content = listOf(TextContent(text = "Completed after ${delay}ms delay")), + structuredContent = buildJsonObject { + put("delay", delay) + }, + ) + } + + server.addTool( + name = specialCharsToolName, + description = specialCharsToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "special", + buildJsonObject { + put("type", "string") + put("description", "Special characters to process") + }, + ) + }, + ), + ) { request -> + val special = (request.arguments["special"] as? JsonPrimitive)?.content ?: specialCharsContent + + CallToolResult( + content = listOf(TextContent(text = "Received special characters: $special")), + structuredContent = buildJsonObject { + put("special", special) + put("length", special.length) + }, + ) + } + } + + @Test + fun testBasicTool() { + runTest { + val testText = "Hello, world!" + val arguments = mapOf("text" to testText) + + val result = client.callTool(basicToolName, arguments) + + val toolResult = assertCallToolResult(result) + assertTextContent(toolResult.content.firstOrNull(), "Echo: $testText") + + val structuredContent = toolResult.structuredContent as JsonObject + assertJsonProperty(structuredContent, "result", testText) + } + } + + @Test + fun testComplexNestedSchema() { + runTest { + val userJson = buildJsonObject { + put("name", JsonPrimitive("John Doe")) + put("age", JsonPrimitive(30)) + put( + "address", + buildJsonObject { + put("street", JsonPrimitive("123 Main St")) + put("city", JsonPrimitive("New York")) + put("country", JsonPrimitive("USA")) + }, + ) + } + + val optionsJson = buildJsonArray { + add(JsonPrimitive("option1")) + add(JsonPrimitive("option2")) + add(JsonPrimitive("option3")) + } + + val arguments = buildJsonObject { + put("user", userJson) + put("options", optionsJson) + } + + val result = client.callTool( + CallToolRequest( + name = complexToolName, + arguments = arguments, + ), + ) + + val toolResult = assertCallToolResult(result) + val content = toolResult.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + val text = content.text ?: "" + + assertTrue(text.contains("John Doe"), "Result should contain the name") + assertTrue(text.contains("30"), "Result should contain the age") + assertTrue(text.contains("123 Main St"), "Result should contain the street") + assertTrue(text.contains("New York"), "Result should contain the city") + assertTrue(text.contains("USA"), "Result should contain the country") + assertTrue(text.contains("option1"), "Result should contain option1") + assertTrue(text.contains("option2"), "Result should contain option2") + assertTrue(text.contains("option3"), "Result should contain option3") + + val structuredContent = toolResult.structuredContent as JsonObject + assertJsonProperty(structuredContent, "name", "John Doe") + assertJsonProperty(structuredContent, "age", 30) + + val address = structuredContent["address"] as? JsonObject + assertNotNull(address, "Address should be present in structured content") + assertJsonProperty(address, "street", "123 Main St") + assertJsonProperty(address, "city", "New York") + assertJsonProperty(address, "country", "USA") + + val options = structuredContent["options"] as? JsonArray + assertNotNull(options, "Options should be present in structured content") + assertEquals(3, options.size, "Options should have 3 items") + } + } + + @Test + fun testLargeResponse() { + runTest { + val size = 10 + val arguments = mapOf("size" to size) + + val result = client.callTool(largeToolName, arguments) + + val toolResult = assertCallToolResult(result) + val content = toolResult.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + val text = content.text ?: "" + + assertEquals(10000, text.length, "Response should be 10KB in size") + + val structuredContent = toolResult.structuredContent as JsonObject + assertJsonProperty(structuredContent, "size", 10000) + } + } + + @Test + fun testSlowTool() { + runTest { + val delay = 500 + val arguments = mapOf("delay" to delay) + + val startTime = System.currentTimeMillis() + val result = client.callTool(slowToolName, arguments) + val endTime = System.currentTimeMillis() + + val toolResult = assertCallToolResult(result) + val content = toolResult.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + val text = content.text ?: "" + + assertTrue(text.contains("${delay}ms"), "Result should mention the delay") + assertTrue(endTime - startTime >= delay, "Tool should take at least the specified delay") + + val structuredContent = toolResult.structuredContent as JsonObject + assertJsonProperty(structuredContent, "delay", delay) + } + } + + @Test + fun testSpecialCharacters() { + runTest { + val arguments = mapOf("special" to specialCharsContent) + + val result = client.callTool(specialCharsToolName, arguments) + + val toolResult = assertCallToolResult(result) + val content = toolResult.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + val text = content.text ?: "" + + assertTrue(text.contains(specialCharsContent), "Result should contain the special characters") + + val structuredContent = toolResult.structuredContent as JsonObject + val special = structuredContent["special"]?.toString()?.trim('"') + + assertNotNull(special, "Special characters should be in structured content") + assertTrue(text.contains(specialCharsContent), "Special characters should be in the content") + } + } + + @Test + fun testConcurrentToolCalls() { + runTest { + val concurrentCount = 10 + val results = mutableListOf() + + runBlocking { + repeat(concurrentCount) { index -> + launch { + val toolName = when (index % 5) { + 0 -> basicToolName + 1 -> complexToolName + 2 -> largeToolName + 3 -> slowToolName + else -> specialCharsToolName + } + + val arguments = when (toolName) { + basicToolName -> mapOf("text" to "Concurrent call $index") + + complexToolName -> mapOf( + "user" to mapOf( + "name" to "User $index", + "age" to 20 + index, + "address" to mapOf( + "street" to "Street $index", + "city" to "City $index", + "country" to "Country $index", + ), + ), + ) + + largeToolName -> mapOf("size" to 1) + + slowToolName -> mapOf("delay" to 100) + + else -> mapOf("special" to "!@#$%^&*()") + } + + val result = client.callTool(toolName, arguments) + + synchronized(results) { + results.add(result) + } + } + } + } + + assertEquals(concurrentCount, results.size, "All concurrent operations should complete") + results.forEach { result -> + assertNotNull(result, "Result should not be null") + assertTrue(result.content.isNotEmpty(), "Result content should not be empty") + } + } + } + + @Test + fun testNonExistentTool() { + runTest { + val nonExistentToolName = "non-existent-tool" + val arguments = mapOf("text" to "Test") + + val exception = assertThrows { + runBlocking { + client.callTool(nonExistentToolName, arguments) + } + } + + assertTrue( + exception.message?.contains("not found") == true || + exception.message?.contains("does not exist") == true || + exception.message?.contains("unknown") == true || + exception.message?.contains("error") == true, + "Exception should indicate tool not found", + ) + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt new file mode 100644 index 00000000..c6262a13 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/ToolIntegrationTest.kt @@ -0,0 +1,473 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin + +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.ImageContent +import io.modelcontextprotocol.kotlin.sdk.PromptMessageContent +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.Tool +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertCallToolResult +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertJsonProperty +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.assertTextContent +import io.modelcontextprotocol.kotlin.sdk.integration.utils.TestUtils.runTest +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.add +import kotlinx.serialization.json.buildJsonArray +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +class ToolIntegrationTest : KotlinTestBase() { + + override val port = 3006 + private val testToolName = "echo" + private val testToolDescription = "A simple echo tool that returns the input text" + private val complexToolName = "calculator" + private val complexToolDescription = "A calculator tool that performs operations on numbers" + private val errorToolName = "error-tool" + private val errorToolDescription = "A tool that demonstrates error handling" + private val multiContentToolName = "multi-content" + private val multiContentToolDescription = "A tool that returns multiple content types" + + override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities( + tools = ServerCapabilities.Tools( + listChanged = true, + ), + ) + + override fun configureServer() { + setupEchoTool() + setupCalculatorTool() + setupErrorHandlingTool() + setupMultiContentTool() + } + + private fun setupEchoTool() { + server.addTool( + name = testToolName, + description = testToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "text", + buildJsonObject { + put("type", "string") + put("description", "The text to echo back") + }, + ) + }, + required = listOf("text"), + ), + ) { request -> + val text = (request.arguments["text"] as? JsonPrimitive)?.content ?: "No text provided" + + CallToolResult( + content = listOf(TextContent(text = "Echo: $text")), + structuredContent = buildJsonObject { + put("result", text) + }, + ) + } + } + + private fun setupCalculatorTool() { + server.addTool( + name = complexToolName, + description = complexToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "operation", + buildJsonObject { + put("type", "string") + put("description", "The operation to perform (add, subtract, multiply, divide)") + put( + "enum", + buildJsonArray { + add("add") + add("subtract") + add("multiply") + add("divide") + }, + ) + }, + ) + put( + "a", + buildJsonObject { + put("type", "number") + put("description", "First operand") + }, + ) + put( + "b", + buildJsonObject { + put("type", "number") + put("description", "Second operand") + }, + ) + put( + "precision", + buildJsonObject { + put("type", "integer") + put("description", "Number of decimal places (optional)") + put("default", 2) + }, + ) + put( + "showSteps", + buildJsonObject { + put("type", "boolean") + put("description", "Whether to show calculation steps") + put("default", false) + }, + ) + put( + "tags", + buildJsonObject { + put("type", "array") + put("description", "Optional tags for the calculation") + put( + "items", + buildJsonObject { + put("type", "string") + }, + ) + }, + ) + }, + required = listOf("operation", "a", "b"), + ), + ) { request -> + val operation = (request.arguments["operation"] as? JsonPrimitive)?.content ?: "add" + val a = (request.arguments["a"] as? JsonPrimitive)?.content?.toDoubleOrNull() ?: 0.0 + val b = (request.arguments["b"] as? JsonPrimitive)?.content?.toDoubleOrNull() ?: 0.0 + val precision = (request.arguments["precision"] as? JsonPrimitive)?.content?.toIntOrNull() ?: 2 + val showSteps = (request.arguments["showSteps"] as? JsonPrimitive)?.content?.toBoolean() ?: false + val tags = (request.arguments["tags"] as? JsonArray)?.mapNotNull { + (it as? JsonPrimitive)?.content + } ?: emptyList() + + val result = when (operation) { + "add" -> a + b + "subtract" -> a - b + "multiply" -> a * b + "divide" -> if (b != 0.0) a / b else Double.POSITIVE_INFINITY + else -> 0.0 + } + + val formattedResult = "%.${precision}f".format(result) + + val textContent = if (showSteps) { + "Operation: $operation\nA: $a\nB: $b\nResult: $formattedResult\nTags: ${ + tags.joinToString(", ") + }" + } else { + "Result: $formattedResult" + } + + CallToolResult( + content = listOf(TextContent(text = textContent)), + structuredContent = buildJsonObject { + put("operation", operation) + put("a", a) + put("b", b) + put("result", result) + put("formattedResult", formattedResult) + put("precision", precision) + put("tags", buildJsonArray { tags.forEach { add(it) } }) + }, + ) + } + } + + private fun setupErrorHandlingTool() { + server.addTool( + name = errorToolName, + description = errorToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "errorType", + buildJsonObject { + put("type", "string") + put("description", "Type of error to simulate (none, exception, error)") + put( + "enum", + buildJsonArray { + add("none") + add("exception") + add("error") + }, + ) + }, + ) + put( + "message", + buildJsonObject { + put("type", "string") + put("description", "Custom error message") + put("default", "An error occurred") + }, + ) + }, + required = listOf("errorType"), + ), + ) { request -> + val errorType = (request.arguments["errorType"] as? JsonPrimitive)?.content ?: "none" + val message = (request.arguments["message"] as? JsonPrimitive)?.content ?: "An error occurred" + + when (errorType) { + "exception" -> throw IllegalArgumentException(message) + + "error" -> CallToolResult( + content = listOf(TextContent(text = "Error: $message")), + structuredContent = buildJsonObject { + put("error", true) + put("message", message) + }, + ) + + else -> CallToolResult( + content = listOf(TextContent(text = "No error occurred")), + structuredContent = buildJsonObject { + put("error", false) + put("message", "Success") + }, + ) + } + } + } + + private fun setupMultiContentTool() { + server.addTool( + name = multiContentToolName, + description = multiContentToolDescription, + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "text", + buildJsonObject { + put("type", "string") + put("description", "Text to include in the response") + }, + ) + put( + "includeImage", + buildJsonObject { + put("type", "boolean") + put("description", "Whether to include an image in the response") + put("default", true) + }, + ) + }, + required = listOf("text"), + ), + ) { request -> + val text = (request.arguments["text"] as? JsonPrimitive)?.content ?: "Default text" + val includeImage = (request.arguments["includeImage"] as? JsonPrimitive)?.content?.toBoolean() ?: true + + val content = mutableListOf( + TextContent(text = "Text content: $text"), + ) + + if (includeImage) { + content.add( + ImageContent( + data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==", + mimeType = "image/png", + ), + ) + } + + CallToolResult( + content = content, + structuredContent = buildJsonObject { + put("text", text) + put("includeImage", includeImage) + }, + ) + } + } + + @Test + fun testListTools() = runTest { + val result = client.listTools() + + assertNotNull(result, "List utils result should not be null") + assertTrue(result.tools.isNotEmpty(), "Tools list should not be empty") + + val testTool = result.tools.find { it.name == testToolName } + assertNotNull(testTool, "Test tool should be in the list") + assertEquals( + testToolDescription, + testTool.description, + "Tool description should match", + ) + } + + @Test + fun testCallTool() = runTest { + val testText = "Hello, world!" + val arguments = mapOf("text" to testText) + + val result = client.callTool(testToolName, arguments) + + val toolResult = assertCallToolResult(result) + assertTextContent(toolResult.content.firstOrNull(), "Echo: $testText") + + val structuredContent = toolResult.structuredContent as JsonObject + assertJsonProperty(structuredContent, "result", testText) + } + + @Test + fun testComplexInputSchemaTool() { + runTest { + val toolsList = client.listTools() + assertNotNull(toolsList, "Tools list should not be null") + val calculatorTool = toolsList.tools.find { it.name == complexToolName } + assertNotNull(calculatorTool, "Calculator tool should be in the list") + + val arguments = mapOf( + "operation" to "multiply", + "a" to 5.5, + "b" to 2.0, + "precision" to 3, + "showSteps" to true, + "tags" to listOf("test", "calculator", "integration"), + ) + + val result = client.callTool(complexToolName, arguments) + + val toolResult = assertCallToolResult(result) + + val content = toolResult.content.firstOrNull() as? TextContent + assertNotNull(content, "Tool result content should be TextContent") + val contentText = requireNotNull(content.text) + + assertTrue(contentText.contains("Operation"), "Result should contain operation") + assertTrue( + contentText.contains("multiply"), + "Result should contain multiply operation", + ) + assertTrue(contentText.contains("5.5"), "Result should contain first operand") + assertTrue(contentText.contains("2.0"), "Result should contain second operand") + assertTrue(contentText.contains("11"), "Result should contain result value") + + val structuredContent = toolResult.structuredContent as JsonObject + assertJsonProperty(structuredContent, "operation", "multiply") + assertJsonProperty(structuredContent, "result", 11.0) + + val formattedResult = structuredContent["formattedResult"]?.toString()?.trim('"') ?: "" + assertTrue( + formattedResult == "11.000" || formattedResult == "11,000", + "Formatted result should be either '11.000' or '11,000', but was '$formattedResult'", + ) + assertJsonProperty(structuredContent, "precision", 3) + + val tags = structuredContent["tags"] as? JsonArray + assertNotNull(tags, "Tags should be present") + } + } + + @Test + fun testToolErrorHandling() = runTest { + val successArgs = mapOf("errorType" to "none") + val successResult = client.callTool(errorToolName, successArgs) + + val successToolResult = assertCallToolResult(successResult, "No error: ") + assertTextContent(successToolResult.content.firstOrNull(), "No error occurred") + + val noErrorStructured = successToolResult.structuredContent as JsonObject + assertJsonProperty(noErrorStructured, "error", false) + + val errorArgs = mapOf( + "errorType" to "error", + "message" to "Custom error message", + ) + val errorResult = client.callTool(errorToolName, errorArgs) + + val errorToolResult = assertCallToolResult(errorResult, "Error: ") + assertTextContent(errorToolResult.content.firstOrNull(), "Error: Custom error message") + + val errorStructured = errorToolResult.structuredContent as JsonObject + assertJsonProperty(errorStructured, "error", true) + assertJsonProperty(errorStructured, "message", "Custom error message") + + val exceptionArgs = mapOf( + "errorType" to "exception", + "message" to "Exception message", + ) + + val exception = assertThrows { + runBlocking { + client.callTool(errorToolName, exceptionArgs) + } + } + + assertEquals( + exception.message?.contains("Exception message"), + true, + "Exception message should contain 'Exception message'", + ) + } + + @Test + fun testMultiContentTool() = runTest { + val testText = "Test multi-content" + val arguments = mapOf( + "text" to testText, + "includeImage" to true, + ) + + val result = client.callTool(multiContentToolName, arguments) + + val toolResult = assertCallToolResult(result) + assertEquals( + 2, + toolResult.content.size, + "Tool result should have 2 content items", + ) + + val textContent = toolResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Result should contain TextContent") + assertNotNull(textContent.text, "Text content should not be null") + assertEquals( + "Text content: $testText", + textContent.text, + "Text content should match", + ) + + val imageContent = toolResult.content.firstOrNull { it is ImageContent } as? ImageContent + assertNotNull(imageContent, "Result should contain ImageContent") + assertEquals("image/png", imageContent.mimeType, "Image MIME type should match") + assertTrue(imageContent.data.isNotEmpty(), "Image data should not be empty") + + val structuredContent = toolResult.structuredContent as JsonObject + assertJsonProperty(structuredContent, "text", testText) + assertJsonProperty(structuredContent, "includeImage", true) + + val textOnlyArgs = mapOf( + "text" to testText, + "includeImage" to false, + ) + + val textOnlyResult = client.callTool(multiContentToolName, textOnlyArgs) + + val textOnlyToolResult = assertCallToolResult(textOnlyResult, "Text-only: ") + assertEquals( + 1, + textOnlyToolResult.content.size, + "Text-only result should have 1 content item", + ) + + assertTextContent(textOnlyToolResult.content.firstOrNull(), "Text content: $testText") + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt new file mode 100644 index 00000000..3905a20a --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerEdgeCasesTest.kt @@ -0,0 +1,258 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript + +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.sse.SSE +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.mcpStreamableHttp +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeout +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import org.junit.jupiter.api.assertThrows +import java.util.concurrent.TimeUnit +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds + +class KotlinClientTypeScriptServerEdgeCasesTest : TypeScriptTestBase() { + + private var port: Int = 0 + private val host = "localhost" + private lateinit var serverUrl: String + + private lateinit var client: Client + private lateinit var tsServerProcess: Process + + @BeforeEach + fun setUp() { + port = findFreePort() + serverUrl = "http://$host:$port/mcp" + tsServerProcess = startTypeScriptServer(port) + println("TypeScript server started on port $port") + } + + @AfterEach + fun tearDown() { + if (::client.isInitialized) { + try { + runBlocking { + withTimeout(3.seconds) { + client.close() + } + } + } catch (e: Exception) { + println("Warning: Error during client close: ${e.message}") + } + } + + if (::tsServerProcess.isInitialized) { + try { + println("Stopping TypeScript server") + stopProcess(tsServerProcess) + } catch (e: Exception) { + println("Warning: Error during TypeScript server stop: ${e.message}") + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testNonExistentTool() { + runBlocking { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val nonExistentToolName = "non-existent-tool" + val arguments = mapOf("name" to "TestUser") + + val exception = assertThrows { + client.callTool(nonExistentToolName, arguments) + } + + val errorMessage = exception.message ?: "" + assertTrue( + errorMessage.contains("not found") || + errorMessage.contains("unknown") || + errorMessage.contains("error"), + "Exception should indicate tool not found: $errorMessage", + ) + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testSpecialCharactersInArguments() { + runBlocking { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val specialChars = "!@#$%^&*()_+{}[]|\\:;\"'<>,.?/" + val arguments = mapOf("name" to specialChars) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + + val text = textContent.text ?: "" + assertTrue( + text.contains(specialChars), + "Tool response should contain the special characters", + ) + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testLargePayload() { + runBlocking { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val largeName = "A".repeat(10 * 1024) + val arguments = mapOf("name" to largeName) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + + val text = textContent.text ?: "" + assertTrue( + text.contains("Hello,") && text.contains("A"), + "Tool response should contain the greeting with the large name", + ) + } + } + } + + @Test + @Timeout(60, unit = TimeUnit.SECONDS) + fun testConcurrentRequests() { + runBlocking { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val concurrentCount = 5 + val results = mutableListOf>() + + for (i in 1..concurrentCount) { + val deferred = async { + val name = "ConcurrentClient$i" + val arguments = mapOf("name" to name) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null for client $i") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present for client $i") + + textContent.text ?: "" + } + results.add(deferred) + } + + val responses = results.awaitAll() + + for (i in 1..concurrentCount) { + val expectedName = "ConcurrentClient$i" + val matchingResponses = responses.filter { it.contains("Hello, $expectedName!") } + assertEquals( + 1, + matchingResponses.size, + "Should have exactly one response for $expectedName", + ) + } + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testInvalidArguments() { + runBlocking { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val invalidArguments = mapOf( + "name" to JsonObject(mapOf("nested" to JsonPrimitive("value"))), + ) + + try { + val result = client.callTool("greet", invalidArguments) + assertNotNull(result, "Tool call result should not be null") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + } catch (e: Exception) { + assertTrue( + e.message?.contains("invalid") == true || + e.message?.contains("error") == true, + "Exception should indicate invalid arguments: ${e.message}", + ) + } + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testMultipleToolCalls() { + runBlocking { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + repeat(10) { i -> + val name = "SequentialClient$i" + val arguments = mapOf("name" to name) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null for call $i") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present for call $i") + + assertEquals( + "Hello, $name!", + textContent.text, + "Tool response should contain the greeting with the provided name", + ) + } + } + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt new file mode 100644 index 00000000..f4cf8ffc --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/KotlinClientTypeScriptServerTest.kt @@ -0,0 +1,172 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript + +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.sse.SSE +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.mcpStreamableHttp +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withContext +import kotlinx.coroutines.withTimeout +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import java.util.concurrent.TimeUnit +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds + +class KotlinClientTypeScriptServerTest : TypeScriptTestBase() { + + private var port: Int = 0 + private val host = "localhost" + private lateinit var serverUrl: String + + private lateinit var client: Client + private lateinit var tsServerProcess: Process + + @BeforeEach + fun setUp() { + port = findFreePort() + serverUrl = "http://$host:$port/mcp" + tsServerProcess = startTypeScriptServer(port) + println("TypeScript server started on port $port") + } + + @AfterEach + fun tearDown() { + if (::client.isInitialized) { + try { + runBlocking { + withTimeout(3.seconds) { + client.close() + } + } + } catch (e: Exception) { + println("Warning: Error during client close: ${e.message}") + } + } + + if (::tsServerProcess.isInitialized) { + try { + println("Stopping TypeScript server") + stopProcess(tsServerProcess) + } catch (e: Exception) { + println("Warning: Error during TypeScript server stop: ${e.message}") + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testKotlinClientConnectsToTypeScriptServer() { + runBlocking { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + assertNotNull(client, "Client should be initialized") + + val pingResult = client.ping() + assertNotNull(pingResult, "Ping result should not be null") + + val serverImpl = client.serverVersion + assertNotNull(serverImpl, "Server implementation should not be null") + println("Connected to TypeScript server: ${serverImpl.name} v${serverImpl.version}") + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testListTools() { + runBlocking { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val result = client.listTools() + assertNotNull(result, "Tools list should not be null") + assertTrue(result.tools.isNotEmpty(), "Tools list should not be empty") + + // Verify specific utils are available + val toolNames = result.tools.map { it.name } + assertTrue(toolNames.contains("greet"), "Greet tool should be available") + assertTrue(toolNames.contains("multi-greet"), "Multi-greet tool should be available") + assertTrue(toolNames.contains("collect-user-info"), "Collect-user-info tool should be available") + + println("Available utils: ${toolNames.joinToString()}") + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testToolCall() { + runBlocking { + withContext(Dispatchers.IO) { + client = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val testName = "TestUser" + val arguments = mapOf("name" to testName) + + val result = client.callTool("greet", arguments) + assertNotNull(result, "Tool call result should not be null") + + val callResult = result as CallToolResult + val textContent = callResult.content.firstOrNull { it is TextContent } as? TextContent + assertNotNull(textContent, "Text content should be present in the result") + assertEquals( + "Hello, $testName!", + textContent.text, + "Tool response should contain the greeting with the provided name", + ) + } + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testMultipleClients() { + runBlocking { + withContext(Dispatchers.IO) { + // First client connection + val client1 = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val tools1 = client1.listTools() + assertNotNull(tools1, "Tools list for first client should not be null") + assertTrue(tools1.tools.isNotEmpty(), "Tools list for first client should not be empty") + + val client2 = HttpClient(CIO) { + install(SSE) + }.mcpStreamableHttp(serverUrl) + + val tools2 = client2.listTools() + assertNotNull(tools2, "Tools list for second client should not be null") + assertTrue(tools2.tools.isNotEmpty(), "Tools list for second client should not be empty") + + val toolNames1 = tools1.tools.map { it.name } + val toolNames2 = tools2.tools.map { it.name } + + assertTrue(toolNames1.contains("greet"), "Greet tool should be available to first client") + assertTrue(toolNames1.contains("multi-greet"), "Multi-greet tool should be available to first client") + assertTrue(toolNames2.contains("greet"), "Greet tool should be available to second client") + assertTrue(toolNames2.contains("multi-greet"), "Multi-greet tool should be available to second client") + + client1.close() + client2.close() + } + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt new file mode 100644 index 00000000..9459dcd1 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptClientKotlinServerTest.kt @@ -0,0 +1,189 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript + +import io.modelcontextprotocol.kotlin.sdk.integration.utils.KotlinServerForTypeScriptClient +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import java.util.concurrent.TimeUnit +import kotlin.test.assertTrue + +class TypeScriptClientKotlinServerTest : TypeScriptTestBase() { + + private var port: Int = 0 + private lateinit var serverUrl: String + private var httpServer: KotlinServerForTypeScriptClient? = null + + @BeforeEach + fun setUp() { + port = findFreePort() + serverUrl = "http://localhost:$port/mcp" + killProcessOnPort(port) + httpServer = KotlinServerForTypeScriptClient() + httpServer?.start(port) + if (!waitForPort(port = port)) { + throw IllegalStateException("Kotlin test server did not become ready on localhost:$port within timeout") + } + println("Kotlin server started on port $port") + } + + @AfterEach + fun tearDown() { + try { + httpServer?.stop() + println("HTTP server stopped") + } catch (e: Exception) { + println("Error during server shutdown: ${e.message}") + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testToolCall() { + val testName = "TestUser" + val command = "npx tsx myClient.ts $serverUrl greet $testName" + val output = executeCommand(command, tsClientDir) + + assertTrue( + output.contains("Hello, $testName!"), + "Tool response should contain the greeting with the provided name", + ) + assertTrue(output.contains("Tool result:"), "Output should indicate a successful tool call") + assertTrue(output.contains("Text content:"), "Output should contain the text content section") + assertTrue(output.contains("Structured content:"), "Output should contain the structured content section") + assertTrue( + output.contains("\"greeting\": \"Hello, $testName!\""), + "Structured content should contain the greeting", + ) + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testNotifications() { + val name = "NotifUser" + val command = "npx tsx myClient.ts $serverUrl multi-greet $name" + val output = executeCommand(command, tsClientDir) + + assertTrue( + output.contains("Multiple greetings") || output.contains("greeting"), + "Tool response should contain greeting message", + ) + // verify that the server sent 3 notifications + assertTrue( + output.contains("\"notificationCount\": 3") || output.contains("notificationCount: 3"), + "Structured content should indicate that 3 notifications were emitted by the server.\nOutput:\n$output", + ) + } + + @Test + @Timeout(120, unit = TimeUnit.SECONDS) + fun testMultipleClientSequence() { + val testName1 = "FirstClient" + val command1 = "npx tsx myClient.ts $serverUrl greet $testName1" + val output1 = executeCommand(command1, tsClientDir) + + assertTrue(output1.contains("Connected to server"), "First client should connect to server") + assertTrue(output1.contains("Hello, $testName1!"), "Tool response should contain the greeting for first client") + assertTrue(output1.contains("Disconnected from server"), "First client should disconnect cleanly") + + val testName2 = "SecondClient" + val command2 = "npx tsx myClient.ts $serverUrl multi-greet $testName2" + val output2 = executeCommand(command2, tsClientDir) + + assertTrue(output2.contains("Connected to server"), "Second client should connect to server") + assertTrue( + output2.contains("Multiple greetings") || output2.contains("greeting"), + "Tool response should contain greeting message", + ) + assertTrue(output2.contains("Disconnected from server"), "Second client should disconnect cleanly") + + val command3 = "npx tsx myClient.ts $serverUrl" + val output3 = executeCommand(command3, tsClientDir) + + assertTrue(output3.contains("Connected to server"), "Third client should connect to server") + assertTrue(output3.contains("Available utils:"), "Third client should list available utils") + assertTrue(output3.contains("greet"), "Greet tool should be available to third client") + assertTrue(output3.contains("multi-greet"), "Multi-greet tool should be available to third client") + assertTrue(output3.contains("Disconnected from server"), "Third client should disconnect cleanly") + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testMultipleClientParallel() { + val clientCount = 3 + val clients = listOf( + "FirstClient" to "greet", + "SecondClient" to "multi-greet", + "ThirdClient" to "", + ) + + val threads = mutableListOf() + val outputs = mutableListOf>() + val exceptions = mutableListOf() + + for (i in 0 until clientCount) { + val (clientName, toolName) = clients[i] + val thread = Thread { + try { + val command = if (toolName.isEmpty()) { + "npx tsx myClient.ts $serverUrl" + } else { + "npx tsx myClient.ts $serverUrl $toolName $clientName" + } + + val output = executeCommand(command, tsClientDir) + synchronized(outputs) { + outputs.add(i to output) + } + } catch (e: Exception) { + synchronized(exceptions) { + exceptions.add(e) + } + } + } + threads.add(thread) + thread.start() + Thread.sleep(500) + } + + threads.forEach { it.join() } + + if (exceptions.isNotEmpty()) { + println( + "Exceptions occurred in parallel clients: ${ + exceptions.joinToString { + it.message ?: it.toString() + } + }", + ) + } + + val sortedOutputs = outputs.sortedBy { it.first }.map { it.second } + + sortedOutputs.forEachIndexed { index, output -> + val clientName = clients[index].first + val toolName = clients[index].second + + when (toolName) { + "greet" -> { + val containsGreeting = output.contains("Hello, $clientName!") || + output.contains("\"greeting\": \"Hello, $clientName!\"") + assertTrue( + containsGreeting, + "Tool response should contain the greeting for $clientName", + ) + } + + "multi-greet" -> { + val containsGreeting = output.contains("Multiple greetings") || + output.contains("greeting") || + output.contains("greet") + assertTrue( + containsGreeting, + "Tool response should contain greeting message for $clientName", + ) + } + } + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt new file mode 100644 index 00000000..86c9b8fe --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptEdgeCasesTest.kt @@ -0,0 +1,187 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript + +import io.modelcontextprotocol.kotlin.sdk.integration.utils.KotlinServerForTypeScriptClient +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import org.junit.jupiter.api.condition.EnabledOnOs +import org.junit.jupiter.api.condition.OS +import java.io.File +import java.util.concurrent.TimeUnit +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class TypeScriptEdgeCasesTest : TypeScriptTestBase() { + + private var port: Int = 0 + private lateinit var serverUrl: String + private var httpServer: KotlinServerForTypeScriptClient? = null + + @BeforeEach + fun setUp() { + port = findFreePort() + serverUrl = "http://localhost:$port/mcp" + killProcessOnPort(port) + httpServer = KotlinServerForTypeScriptClient() + httpServer?.start(port) + if (!waitForPort(port = port)) { + throw IllegalStateException("Kotlin test server did not become ready on localhost:$port within timeout") + } + println("Kotlin server started on port $port") + } + + @AfterEach + fun tearDown() { + try { + httpServer?.stop() + println("HTTP server stopped") + } catch (e: Exception) { + println("Error during server shutdown: ${e.message}") + } + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testErrorHandling() { + val nonExistentToolCommand = "npx tsx myClient.ts $serverUrl non-existent-tool" + val nonExistentToolOutput = executeCommandAllowingFailure(nonExistentToolCommand, tsClientDir) + + assertTrue( + nonExistentToolOutput.contains("Tool \"non-existent-tool\" not found"), + "Client should handle non-existent tool gracefully", + ) + + val invalidUrlCommand = "npx tsx myClient.ts http://localhost:${port + 1000}/mcp greet TestUser" + val invalidUrlOutput = executeCommandAllowingFailure(invalidUrlCommand, tsClientDir) + + assertTrue( + invalidUrlOutput.contains("Error:") || invalidUrlOutput.contains("ECONNREFUSED"), + "Client should handle connection errors gracefully", + ) + } + + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + fun testSpecialCharacters() { + val specialChars = "!@#$+-[].,?" + + val tempFile = File.createTempFile("special_chars", ".txt") + tempFile.writeText(specialChars) + tempFile.deleteOnExit() + + val specialCharsContent = tempFile.readText() + val specialCharsCommand = "npx tsx myClient.ts $serverUrl greet \"$specialCharsContent\"" + val specialCharsOutput = executeCommand(specialCharsCommand, tsClientDir) + + assertTrue( + specialCharsOutput.contains("Hello, $specialChars!"), + "Tool should handle special characters in arguments", + ) + assertTrue( + specialCharsOutput.contains("Disconnected from server"), + "Client should disconnect cleanly after handling special characters", + ) + } + + // skip on windows as it can't handle long commands + @Test + @Timeout(30, unit = TimeUnit.SECONDS) + @EnabledOnOs(OS.MAC, OS.LINUX) + fun testLargePayload() { + val largeName = "A".repeat(10 * 1024) + + val tempFile = File.createTempFile("large_name", ".txt") + tempFile.writeText(largeName) + tempFile.deleteOnExit() + + val largeNameContent = tempFile.readText() + val largePayloadCommand = "npx tsx myClient.ts $serverUrl greet \"$largeNameContent\"" + val largePayloadOutput = executeCommand(largePayloadCommand, tsClientDir) + + tempFile.delete() + + assertTrue( + largePayloadOutput.contains("Hello,") && largePayloadOutput.contains("A".repeat(20)), + "Tool should handle large payloads", + ) + assertTrue( + largePayloadOutput.contains("Disconnected from server"), + "Client should disconnect cleanly after handling large payload", + ) + } + + @Test + @Timeout(60, unit = TimeUnit.SECONDS) + fun testComplexConcurrentRequests() { + val commands = listOf( + "npx tsx myClient.ts $serverUrl greet \"Client1\"", + "npx tsx myClient.ts $serverUrl multi-greet \"Client2\"", + "npx tsx myClient.ts $serverUrl greet \"Client3\"", + "npx tsx myClient.ts $serverUrl", + "npx tsx myClient.ts $serverUrl multi-greet \"Client5\"", + ) + + val threads = commands.mapIndexed { index, command -> + Thread { + println("Starting client $index") + val output = executeCommand(command, tsClientDir) + println("Client $index completed") + + assertTrue( + output.contains("Connected to server"), + "Client $index should connect to server", + ) + assertTrue( + output.contains("Disconnected from server"), + "Client $index should disconnect cleanly", + ) + + when { + command.contains("greet \"Client1\"") -> + assertTrue(output.contains("Hello, Client1!"), "Client 1 should receive correct greeting") + + command.contains("multi-greet \"Client2\"") -> + assertTrue(output.contains("Multiple greetings"), "Client 2 should receive multiple greetings") + + command.contains("greet \"Client3\"") -> + assertTrue(output.contains("Hello, Client3!"), "Client 3 should receive correct greeting") + + !command.contains("greet") && !command.contains("multi-greet") -> + assertTrue(output.contains("Available utils:"), "Client 4 should list available tools") + + command.contains("multi-greet \"Client5\"") -> + assertTrue(output.contains("Multiple greetings"), "Client 5 should receive multiple greetings") + } + }.apply { start() } + } + + threads.forEach { it.join() } + } + + @Test + @Timeout(120, unit = TimeUnit.SECONDS) + fun testRapidSequentialRequests() { + val outputs = (1..10).map { i -> + val command = "npx tsx myClient.ts $serverUrl greet \"RapidClient$i\"" + val output = executeCommand(command, tsClientDir) + + assertTrue( + output.contains("Connected to server"), + "Client $i should connect to server", + ) + assertTrue( + output.contains("Hello, RapidClient$i!"), + "Client $i should receive correct greeting", + ) + assertTrue( + output.contains("Disconnected from server"), + "Client $i should disconnect cleanly", + ) + + output + } + + assertEquals(10, outputs.size, "All 10 rapid requests should complete successfully") + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt new file mode 100644 index 00000000..0cfcef60 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/TypeScriptTestBase.kt @@ -0,0 +1,210 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.typescript + +import io.modelcontextprotocol.kotlin.sdk.integration.utils.Retry +import org.junit.jupiter.api.BeforeAll +import java.io.BufferedReader +import java.io.File +import java.io.InputStreamReader +import java.net.ServerSocket +import java.net.Socket +import java.nio.file.Files +import java.util.concurrent.TimeUnit + +@Retry(times = 3) +abstract class TypeScriptTestBase { + + protected val projectRoot: File get() = File(System.getProperty("user.dir")) + protected val tsClientDir: File + get() = File( + projectRoot, + "src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils", + ) + + companion object { + @JvmStatic + private val tempRootDir: File = Files.createTempDirectory("typescript-sdk-").toFile().apply { deleteOnExit() } + + @JvmStatic + protected val sdkDir: File = File(tempRootDir, "typescript-sdk") + + @JvmStatic + @BeforeAll + fun setupTypeScriptSdk() { + println("Cloning TypeScript SDK repository") + + if (!sdkDir.exists()) { + val process = ProcessBuilder( + "git", + "clone", + "--depth", + "1", + "https://github.com/modelcontextprotocol/typescript-sdk.git", + sdkDir.absolutePath, + ) + .redirectErrorStream(true) + .start() + val exitCode = process.waitFor() + if (exitCode != 0) { + throw RuntimeException("Failed to clone TypeScript SDK repository: exit code $exitCode") + } + } + + println("Installing TypeScript SDK dependencies") + executeCommand("npm install", sdkDir) + } + + @JvmStatic + protected fun executeCommand(command: String, workingDir: File): String = + runCommand(command, workingDir, allowFailure = false, timeoutSeconds = null) + + @JvmStatic + protected fun killProcessOnPort(port: Int) { + val isWindows = System.getProperty("os.name").lowercase().contains("windows") + val killCommand = if (isWindows) { + "netstat -ano | findstr :$port | for /f \"tokens=5\" %a in ('more') do taskkill /F /PID %a 2>nul || echo No process found" + } else { + "lsof -ti:$port | xargs kill -9 2>/dev/null || true" + } + runCommand(killCommand, File("."), allowFailure = true, timeoutSeconds = null) + } + + @JvmStatic + protected fun findFreePort(): Int { + ServerSocket(0).use { socket -> + return socket.localPort + } + } + + private fun runCommand( + command: String, + workingDir: File, + allowFailure: Boolean, + timeoutSeconds: Long?, + ): String { + if (!workingDir.exists()) { + if (!workingDir.mkdirs()) { + throw RuntimeException("Failed to create working directory: ${workingDir.absolutePath}") + } + } + + if (!workingDir.isDirectory || !workingDir.canRead()) { + throw RuntimeException("Working directory is not accessible: ${workingDir.absolutePath}") + } + + val isWindows = System.getProperty("os.name").lowercase().contains("windows") + val processBuilder = if (isWindows) { + ProcessBuilder() + .command("cmd.exe", "/c", "set TYPESCRIPT_SDK_DIR=${sdkDir.absolutePath} && $command") + } else { + ProcessBuilder() + .command("bash", "-c", "TYPESCRIPT_SDK_DIR='${sdkDir.absolutePath}' $command") + } + + val process = processBuilder + .directory(workingDir) + .redirectErrorStream(true) + .start() + + val output = StringBuilder() + BufferedReader(InputStreamReader(process.inputStream)).use { reader -> + var line: String? + while (reader.readLine().also { line = it } != null) { + println(line) + output.append(line).append("\n") + } + } + + if (timeoutSeconds == null) { + val exitCode = process.waitFor() + if (!allowFailure && exitCode != 0) { + throw RuntimeException( + "Command execution failed with exit code $exitCode: $command\nWorking dir: ${workingDir.absolutePath}\nOutput:\n$output", + ) + } + } else { + process.waitFor(timeoutSeconds, TimeUnit.SECONDS) + } + + return output.toString() + } + } + + protected fun waitForProcessTermination(process: Process, timeoutSeconds: Long): Boolean { + if (process.isAlive && !process.waitFor(timeoutSeconds, TimeUnit.SECONDS)) { + process.destroyForcibly() + process.waitFor(2, TimeUnit.SECONDS) + return false + } + return true + } + + protected fun createProcessOutputReader(process: Process, prefix: String = "TS-SERVER"): Thread { + val outputReader = Thread { + try { + process.inputStream.bufferedReader().useLines { lines -> + for (line in lines) { + println("[$prefix] $line") + } + } + } catch (e: Exception) { + println("Warning: Error reading process output: ${e.message}") + } + } + outputReader.isDaemon = true + return outputReader + } + + protected fun waitForPort(host: String = "localhost", port: Int, timeoutSeconds: Long = 10): Boolean { + val deadline = System.currentTimeMillis() + timeoutSeconds * 1000 + while (System.currentTimeMillis() < deadline) { + try { + Socket(host, port).use { return true } + } catch (_: Exception) { + Thread.sleep(100) + } + } + return false + } + + protected fun executeCommandAllowingFailure(command: String, workingDir: File, timeoutSeconds: Long = 20): String = + runCommand(command, workingDir, allowFailure = true, timeoutSeconds = timeoutSeconds) + + protected fun startTypeScriptServer(port: Int): Process { + killProcessOnPort(port) + + if (!sdkDir.exists() || !sdkDir.isDirectory) { + throw IllegalStateException( + "TypeScript SDK directory does not exist or is not accessible: ${sdkDir.absolutePath}", + ) + } + + val isWindows = System.getProperty("os.name").lowercase().contains("windows") + val processBuilder = if (isWindows) { + ProcessBuilder() + .command("cmd.exe", "/c", "set MCP_PORT=$port && npx tsx src/examples/server/simpleStreamableHttp.ts") + } else { + ProcessBuilder() + .command("bash", "-c", "MCP_PORT=$port npx tsx src/examples/server/simpleStreamableHttp.ts") + } + + val process = processBuilder + .directory(sdkDir) + .redirectErrorStream(true) + .start() + + if (!waitForPort(port = port)) { + throw IllegalStateException("TypeScript server did not become ready on localhost:$port within timeout") + } + createProcessOutputReader(process).start() + return process + } + + protected fun stopProcess(process: Process, waitSeconds: Long = 3, name: String = "TypeScript server") { + process.destroy() + if (waitForProcessTermination(process, waitSeconds)) { + println("$name stopped gracefully") + } else { + println("$name did not stop gracefully, forced termination") + } + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt new file mode 100644 index 00000000..535304a8 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/KotlinServerForTypeScriptClient.kt @@ -0,0 +1,492 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.utils + +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.http.ContentType +import io.ktor.http.HttpStatusCode +import io.ktor.server.application.ApplicationCall +import io.ktor.server.cio.CIO +import io.ktor.server.engine.EmbeddedServer +import io.ktor.server.engine.embeddedServer +import io.ktor.server.request.header +import io.ktor.server.request.receiveText +import io.ktor.server.response.header +import io.ktor.server.response.respond +import io.ktor.server.response.respondText +import io.ktor.server.response.respondTextWriter +import io.ktor.server.routing.delete +import io.ktor.server.routing.get +import io.ktor.server.routing.post +import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.ErrorCode +import io.modelcontextprotocol.kotlin.sdk.GetPromptResult +import io.modelcontextprotocol.kotlin.sdk.Implementation +import io.modelcontextprotocol.kotlin.sdk.JSONRPCError +import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.JSONRPCNotification +import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.LoggingLevel +import io.modelcontextprotocol.kotlin.sdk.LoggingMessageNotification +import io.modelcontextprotocol.kotlin.sdk.PromptArgument +import io.modelcontextprotocol.kotlin.sdk.PromptMessage +import io.modelcontextprotocol.kotlin.sdk.ReadResourceResult +import io.modelcontextprotocol.kotlin.sdk.RequestId +import io.modelcontextprotocol.kotlin.sdk.Role +import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.TextContent +import io.modelcontextprotocol.kotlin.sdk.TextResourceContents +import io.modelcontextprotocol.kotlin.sdk.Tool +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport +import io.modelcontextprotocol.kotlin.sdk.shared.McpJson +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeoutOrNull +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.contentOrNull +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.jsonPrimitive +import java.util.UUID +import java.util.concurrent.ConcurrentHashMap + +private val logger = KotlinLogging.logger {} + +class KotlinServerForTypeScriptClient { + private val serverTransports = ConcurrentHashMap() + private val jsonFormat = Json { ignoreUnknownKeys = true } + private var server: EmbeddedServer<*, *>? = null + + fun start(port: Int = 3000) { + logger.info { "Starting HTTP server on port $port" } + + server = embeddedServer(CIO, port = port) { + routing { + get("/mcp") { + val sessionId = call.request.header("mcp-session-id") + if (sessionId == null) { + call.respond(HttpStatusCode.BadRequest, "Missing mcp-session-id header") + return@get + } + val transport = serverTransports[sessionId] + if (transport == null) { + call.respond(HttpStatusCode.BadRequest, "Invalid mcp-session-id") + return@get + } + transport.stream(call) + } + + post("/mcp") { + val sessionId = call.request.header("mcp-session-id") + val requestBody = call.receiveText() + + logger.debug { "Received request with sessionId: $sessionId" } + logger.trace { "Request body: $requestBody" } + + val jsonElement = try { + jsonFormat.parseToJsonElement(requestBody) + } catch (e: Exception) { + logger.error(e) { "Failed to parse request body as JSON" } + call.respond( + HttpStatusCode.BadRequest, + jsonFormat.encodeToString( + JsonObject.serializer(), + JsonObject( + mapOf( + "jsonrpc" to JsonPrimitive("2.0"), + "error" to JsonObject( + mapOf( + "code" to JsonPrimitive(-32700), + "message" to JsonPrimitive("Parse error: ${e.message}"), + ), + ), + "id" to JsonNull, + ), + ), + ), + ) + return@post + } + + if (sessionId != null && serverTransports.containsKey(sessionId)) { + logger.debug { "Using existing transport for session: $sessionId" } + val transport = serverTransports[sessionId]!! + transport.handleRequest(call, jsonElement) + } else { + if (isInitializeRequest(jsonElement)) { + val newSessionId = UUID.randomUUID().toString() + logger.info { "Creating new session with ID: $newSessionId" } + + val transport = HttpServerTransport(newSessionId) + + serverTransports[newSessionId] = transport + + val mcpServer = createMcpServer() + + call.response.header("mcp-session-id", newSessionId) + + val serverThread = Thread { + runBlocking { + mcpServer.connect(transport) + } + } + serverThread.start() + + Thread.sleep(500) + + transport.handleRequest(call, jsonElement) + } else { + logger.warn { "Invalid request: no session ID or not an initialization request" } + call.respond( + HttpStatusCode.BadRequest, + jsonFormat.encodeToString( + JsonObject.serializer(), + JsonObject( + mapOf( + "jsonrpc" to JsonPrimitive("2.0"), + "error" to JsonObject( + mapOf( + "code" to JsonPrimitive(-32000), + "message" to + JsonPrimitive("Bad Request: No valid session ID provided"), + ), + ), + "id" to JsonNull, + ), + ), + ), + ) + } + } + } + + delete("/mcp") { + val sessionId = call.request.header("mcp-session-id") + if (sessionId != null && serverTransports.containsKey(sessionId)) { + logger.info { "Terminating session: $sessionId" } + val transport = serverTransports[sessionId]!! + serverTransports.remove(sessionId) + runBlocking { + transport.close() + } + call.respond(HttpStatusCode.OK) + } else { + logger.warn { "Invalid session termination request: $sessionId" } + call.respond(HttpStatusCode.BadRequest, "Invalid or missing session ID") + } + } + } + } + + server?.start(wait = false) + } + + fun stop() { + logger.info { "Stopping HTTP server" } + server?.stop(500, 1000) + server = null + } + + private fun createMcpServer(): Server { + val server = Server( + Implementation( + name = "kotlin-http-server", + version = "1.0.0", + ), + ServerOptions( + capabilities = ServerCapabilities( + prompts = ServerCapabilities.Prompts(listChanged = true), + resources = ServerCapabilities.Resources(subscribe = true, listChanged = true), + tools = ServerCapabilities.Tools(listChanged = true), + ), + ), + ) + + server.addTool( + name = "greet", + description = "A simple greeting tool", + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "name", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("Name to greet")) + }, + ) + }, + required = listOf("name"), + ), + ) { request -> + val name = (request.arguments["name"] as? JsonPrimitive)?.content ?: "World" + CallToolResult( + content = listOf(TextContent("Hello, $name!")), + structuredContent = buildJsonObject { + put("greeting", JsonPrimitive("Hello, $name!")) + }, + ) + } + + server.addTool( + name = "multi-greet", + description = "A greeting tool that sends multiple notifications", + inputSchema = Tool.Input( + properties = buildJsonObject { + put( + "name", + buildJsonObject { + put("type", JsonPrimitive("string")) + put("description", JsonPrimitive("Name to greet")) + }, + ) + }, + required = listOf("name"), + ), + ) { request -> + val name = (request.arguments["name"] as? JsonPrimitive)?.content ?: "World" + + server.sendToolListChanged() + server.sendLoggingMessage( + LoggingMessageNotification( + LoggingMessageNotification.Params( + level = LoggingLevel.info, + data = JsonPrimitive("Preparing greeting for $name"), + ), + ), + ) + server.sendLoggingMessage( + LoggingMessageNotification( + LoggingMessageNotification.Params( + level = LoggingLevel.info, + data = JsonPrimitive("Halfway there for $name"), + ), + ), + ) + server.sendLoggingMessage( + LoggingMessageNotification( + LoggingMessageNotification.Params( + level = LoggingLevel.info, + data = JsonPrimitive("Done sending greetings to $name"), + ), + ), + ) + + CallToolResult( + content = listOf(TextContent("Multiple greetings sent to $name!")), + structuredContent = buildJsonObject { + put("greeting", JsonPrimitive("Multiple greetings sent to $name!")) + put("notificationCount", JsonPrimitive(3)) + }, + ) + } + + server.addPrompt( + name = "greeting-template", + description = "A simple greeting prompt template", + arguments = listOf( + PromptArgument( + name = "name", + description = "Name to include in greeting", + required = true, + ), + ), + ) { request -> + GetPromptResult( + "Greeting for ${request.name}", + messages = listOf( + PromptMessage( + role = Role.user, + content = TextContent( + "Please greet ${request.arguments?.get("name") ?: "someone"} in a friendly manner.", + ), + ), + ), + ) + } + + server.addResource( + uri = "https://example.com/greetings/default", + name = "Default Greeting", + description = "A simple greeting resource", + mimeType = "text/plain", + ) { request -> + ReadResourceResult( + contents = listOf( + TextResourceContents("Hello, world!", request.uri, "text/plain"), + ), + ) + } + + return server + } + + private fun isInitializeRequest(json: JsonElement): Boolean { + if (json !is JsonObject) return false + + val method = json["method"]?.jsonPrimitive?.contentOrNull + return method == "initialize" + } +} + +class HttpServerTransport(private val sessionId: String) : AbstractTransport() { + private val logger = KotlinLogging.logger {} + private val pendingResponses = ConcurrentHashMap>() + private val messageQueue = Channel(Channel.UNLIMITED) + + suspend fun stream(call: ApplicationCall) { + logger.debug { "Starting SSE stream for session: $sessionId" } + call.response.header("Cache-Control", "no-cache") + call.response.header("Connection", "keep-alive") + call.respondTextWriter(ContentType.Text.EventStream) { + try { + while (true) { + val result = messageQueue.receiveCatching() + val msg = result.getOrNull() ?: break + val json = McpJson.encodeToString(msg) + write("event: message\n") + write("data: ") + write(json) + write("\n\n") + flush() + } + } catch (e: Exception) { + logger.warn(e) { "SSE stream terminated for session: $sessionId" } + } finally { + logger.debug { "SSE stream closed for session: $sessionId" } + } + } + } + + suspend fun handleRequest(call: ApplicationCall, requestBody: JsonElement) { + try { + logger.info { "Handling request body: $requestBody" } + val message = McpJson.decodeFromJsonElement(requestBody) + logger.info { "Decoded message: $message" } + + if (message is JSONRPCRequest) { + val id = message.id.toString() + logger.info { "Received request with ID: $id, method: ${message.method}" } + val responseDeferred = CompletableDeferred() + pendingResponses[id] = responseDeferred + logger.info { "Created deferred response for ID: $id" } + + logger.info { "Invoking onMessage handler" } + _onMessage.invoke(message) + logger.info { "onMessage handler completed" } + + try { + val response = withTimeoutOrNull(10000) { + responseDeferred.await() + } + + if (response != null) { + val jsonResponse = McpJson.encodeToString(response) + call.respondText(jsonResponse, ContentType.Application.Json) + } else { + logger.warn { "Timeout waiting for response to request ID: $id" } + call.respondText( + McpJson.encodeToString( + JSONRPCResponse( + id = message.id, + error = JSONRPCError( + code = ErrorCode.Defined.RequestTimeout, + message = "Request timed out", + ), + ), + ), + ContentType.Application.Json, + ) + } + } catch (_: CancellationException) { + logger.warn { "Request cancelled for ID: $id" } + pendingResponses.remove(id) + if (!call.response.isCommitted) { + call.respondText( + McpJson.encodeToString( + JSONRPCResponse( + id = message.id, + error = JSONRPCError( + code = ErrorCode.Defined.ConnectionClosed, + message = "Request cancelled", + ), + ), + ), + ContentType.Application.Json, + HttpStatusCode.ServiceUnavailable, + ) + } + } + } else { + call.respondText("", ContentType.Application.Json, HttpStatusCode.Accepted) + } + } catch (e: Exception) { + logger.error(e) { "Error handling request" } + if (!call.response.isCommitted) { + try { + val errorResponse = JSONRPCResponse( + id = RequestId.NumberId(0), + error = JSONRPCError( + code = ErrorCode.Defined.InternalError, + message = "Internal server error: ${e.message}", + ), + ) + + call.respondText( + McpJson.encodeToString(errorResponse), + ContentType.Application.Json, + HttpStatusCode.InternalServerError, + ) + } catch (responseEx: Exception) { + logger.error(responseEx) { "Failed to send error response" } + } + } + } + } + + override suspend fun start() { + logger.debug { "Starting HTTP server transport for session: $sessionId" } + } + + override suspend fun send(message: JSONRPCMessage) { + logger.info { "Sending message: $message" } + + if (message is JSONRPCResponse) { + val id = message.id.toString() + logger.info { "Sending response for request ID: $id" } + val deferred = pendingResponses.remove(id) + if (deferred != null) { + logger.info { "Found pending response for ID: $id, completing deferred" } + deferred.complete(message) + return + } else { + logger.warn { "No pending response found for ID: $id" } + } + } else if (message is JSONRPCRequest) { + logger.info { "Sending request with ID: ${message.id}" } + } else if (message is JSONRPCNotification) { + logger.info { "Sending notification: ${message.method}" } + } + + logger.info { "Queueing message for next client request" } + messageQueue.send(message) + } + + override suspend fun close() { + logger.debug { "Closing HTTP server transport for session: $sessionId" } + messageQueue.close() + _onClose.invoke() + } +} + +fun main() { + val server = KotlinServerForTypeScriptClient() + server.start() +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/Retry.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/Retry.kt new file mode 100644 index 00000000..32f20534 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/Retry.kt @@ -0,0 +1,97 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.utils + +import org.junit.jupiter.api.extension.ExtendWith +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.api.extension.InvocationInterceptor +import org.junit.jupiter.api.extension.InvocationInterceptor.Invocation +import org.junit.jupiter.api.extension.ReflectiveInvocationContext +import org.opentest4j.TestAbortedException +import java.lang.reflect.AnnotatedElement +import java.lang.reflect.Method +import java.util.Optional + +@Target(AnnotationTarget.CLASS) +@Retention(AnnotationRetention.RUNTIME) +@ExtendWith(RetryExtension::class) +annotation class Retry(val times: Int = 3, val delayMs: Long = 1000) + +class RetryExtension : InvocationInterceptor { + override fun interceptTestMethod( + invocation: Invocation, + invocationContext: ReflectiveInvocationContext, + extensionContext: ExtensionContext, + ) { + executeWithRetry(invocation, extensionContext) + } + + private fun resolveRetryAnnotation(extensionContext: ExtensionContext): Retry? { + val classAnn = extensionContext.testClass.flatMap { findRetry(it) } + return classAnn.orElse(null) + } + + private fun findRetry(element: AnnotatedElement): Optional = + Optional.ofNullable(element.getAnnotation(Retry::class.java)) + + private fun executeWithRetry(invocation: Invocation, extensionContext: ExtensionContext) { + val retry = resolveRetryAnnotation(extensionContext) + if (retry == null || retry.times <= 1) { + invocation.proceed() + return + } + + val maxAttempts = retry.times + val delay = retry.delayMs + var lastError: Throwable? = null + + for (attempt in 1..maxAttempts) { + if (attempt > 1 && delay > 0) { + try { + Thread.sleep(delay) + } catch (_: InterruptedException) { + Thread.currentThread().interrupt() + break + } + } + + try { + if (attempt == 1) { + invocation.proceed() + } else { + val instance = extensionContext.requiredTestInstance + val testMethod = extensionContext.requiredTestMethod + testMethod.isAccessible = true + testMethod.invoke(instance) + } + return + } catch (t: Throwable) { + if (t is TestAbortedException) throw t + lastError = if (t is java.lang.reflect.InvocationTargetException) t.targetException ?: t else t + if (attempt == maxAttempts) { + println( + "[Retry] Giving up after $attempt attempts for ${ + describeTest( + extensionContext, + ) + }: ${lastError.message}", + ) + throw lastError + } + println( + "[Retry] Failure on attempt $attempt/$maxAttempts for ${ + describeTest( + extensionContext, + ) + }: ${lastError.message}", + ) + } + } + + throw lastError ?: IllegalStateException("Unexpected state in retry logic") + } + + private fun describeTest(ctx: ExtensionContext): String { + val methodName = ctx.testMethod.map(Method::getName).orElse("") + val className = ctx.testClass.map { it.name }.orElse("") + return "$className#$methodName" + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/TestUtils.kt b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/TestUtils.kt new file mode 100644 index 00000000..bed66cd4 --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/TestUtils.kt @@ -0,0 +1,91 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.utils + +import io.modelcontextprotocol.kotlin.sdk.CallToolResultBase +import io.modelcontextprotocol.kotlin.sdk.PromptMessageContent +import io.modelcontextprotocol.kotlin.sdk.TextContent +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withContext +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertTrue + +object TestUtils { + fun runTest(block: suspend () -> T): T = runBlocking { + withContext(Dispatchers.IO) { + block() + } + } + + fun assertTextContent(content: PromptMessageContent?, expectedText: String) { + assertNotNull(content, "Content should not be null") + assertTrue(content is TextContent, "Content should be TextContent") + assertNotNull(content.text, "Text content should not be null") + assertEquals(expectedText, content.text, "Text content should match") + } + + fun assertCallToolResult(result: Any?, message: String = ""): CallToolResultBase { + assertNotNull(result, "${message}Call tool result should not be null") + assertTrue(result is CallToolResultBase, "${message}Result should be CallToolResultBase") + assertTrue(result.content.isNotEmpty(), "${message}Tool result content should not be empty") + assertNotNull(result.structuredContent, "${message}Tool result structured content should not be null") + + return result + } + + /** + * Asserts that a JSON property has the expected string value. + */ + fun assertJsonProperty( + json: JsonObject, + property: String, + expectedValue: String, + message: String = "", + ) { + assertEquals(expectedValue, json[property]?.toString()?.trim('"'), "${message}$property should match") + } + + /** + * Asserts that a JSON property has the expected numeric value. + */ + fun assertJsonProperty( + json: JsonObject, + property: String, + expectedValue: Number, + message: String = "", + ) { + when (expectedValue) { + is Int -> assertEquals( + expectedValue, + (json[property] as? JsonPrimitive)?.content?.toIntOrNull(), + "${message}$property should match", + ) + + is Double -> assertEquals( + expectedValue, + (json[property] as? JsonPrimitive)?.content?.toDoubleOrNull(), + "${message}$property should match", + ) + + else -> assertEquals( + expectedValue.toString(), + json[property]?.toString()?.trim('"'), + "${message}$property should match", + ) + } + } + + /** + * Asserts that a JSON property has the expected boolean value. + */ + fun assertJsonProperty( + json: JsonObject, + property: String, + expectedValue: Boolean, + message: String = "", + ) { + assertEquals(expectedValue.toString(), json[property].toString(), "${message}$property should match") + } +} diff --git a/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/myClient.ts b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/myClient.ts new file mode 100644 index 00000000..42a14f5f --- /dev/null +++ b/kotlin-sdk-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/utils/myClient.ts @@ -0,0 +1,129 @@ +// @ts-ignore +const args = process.argv.slice(2); +const serverUrl = args[0] || 'http://localhost:3001/mcp'; +const toolName = args[1]; +const toolArgs = args.slice(2); +const PROTOCOL_VERSION = "2024-11-05"; + +// @ts-ignore +async function main() { + // @ts-ignore + const sdkDirRaw = process.env.TYPESCRIPT_SDK_DIR; + const sdkDir = sdkDirRaw ? sdkDirRaw.trim() : undefined; + let Client: any; + let StreamableHTTPClientTransport: any; + if (sdkDir) { + // @ts-ignore + const path = await import('path'); + // @ts-ignore + const { pathToFileURL } = await import('url'); + const clientUrl = pathToFileURL(path.join(sdkDir, 'src', 'client', 'index.ts')).href; + const streamUrl = pathToFileURL(path.join(sdkDir, 'src', 'client', 'streamableHttp.js')).href; + // @ts-ignore + ({ Client } = await import(clientUrl)); + // @ts-ignore + ({ StreamableHTTPClientTransport } = await import(streamUrl)); + } else { + // @ts-ignore + ({Client} = await import("../../../../../../../resources/typescript-sdk/src/client")); + // @ts-ignore + ({StreamableHTTPClientTransport} = await import("../../../../../../../resources/typescript-sdk/src/client/streamableHttp.js")); + } + if (!toolName) { + console.log('Usage: npx tsx myClient.ts [server-url] [tool-args...]'); + console.log('Using default server URL:', serverUrl); + console.log('Available utils will be listed after connection'); + } + + console.log(`Connecting to server at ${serverUrl}`); + if (toolName) { + console.log(`Will call tool: ${toolName} with args: ${toolArgs.join(', ')}`); + } + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(new URL(serverUrl)); + + try { + await client.connect(transport, {protocolVersion: PROTOCOL_VERSION}); + console.log('Connected to server'); + + try { + if (typeof (client as any).on === 'function') { + (client as any).on('notification', (n: any) => { + try { + const method = (n && (n.method || (n.params && n.params.method))) || 'unknown'; + console.log('Notification:', method, JSON.stringify(n)); + } catch { + console.log('Notification: '); + } + }); + } + } catch { + } + + const toolsResult = await client.listTools(); + const tools = toolsResult.tools; + console.log('Available utils:', tools.map((t: { name: any; }) => t.name).join(', ')); + + if (!toolName) { + await client.close(); + return; + } + + const tool = tools.find((t: { name: string; }) => t.name === toolName); + if (!tool) { + console.error(`Tool "${toolName}" not found`); + // @ts-ignore + process.exit(1); + } + + const toolArguments = {}; + + if (toolName === "greet" && toolArgs.length > 0) { + toolArguments["name"] = toolArgs[0]; + } else if (tool.input && tool.input.properties) { + const propNames = Object.keys(tool.input.properties); + if (propNames.length > 0 && toolArgs.length > 0) { + toolArguments[propNames[0]] = toolArgs[0]; + } + } + + console.log(`Calling tool ${toolName} with arguments:`, toolArguments); + + const result = await client.callTool({ + name: toolName, + arguments: toolArguments + }); + console.log('Tool result:', result); + + if (result.content) { + for (const content of result.content) { + if (content.type === 'text') { + console.log('Text content:', content.text); + } + } + } + + if (result.structuredContent) { + console.log('Structured content:', JSON.stringify(result.structuredContent, null, 2)); + } + + } catch (error) { + console.error('Error:', error); + // @ts-ignore + process.exit(1); + } finally { + await client.close(); + console.log('Disconnected from server'); + } +} + +main().catch(error => { + console.error('Unhandled error:', error); + // @ts-ignore + process.exit(1); +});