diff --git a/src/mcp/types.py b/src/mcp/types.py index 465fc6ee6..0e3e71014 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -653,14 +653,6 @@ class ImageContent(BaseModel): model_config = ConfigDict(extra="allow") -class SamplingMessage(BaseModel): - """Describes a message issued to or received from an LLM API.""" - - role: Role - content: TextContent | ImageContent - model_config = ConfigDict(extra="allow") - - class EmbeddedResource(BaseModel): """ The contents of a resource, embedded into a prompt or tool call result. @@ -675,6 +667,14 @@ class EmbeddedResource(BaseModel): model_config = ConfigDict(extra="allow") +class SamplingMessage(BaseModel): + """Describes a message issued to or received from an LLM API.""" + + role: Role + content: TextContent | ImageContent | EmbeddedResource + model_config = ConfigDict(extra="allow") + + class PromptMessage(BaseModel): """Describes a message returned as part of a prompt.""" @@ -961,7 +961,7 @@ class CreateMessageResult(Result): """The client's response to a sampling/create_message request from the server.""" role: Role - content: TextContent | ImageContent + content: TextContent | ImageContent | EmbeddedResource model: str """The name of the model that generated the message.""" stopReason: StopReason | None = None diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 79285ecb1..ca3573d9e 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -24,6 +24,7 @@ from mcp.types import ( CreateMessageRequestParams, CreateMessageResult, + EmbeddedResource, GetPromptResult, InitializeResult, ReadResourceResult, @@ -144,6 +145,37 @@ async def sampling_tool(prompt: str, ctx: Context) -> str: else: return f"Sampling result: {str(result.content)[:100]}..." + # Tool with sampling capability + @mcp.tool(description="A tool that uses sampling to generate a resource") + async def sampling_tool_resource(prompt: str, ctx: Context) -> str: + await ctx.info(f"Requesting sampling for prompt: {prompt}") + + # Request sampling from the client + result = await ctx.session.create_message( + messages=[ + SamplingMessage( + role="user", + content=EmbeddedResource( + type="resource", + resource=TextResourceContents( + uri=AnyUrl("file://prompt"), + text=prompt, + mimeType="text/plain", + ), + ), + ) + ], + max_tokens=100, + temperature=0.7, + ) + + await ctx.info(f"Received sampling result from model: {result.model}") + # Handle different content types + if result.content.type == "text": + return f"Sampling result: {result.content.text[:100]}..." + else: + return f"Sampling result: {str(result.content)[:100]}..." + # Tool that sends notifications and logging @mcp.tool(description="A tool that demonstrates notifications and logging") async def notification_tool(message: str, ctx: Context) -> str: @@ -694,12 +726,13 @@ async def progress_callback( assert len(collector.log_messages) > 0 # 3. Test sampling tool - prompt = "What is the meaning of life?" - sampling_result = await session.call_tool("sampling_tool", {"prompt": prompt}) - assert len(sampling_result.content) == 1 - assert isinstance(sampling_result.content[0], TextContent) - assert "Sampling result:" in sampling_result.content[0].text - assert "This is a simulated LLM response" in sampling_result.content[0].text + for tool in ["sampling_tool", "sampling_tool_resource"]: + prompt = "What is the meaning of life?" + sampling_result = await session.call_tool(tool, {"prompt": prompt}) + assert len(sampling_result.content) == 1 + assert isinstance(sampling_result.content[0], TextContent) + assert "Sampling result:" in sampling_result.content[0].text + assert "This is a simulated LLM response" in sampling_result.content[0].text # Verify we received log messages from the sampling tool assert len(collector.log_messages) > 0 @@ -810,6 +843,12 @@ async def sampling_callback( # Simulate LLM response based on the input if params.messages and isinstance(params.messages[0].content, TextContent): input_text = params.messages[0].content.text + elif ( + params.messages + and isinstance(params.messages[0].content, EmbeddedResource) + and isinstance(params.messages[0].content.resource, TextResourceContents) + ): + input_text = params.messages[0].content.resource.text else: input_text = "No input" response_text = f"This is a simulated LLM response to: {input_text}"