diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index c714c44bb..fed2efc77 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -271,6 +271,7 @@ async def call_tool( arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, progress_callback: ProgressFnT | None = None, + webhooks: list[types.Webhook] | None = None, ) -> types.CallToolResult: """Send a tools/call request with optional progress callback support.""" @@ -281,6 +282,7 @@ async def call_tool( params=types.CallToolRequestParams( name=name, arguments=arguments, + webhooks=webhooks, ), ) ), diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 3282baae6..691b9dcfe 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -99,6 +99,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]): stateless_http: bool = ( False # If True, uses true stateless mode (new transport per request) ) + webhooks_supported: bool = False # resource settings warn_on_duplicate_resources: bool = True @@ -150,11 +151,10 @@ def __init__( self._mcp_server = MCPServer( name=name or "FastMCP", instructions=instructions, - lifespan=( - lifespan_wrapper(self, self.settings.lifespan) - if self.settings.lifespan - else default_lifespan - ), + lifespan=lifespan_wrapper(self, self.settings.lifespan) + if self.settings.lifespan + else default_lifespan, + webhooks_supported=self.settings.webhooks_supported, ) self._tool_manager = ToolManager( tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools @@ -165,6 +165,7 @@ def __init__( self._prompt_manager = PromptManager( warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts ) + if (self.settings.auth is not None) != (auth_server_provider is not None): # TODO: after we support separate authorization servers (see # https://github.com/modelcontextprotocol/modelcontextprotocol/pull/284) @@ -272,11 +273,17 @@ def get_context(self) -> Context[ServerSession, object]: return Context(request_context=request_context, fastmcp=self) async def call_tool( - self, name: str, arguments: dict[str, Any] + self, + name: str, + arguments: dict[str, Any], ) -> Sequence[TextContent | ImageContent | EmbeddedResource]: """Call a tool by name with arguments.""" context = self.get_context() - result = await self._tool_manager.call_tool(name, arguments, context=context) + result = await self._tool_manager.call_tool( + name, + arguments, + context=context, + ) converted_result = _convert_to_content(result) return converted_result @@ -777,6 +784,8 @@ def streamable_http_app(self) -> Starlette: event_store=self._event_store, json_response=self.settings.json_response, stateless=self.settings.stateless_http, # Use the stateless setting + webhooks_supported=self.settings.webhooks_supported, + # Use the webhooks supported setting ) # Create the ASGI handler @@ -929,6 +938,7 @@ def my_tool(x: int, ctx: Context) -> str: _request_context: RequestContext[ServerSessionT, LifespanContextT] | None _fastmcp: FastMCP | None + has_webhook: bool = False def __init__( self, diff --git a/src/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/server/fastmcp/tools/tool_manager.py index 6ec4fd151..9e0f28419 100644 --- a/src/mcp/server/fastmcp/tools/tool_manager.py +++ b/src/mcp/server/fastmcp/tools/tool_manager.py @@ -72,4 +72,10 @@ async def call_tool( if not tool: raise ToolError(f"Unknown tool: {name}") + if context is not None: + try: + context.has_webhook = context.request_context.has_webhook + except Exception: + logger.debug("Request context is not available.") + return await tool.run(arguments, context=context) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 876aef817..94dca6720 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -132,6 +132,7 @@ def __init__( lifespan: Callable[ [Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT] ] = lifespan, + webhooks_supported: bool = False, ): self.name = name self.version = version @@ -144,6 +145,7 @@ def __init__( } self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self.notification_options = NotificationOptions() + self.webhooks_supported = webhooks_supported logger.debug(f"Initializing server '{name}'") def create_initialization_options( @@ -199,7 +201,8 @@ def get_capabilities( # Set tool capabilities if handler exists if types.ListToolsRequest in self.request_handlers: tools_capability = types.ToolsCapability( - listChanged=notification_options.tools_changed + listChanged=notification_options.tools_changed, + webhooksSupported=self.webhooks_supported, ) # Set logging capabilities if handler exists @@ -409,6 +412,8 @@ def decorator( async def handler(req: types.CallToolRequest): try: + if req.params.webhooks is not None and len(req.params.webhooks) > 0: + self.request_context.has_webhook = True results = await func(req.params.name, (req.params.arguments or {})) return types.ServerResult( types.CallToolResult(content=list(results), isError=False) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 8f4a1f512..dc857bc70 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -7,6 +7,7 @@ responses, with streaming support for long-running operations. """ +import asyncio import json import logging import re @@ -24,6 +25,7 @@ from starlette.responses import Response from starlette.types import Receive, Scope, Send +from mcp.shared._httpx_utils import create_mcp_http_client from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.types import ( INTERNAL_ERROR, @@ -36,6 +38,7 @@ JSONRPCRequest, JSONRPCResponse, RequestId, + Webhook, ) logger = logging.getLogger(__name__) @@ -136,6 +139,7 @@ def __init__( self, mcp_session_id: str | None, is_json_response_enabled: bool = False, + is_webhooks_supported: bool = False, event_store: EventStore | None = None, ) -> None: """ @@ -146,6 +150,10 @@ def __init__( Must contain only visible ASCII characters (0x21-0x7E). is_json_response_enabled: If True, return JSON responses for requests instead of SSE streams. Default is False. + is_webhooks_supported: If True and if webhooks are provided in + tools/call request, the client will receive an Accepted + HTTP response and the CallTool response will be sent to + the webhook. Default is False. event_store: Event store for resumability support. If provided, resumability will be enabled, allowing clients to reconnect and resume messages. @@ -162,6 +170,7 @@ def __init__( self.mcp_session_id = mcp_session_id self.is_json_response_enabled = is_json_response_enabled + self.is_webhooks_supported = is_webhooks_supported self._event_store = event_store self._request_streams: dict[ RequestId, @@ -410,9 +419,45 @@ async def _handle_post_request( ](0) request_stream_reader = self._request_streams[request_id][1] + session_message = SessionMessage(message) + webhooks = self._get_webhooks(session_message.message) + if webhooks is not None: + if self.is_webhooks_supported: + result = { + "content": [ + { + "type": "text", + "text": "Response will be forwarded to the webhook.", + } + ], + "isError": False, + } + response = self._create_json_response( + JSONRPCMessage( + root=JSONRPCResponse( + jsonrpc="2.0", id=message.root.id, result=result + ) + ), + HTTPStatus.OK, + ) + asyncio.create_task( + self._send_response_to_webhooks( + request_id, session_message, webhooks, request_stream_reader + ) + ) + else: + logger.exception("Webhooks not supported error") + err = "Webhooks not supported" + response = self._create_error_response( + f"Validation error: {err}", + HTTPStatus.BAD_REQUEST, + INVALID_PARAMS, + ) + await response(scope, receive, send) + return + if self.is_json_response_enabled: # Process the message - session_message = SessionMessage(message) await writer.send(session_message) try: # Process messages from the request-specific stream @@ -531,6 +576,126 @@ async def sse_writer(): await writer.send(Exception(err)) return + async def _send_response_to_webhooks( + self, + request_id: str, + session_message: SessionMessage, + webhooks: list[Webhook], + request_stream_reader: MemoryObjectReceiveStream[EventMessage], + ): + writer = self._read_stream_writer + if writer is None: + raise ValueError( + "No read stream writer available. Ensure connect() is called first." + ) + await writer.send(session_message) + + try: + response_message = JSONRPCError( + jsonrpc="2.0", + id="server-error", # We don't have a request ID for general errors + error=ErrorData( + code=INTERNAL_ERROR, + message="Error processing request: No response received", + ), + ) + + if self.is_json_response_enabled: + # Process messages from the request-specific stream + # We need to collect all messages until we get a response + async for event_message in request_stream_reader: + # If it's a response, this is what we're waiting for + if isinstance( + event_message.message.root, JSONRPCResponse | JSONRPCError + ): + response_message = event_message.message + break + # For notifications and request, keep waiting + else: + logger.debug(f"received: {event_message.message.root.method}") + + await self._send_message_to_webhooks(webhooks, response_message) + else: + # Send each event on the request stream as a separate message + async for event_message in request_stream_reader: + event_data = self._create_event_data(event_message) + await self._send_message_to_webhooks(webhooks, event_data) + + # If response, remove from pending streams and close + if isinstance( + event_message.message.root, + JSONRPCResponse | JSONRPCError, + ): + break + + except Exception as e: + logger.exception(f"Error sending response to webhooks: {e}") + + finally: + await self._clean_up_memory_streams(request_id) + + async def _send_message_to_webhooks( + self, + webhooks: list[Webhook], + message: JSONRPCMessage | JSONRPCError | dict[str, str], + ): + for webhook in webhooks: + headers = {"Content-Type": CONTENT_TYPE_JSON} + # Add authorization headers + if webhook.authentication and webhook.authentication.credentials: + if webhook.authentication.strategy == "bearer": + headers["Authorization"] = ( + f"Bearer {webhook.authentication.credentials}" + ) + elif webhook.authentication.strategy == "apiKey": + headers["X-API-Key"] = webhook.authentication.credentials + elif webhook.authentication.strategy == "basic": + try: + # Try to parse as JSON + creds_dict = json.loads(webhook.authentication.credentials) + if "username" in creds_dict and "password" in creds_dict: + # Create basic auth header from username and password + import base64 + + auth_string = ( + f"{creds_dict['username']}:{creds_dict['password']}" + ) + credentials = base64.b64encode( + auth_string.encode() + ).decode() + headers["Authorization"] = f"Basic {credentials}" + except Exception: + # Not JSON, use as-is + headers["Authorization"] = ( + f"Basic {webhook.authentication.credentials}" + ) + elif ( + webhook.authentication.strategy == "customHeader" + and webhook.authentication.credentials + ): + try: + custom_headers = json.loads(webhook.authentication.credentials) + headers.update(custom_headers) + except Exception as e: + logger.exception(f"Error setting custom headers: {e}") + + async with create_mcp_http_client(headers=headers) as client: + try: + if isinstance(message, JSONRPCMessage | JSONRPCError): + await client.post( + webhook.url, + json=message.model_dump_json( + by_alias=True, exclude_none=True + ), + ) + else: + await client.post(webhook.url, json=message) + + except Exception as e: + logger.exception( + f"Error sending response to webhook {webhook.url}: {e}" + ) + async def _handle_get_request(self, request: Request, send: Send) -> None: """ Handle GET request to establish SSE. @@ -651,6 +816,19 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: ) await response(request.scope, request.receive, send) + def _get_webhooks(self, message: JSONRPCMessage) -> list[Webhook] | None: + """Return webhooks if the request is a call tool request with webhooks.""" + if ( + isinstance(message.root, JSONRPCRequest) + and message.root.method == "tools/call" + and message.root.params is not None + and "webhooks" in message.root.params + and message.root.params["webhooks"] is not None + and len(message.root.params["webhooks"]) > 0 + ): + return [Webhook(**webhook) for webhook in message.root.params["webhooks"]] + return None + async def _terminate_session(self) -> None: """Terminate the current session, closing all streams. diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index e5ef8b4aa..adebefe9a 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -60,11 +60,13 @@ def __init__( event_store: EventStore | None = None, json_response: bool = False, stateless: bool = False, + webhooks_supported: bool = False, ): self.app = app self.event_store = event_store self.json_response = json_response self.stateless = stateless + self.webhooks_supported = webhooks_supported # Session tracking (only used if not stateless) self._session_creation_lock = anyio.Lock() @@ -161,6 +163,7 @@ async def _handle_stateless_request( http_transport = StreamableHTTPServerTransport( mcp_session_id=None, # No session tracking in stateless mode is_json_response_enabled=self.json_response, + is_webhooks_supported=self.webhooks_supported, event_store=None, # No event store in stateless mode ) @@ -221,6 +224,7 @@ async def _handle_stateful_request( http_transport = StreamableHTTPServerTransport( mcp_session_id=new_session_id, is_json_response_enabled=self.json_response, + is_webhooks_supported=self.webhooks_supported, event_store=self.event_store, # May be None (no resumability) ) diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py index ae85d3a19..b607061be 100644 --- a/src/mcp/shared/context.py +++ b/src/mcp/shared/context.py @@ -16,3 +16,4 @@ class RequestContext(Generic[SessionT, LifespanContextT]): meta: RequestParams.Meta | None session: SessionT lifespan_context: LifespanContextT + has_webhook: bool = False diff --git a/src/mcp/types.py b/src/mcp/types.py index d864b19da..4552b34cf 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -122,6 +122,7 @@ class JSONRPCRequest(Request[dict[str, Any] | None, str]): id: RequestId method: str params: dict[str, Any] | None = None + webhooks: dict[str, Any] | None = None class JSONRPCNotification(Notification[dict[str, Any] | None, str]): @@ -245,6 +246,8 @@ class ToolsCapability(BaseModel): listChanged: bool | None = None """Whether this server supports notifications for changes to the tool list.""" + webhooksSupported: bool | None = None + """Capability for transmitting tool responses to webhooks.""" model_config = ConfigDict(extra="allow") @@ -703,6 +706,27 @@ class PromptListChangedNotification( params: NotificationParams | None = None +class AuthenticationInfo(BaseModel): + """Used to specify authentication mechanism""" + + strategy: Literal["bearer", "apiKey", "basic", "customHeader"] + """Authentication strategy that the server will follow""" + credentials: str | None = None + """ + Static credentials in the case of bearer, apiKey or basic. + In case of basic and customHeader, this can also be a parsable JSON. + """ + + +class Webhook(BaseModel): + """Used to specify a webhook and authentication method to communicate with it""" + + url: str + """Url to which the response will be transmitted""" + authentication: AuthenticationInfo | None = None + """Authentication required to communicate with the webhook""" + + class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]): """Sent from the client to request a list of tools the server has.""" @@ -783,6 +807,7 @@ class CallToolRequestParams(RequestParams): name: str arguments: dict[str, Any] | None = None + webhooks: list[Webhook] | None = None model_config = ConfigDict(extra="allow")