Skip to content

feat: support embedded resources in sampling #727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,14 +653,6 @@ class ImageContent(BaseModel):
model_config = ConfigDict(extra="allow")


class SamplingMessage(BaseModel):
"""Describes a message issued to or received from an LLM API."""

role: Role
content: TextContent | ImageContent
model_config = ConfigDict(extra="allow")


class EmbeddedResource(BaseModel):
"""
The contents of a resource, embedded into a prompt or tool call result.
Expand All @@ -675,6 +667,14 @@ class EmbeddedResource(BaseModel):
model_config = ConfigDict(extra="allow")


class SamplingMessage(BaseModel):
"""Describes a message issued to or received from an LLM API."""

role: Role
content: TextContent | ImageContent | EmbeddedResource
model_config = ConfigDict(extra="allow")


class PromptMessage(BaseModel):
"""Describes a message returned as part of a prompt."""

Expand Down Expand Up @@ -961,7 +961,7 @@ class CreateMessageResult(Result):
"""The client's response to a sampling/create_message request from the server."""

role: Role
content: TextContent | ImageContent
content: TextContent | ImageContent | EmbeddedResource
model: str
"""The name of the model that generated the message."""
stopReason: StopReason | None = None
Expand Down
51 changes: 45 additions & 6 deletions tests/server/fastmcp/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from mcp.types import (
CreateMessageRequestParams,
CreateMessageResult,
EmbeddedResource,
GetPromptResult,
InitializeResult,
ReadResourceResult,
Expand Down Expand Up @@ -144,6 +145,37 @@ async def sampling_tool(prompt: str, ctx: Context) -> str:
else:
return f"Sampling result: {str(result.content)[:100]}..."

# Tool with sampling capability
@mcp.tool(description="A tool that uses sampling to generate a resource")
async def sampling_tool_resource(prompt: str, ctx: Context) -> str:
await ctx.info(f"Requesting sampling for prompt: {prompt}")

# Request sampling from the client
result = await ctx.session.create_message(
messages=[
SamplingMessage(
role="user",
content=EmbeddedResource(
type="resource",
resource=TextResourceContents(
uri=AnyUrl("file://prompt"),
text=prompt,
mimeType="text/plain",
),
),
)
],
max_tokens=100,
temperature=0.7,
)

await ctx.info(f"Received sampling result from model: {result.model}")
# Handle different content types
if result.content.type == "text":
return f"Sampling result: {result.content.text[:100]}..."
else:
return f"Sampling result: {str(result.content)[:100]}..."

# Tool that sends notifications and logging
@mcp.tool(description="A tool that demonstrates notifications and logging")
async def notification_tool(message: str, ctx: Context) -> str:
Expand Down Expand Up @@ -694,12 +726,13 @@ async def progress_callback(
assert len(collector.log_messages) > 0

# 3. Test sampling tool
prompt = "What is the meaning of life?"
sampling_result = await session.call_tool("sampling_tool", {"prompt": prompt})
assert len(sampling_result.content) == 1
assert isinstance(sampling_result.content[0], TextContent)
assert "Sampling result:" in sampling_result.content[0].text
assert "This is a simulated LLM response" in sampling_result.content[0].text
for tool in ["sampling_tool", "sampling_tool_resource"]:
prompt = "What is the meaning of life?"
sampling_result = await session.call_tool(tool, {"prompt": prompt})
assert len(sampling_result.content) == 1
assert isinstance(sampling_result.content[0], TextContent)
assert "Sampling result:" in sampling_result.content[0].text
assert "This is a simulated LLM response" in sampling_result.content[0].text

# Verify we received log messages from the sampling tool
assert len(collector.log_messages) > 0
Expand Down Expand Up @@ -810,6 +843,12 @@ async def sampling_callback(
# Simulate LLM response based on the input
if params.messages and isinstance(params.messages[0].content, TextContent):
input_text = params.messages[0].content.text
elif (
params.messages
and isinstance(params.messages[0].content, EmbeddedResource)
and isinstance(params.messages[0].content.resource, TextResourceContents)
):
input_text = params.messages[0].content.resource.text
else:
input_text = "No input"
response_text = f"This is a simulated LLM response to: {input_text}"
Expand Down
Loading