diff --git a/examples/clients/mqtt-clients/client_apis_demo.py b/examples/clients/mqtt-clients/client_apis_demo.py new file mode 100644 index 0000000000..31da377f63 --- /dev/null +++ b/examples/clients/mqtt-clients/client_apis_demo.py @@ -0,0 +1,49 @@ +import logging +import anyio +import mcp.client.mqtt as mcp_mqtt +from mcp.shared.mqtt import configure_logging + +configure_logging(level="DEBUG") +logger = logging.getLogger(__name__) + +async def on_mcp_server_discovered(client, server_name): + logger.info(f"Discovered {server_name}, connecting ...") + await client.initialize_mcp_server(server_name) + +async def on_mcp_connect(client, server_name, connect_result): + capabilities = client.get_session(server_name).server_info.capabilities + logger.info(f"Capabilities of {server_name}: {capabilities}") + if capabilities.prompts: + prompts = await client.list_prompts(server_name) + logger.info(f"Prompts of {server_name}: {prompts}") + if capabilities.resources: + resources = await client.list_resources(server_name) + logger.info(f"Resources of {server_name}: {resources}") + resource_templates = await client.list_resource_templates(server_name) + logger.info(f"Resources templates of {server_name}: {resource_templates}") + if capabilities.tools: + toolsResult = await client.list_tools(server_name) + tools = toolsResult.tools + logger.info(f"Tools of {server_name}: {tools}") + +async def on_mcp_disconnect(client, server_name): + logger.info(f"Disconnected from {server_name}") + +async def main(): + async with mcp_mqtt.MqttTransportClient( + "test_client", + auto_connect_to_mcp_server = True, + on_mcp_server_discovered = on_mcp_server_discovered, + on_mcp_connect = on_mcp_connect, + on_mcp_disconnect = on_mcp_disconnect, + mqtt_options = mcp_mqtt.MqttOptions( + host="broker.emqx.io", + ) + ) as client: + client.start() + while True: + logger.info("Other works while the MQTT transport client is running in the background...") + await anyio.sleep(10) + +if __name__ == "__main__": + anyio.run(main) diff --git a/examples/fastmcp/mqtt_simple_echo.py b/examples/fastmcp/mqtt_simple_echo.py new file mode 100644 index 0000000000..430128f1ca --- /dev/null +++ b/examples/fastmcp/mqtt_simple_echo.py @@ -0,0 +1,20 @@ +""" +FastMCP Echo Server +""" + +from mcp.server.fastmcp import FastMCP + +# Create server +mcp = FastMCP( + "demo_server/echo", + log_level="DEBUG", + mqtt_server_description="A simple FastMCP server that echoes back the input text.", + mqtt_options={ + "host": "broker.emqx.io", + }, +) + +@mcp.tool() +def echo(text: str) -> str: + """Echo the input text""" + return text diff --git a/pyproject.toml b/pyproject.toml index 25514cd6b0..9da6b75ec4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "sse-starlette>=1.6.1", "pydantic-settings>=2.5.2", "uvicorn>=0.23.1", + "paho-mqtt", ] [project.optional-dependencies] @@ -101,10 +102,11 @@ target-version = "py310" "tests/server/fastmcp/test_func_metadata.py" = ["E501"] [tool.uv.workspace] -members = ["examples/servers/*"] +members = ["examples/servers/*", "examples/clients/mqtt-clients/smart-home"] [tool.uv.sources] mcp = { workspace = true } +paho-mqtt = { git = "https://github.com/eclipse-paho/paho.mqtt.python.git", tag = "v2.1.0" } [tool.pytest.ini_options] xfail_strict = true diff --git a/src/mcp/client/mqtt.py b/src/mcp/client/mqtt.py new file mode 100644 index 0000000000..b9fe0bcbff --- /dev/null +++ b/src/mcp/client/mqtt.py @@ -0,0 +1,477 @@ +""" +This module implements the MQTT transport for the MCP server. +""" +from contextlib import AsyncExitStack +import json +from uuid import uuid4 +from datetime import timedelta +import random +from pydantic import AnyUrl, BaseModel +from mcp.client.session import ClientSession, SamplingFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT +from mcp.shared.exceptions import McpError +from mcp.shared.mqtt import MqttTransportBase, MqttOptions, QOS, MCP_SERVER_NAME_FILTERS +import asyncio +import anyio.to_thread as anyio_to_thread +import anyio.from_thread as anyio_from_thread +import traceback +import mcp.shared.mqtt_topic as mqtt_topic +import paho.mqtt.client as mqtt +import logging +from paho.mqtt.reasoncodes import ReasonCode +from paho.mqtt.properties import Properties +from paho.mqtt.subscribeoptions import SubscribeOptions +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from typing import Any, Literal, TypeAlias, Callable, Awaitable +import mcp.types as types + +RcvStream : TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage] +SndStream : TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage] +RcvStreamEx : TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] +SndStreamEX : TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage | Exception] +ServerRun : TypeAlias = Callable[[RcvStreamEx, SndStream], Awaitable[Any]] + +ServerName : TypeAlias = str +ServerId : TypeAlias = str +InitializeResult : TypeAlias = Literal["ok"] | Literal["already_connected"] | tuple[Literal["error"], str] +ConnectResult : TypeAlias = tuple[Literal["ok"], types.InitializeResult] | tuple[Literal["error"], Any] + +logger = logging.getLogger(__name__) + +class ServerDefinition(BaseModel): + description: str + meta: dict[str, Any] = {} + +class ServerOnlineNotification(BaseModel): + jsonrpc: Literal["2.0"] + method: str = "notifications/server/online" + params: ServerDefinition + +class MqttClientSession(ClientSession): + def __init__( + self, + server_id: ServerId, + server_name: ServerName, + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_timeout_seconds: timedelta | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + ) -> None: + super().__init__( + read_stream, + write_stream, + read_timeout_seconds, + sampling_callback, + list_roots_callback, + logging_callback, + message_handler, + ) + self.server_id = server_id + self.server_name = server_name + self.server_info: types.InitializeResult | None = None + +class MqttTransportClient(MqttTransportBase): + + def __init__(self, + mcp_client_name: str, + client_id: str | None = None, + server_name_filter: str | list[str] = '#', + auto_connect_to_mcp_server: bool = False, + on_mcp_connect: Callable[["MqttTransportClient", ServerName, ConnectResult], Awaitable[Any]] | None = None, + on_mcp_disconnect: Callable[["MqttTransportClient", ServerName], Awaitable[Any]] | None = None, + on_mcp_server_discovered: Callable[["MqttTransportClient", ServerName], Awaitable[Any]] | None = None, + mqtt_options: MqttOptions = MqttOptions()): + uuid = uuid4().hex + mqtt_clientid = client_id if client_id else uuid + self._current_server_id: dict[ServerName, ServerId] = {} + self.server_list: dict[ServerName, dict[ServerId, ServerDefinition]] = {} + self.client_sessions: dict[ServerName, MqttClientSession] = {} + self.mcp_client_id = mqtt_clientid + self.mcp_client_name = mcp_client_name + if isinstance(server_name_filter, str): + self.server_name_filter = [server_name_filter] + else: + self.server_name_filter = server_name_filter + self.auto_connect_to_mcp_server = auto_connect_to_mcp_server + self.on_mcp_connect = on_mcp_connect + self.on_mcp_disconnect = on_mcp_disconnect + self.on_mcp_server_discovered = on_mcp_server_discovered + self.client_capability_change_topic = mqtt_topic.get_client_capability_change_topic(self.mcp_client_id) + ## Send disconnected notification when disconnects + self._disconnected_msg = types.JSONRPCMessage( + types.JSONRPCNotification( + jsonrpc="2.0", + method = "notifications/disconnected" + ) + ) + super().__init__("mcp-client", mqtt_clientid = mqtt_clientid, + mqtt_options = mqtt_options, + disconnected_msg = self._disconnected_msg, + disconnected_msg_retain = False) + + def get_presence_topic(self) -> str: + return mqtt_topic.get_client_presence_topic(self.mcp_client_id) + + async def start(self, timeout: timedelta | None = None) -> bool | str: + try: + connect_result = self.connect() + asyncio.create_task(anyio_to_thread.run_sync(self.client.loop_forever)) + if connect_result != mqtt.MQTT_ERR_SUCCESS: + logger.error(f"Failed to connect to MQTT broker, error code: {connect_result}") + return mqtt.error_string(connect_result) + # test if the client is connected and wait until it is connected + if timeout: + while not self.is_connected(): + await asyncio.sleep(0.1) + if timeout.total_seconds() <= 0: + last_fail_reason = self.get_last_connect_fail_reason() + if last_fail_reason: + return last_fail_reason.getName() + return "timeout" + timeout -= timedelta(seconds=0.1) + return True + except asyncio.CancelledError: + logger.debug("MQTT transport (MCP client) got cancelled") + return "cancelled" + except ConnectionRefusedError as exc: + logger.error(f"MQTT transport (MCP client) failed to connect: {exc}") + return "connection_refused" + except TimeoutError as exc: + logger.error(f"MQTT transport (MCP client) timed out: {exc}") + return "timeout" + except Exception as exc: + logger.error(f"MQTT transport (MCP client) failed: {exc}") + return f"connect mqtt error: {str(exc)}" + + def get_session(self, server_name: ServerName) -> MqttClientSession | None: + return self.client_sessions.get(server_name, None) + + async def initialize_mcp_server( + self, server_name: str, + read_timeout_seconds: timedelta | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None) -> InitializeResult: + if server_name in self.client_sessions: + return "already_connected" + if server_name not in self.server_list: + logger.error(f"MCP server not found, server name: {server_name}") + return ("error", "MCP server not found") + server_id = self.pick_server_id(server_name) + + async def after_subscribed( + subscribe_result: Literal["success", "error"] + ): + if subscribe_result == "error": + if self.on_mcp_connect: + self._task_group.start_soon(self.on_mcp_connect, self, server_name, ("error", "subscribe_mcp_server_topics_failed")) + client_session = self._create_session( + server_id, + server_name, + read_timeout_seconds, + sampling_callback, + list_roots_callback, + logging_callback, + message_handler + ) + self.client_sessions[server_name] = client_session + try: + logger.debug(f"before initialize: {server_name}") + async def after_initialize(): + exit_stack = AsyncExitStack() + try: + session = await exit_stack.enter_async_context(client_session) + init_result = await session.initialize() + session.server_info = init_result + if self.on_mcp_connect: + self._task_group.start_soon(self.on_mcp_connect, self, server_name, ("ok", init_result)) + except Exception as e: + self.client_sessions.pop(server_name) + logging.error(f"Failed to initialize server {server_name}: {e}") + await exit_stack.aclose() + self._task_group.start_soon(after_initialize) + logger.debug(f"after initialize: {server_name}") + except McpError as exc: + self.client_sessions.pop(server_name) + logger.error(f"Failed to connect to MCP server: {exc}") + if self.on_mcp_connect: + self._task_group.start_soon(self.on_mcp_connect, self, server_name, ("error", McpError)) + + if self._subscribe_mcp_server_topics(server_id, server_name, after_subscribed): + return "ok" + else: + return ("error", "send_subscribe_request_failed") + + async def deinitialize_mcp_server(self, server_name: ServerName) -> None: + server_id = self._current_server_id[server_name] + topic = mqtt_topic.get_rpc_topic(self.mcp_client_id, server_id, server_name) + self.publish_json_rpc_message(topic, message = self._disconnected_msg, retain=False) + self._remove_server(server_id, server_name) + + async def send_ping(self, server_name: ServerName) -> bool | types.EmptyResult: + return await self._with_session(server_name, lambda s: s.send_ping()) + + async def send_progress_notification( + self, server_name: ServerName, progress_token: str | int, + progress: float, total: float | None = None + ) -> bool | None: + return await self._with_session(server_name, lambda s: s.send_progress_notification(progress_token, progress, total)) + + async def set_logging_level(self, server_name: ServerName, + level: types.LoggingLevel) -> bool | types.EmptyResult: + return await self._with_session(server_name, lambda s: s.set_logging_level(level)) + + async def list_resources(self, server_name: ServerName) -> bool | types.ListResourcesResult: + return await self._with_session(server_name, lambda s: s.list_resources()) + + async def list_resource_templates(self, server_name: ServerName) -> bool | types.ListResourceTemplatesResult: + return await self._with_session(server_name, lambda s: s.list_resource_templates()) + + async def read_resource(self, server_name: ServerName, + uri: AnyUrl) -> bool | types.ReadResourceResult: + return await self._with_session(server_name, lambda s: s.read_resource(uri)) + + async def subscribe_resource(self, server_name: ServerName, + uri: AnyUrl) -> bool | types.EmptyResult: + return await self._with_session(server_name, lambda s: s.subscribe_resource(uri)) + + async def unsubscribe_resource(self, server_name: ServerName, + uri: AnyUrl) -> bool | types.EmptyResult: + return await self._with_session(server_name, lambda s: s.unsubscribe_resource(uri)) + + async def call_tool( + self, server_name: ServerName, name: str, arguments: dict[str, Any] | None = None + ) -> bool | types.CallToolResult: + return await self._with_session(server_name, lambda s: s.call_tool(name, arguments)) + + async def list_prompts(self, server_name: ServerName) -> bool | types.ListPromptsResult: + return await self._with_session(server_name, lambda s: s.list_prompts()) + + async def get_prompt( + self, server_name: ServerName, name: str, arguments: dict[str, str] | None = None + ) -> bool | types.GetPromptResult: + return await self._with_session(server_name, lambda s: s.get_prompt(name, arguments)) + + async def complete( + self, + server_name: ServerName, + ref: types.ResourceReference | types.PromptReference, + argument: dict[str, str], + ) -> bool | types.CompleteResult: + return await self._with_session(server_name, lambda s: s.complete(ref, argument)) + + async def list_tools(self, server_name: ServerName) -> bool | types.ListToolsResult: + return await self._with_session(server_name, lambda s: s.list_tools()) + + async def send_roots_list_changed(self, server_name: ServerName) -> bool | None: + return await self._with_session(server_name, lambda s: s.send_roots_list_changed()) + + async def _with_session( + self, server_name: ServerName, + async_callback: Callable[[MqttClientSession], Awaitable[bool | Any]]) -> bool | Any: + if not (client_session := self.client_sessions.get(server_name, None)): + logger.error(f"No session for server_name: {server_name}") + return False + return await async_callback(client_session) + + def _create_session( + self, server_id: ServerId, server_name: ServerName, + read_timeout_seconds: timedelta | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None + ): + ## Streams are used to communicate between the MqttTransportClient and the MCPSession: + ## 1. (msg) --> MqttBroker --> MqttTransportClient -->[read_stream_writer]-->[read_stream]--> MCPSession + ## 2. MqttBroker <-- MqttTransportClient <--[write_stream_reader]--[write_stream]-- MCPSession <-- (msg) + read_stream: RcvStreamEx + read_stream_writer: SndStreamEX + write_stream: SndStream + write_stream_reader: RcvStream + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) # type: ignore + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) # type: ignore + self._read_stream_writers[server_id] = read_stream_writer + self._task_group.start_soon(self._receieved_from_session, server_id, server_name, write_stream_reader) + logger.debug(f"Created new session for server_id: {server_id}") + return MqttClientSession( + server_id, + server_name, + read_stream, + write_stream, + read_timeout_seconds, + sampling_callback, + list_roots_callback, + logging_callback, + message_handler + ) + + def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code : ReasonCode, properties: Properties | None): + super()._on_connect(client, userdata, connect_flags, reason_code, properties) + if properties and hasattr(properties, "UserProperty"): + user_properties: dict[str, Any] = dict(properties.UserProperty) # type: ignore + if MCP_SERVER_NAME_FILTERS in user_properties: + self.server_name_filter = json.loads(user_properties[MCP_SERVER_NAME_FILTERS]) + logger.debug(f"Use broker suggested server name filters: {self.server_name_filter}") + if reason_code == 0: + ## Subscribe to the MCP server's presence topic + for snf in self.server_name_filter: + logger.debug(f"Subscribing to server presence topic for server_name_filter: {snf}") + client.subscribe(mqtt_topic.get_server_presence_topic('+', snf), qos=QOS) + + def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): + logger.debug(f"Received message on topic {msg.topic}: {msg.payload.decode()}") + match msg.topic: + case str() as t if t.startswith(mqtt_topic.SERVER_PRESENCE_BASE): + self._handle_server_presence_message(msg) + case str() as t if t.startswith(mqtt_topic.RPC_BASE): + self._handle_rpc_message(msg) + case str() as t if t.startswith(mqtt_topic.SERVER_CAPABILITY_CHANGE_BASE): + self._handle_server_capability_message(msg) + case _: + logger.error(f"Received message on unexpected topic: {msg.topic}") + + def _on_subscribe(self, client: mqtt.Client, userdata: Any, mid: int, + reason_code_list: list[ReasonCode], properties: Properties | None): + if mid in userdata.get("pending_subs", {}): + server_name, server_id, after_subscribed = userdata["pending_subs"].pop(mid) + ## only create session if all topic subscribed successfully + if all([rc.value == QOS for rc in reason_code_list]): + logger.debug(f"Subscribed to topics for server_name: {server_name}, server_id: {server_id}") + anyio_from_thread.run(after_subscribed, "success") + else: + anyio_from_thread.run(after_subscribed, "error") + logger.error(f"Failed to subscribe to topics for server_name: {server_name}, server_id: {server_id}, reason_codes: {reason_code_list}") + + def _handle_server_presence_message(self, msg: mqtt.MQTTMessage) -> None: + topic_words = msg.topic.split("/") + server_id = topic_words[2] + server_name = "/".join(topic_words[3:]) + if msg.payload: + newly_added_server = False if server_name in self.server_list else True + server_notif = ServerOnlineNotification.model_validate_json(msg.payload.decode()) + self.server_list.setdefault(server_name, {})[server_id] = server_notif.params + logger.debug(f"Server {server_name} with id {server_id} is online") + if newly_added_server: + if self.auto_connect_to_mcp_server: + logger.debug(f"Auto connecting to MCP server {server_name}") + anyio_from_thread.run(self.initialize_mcp_server, server_name) + if self.on_mcp_server_discovered: + anyio_from_thread.run(self.on_mcp_server_discovered, self, server_name) + else: + # server is offline if the payload is empty + logger.debug(f"Server {server_name} with id {server_id} is offline") + self._remove_server(server_id, server_name) + + def _remove_server(self, server_id: ServerId, server_name: ServerName) -> None: + if server_id in self.server_list.get(server_name, {}): + if server_id in self._read_stream_writers: + logger.debug(f"Closing stream writer for server_id: {server_id}") + self._read_stream_writers[server_id].close() + + def _handle_rpc_message(self, msg: mqtt.MQTTMessage) -> None: + server_name = "/".join(msg.topic.split("/")[3:]) + anyio_from_thread.run(self._send_message_to_session, server_name, msg) + + def _handle_server_capability_message(self, msg: mqtt.MQTTMessage) -> None: + server_name = "/".join(msg.topic.split("/")[4:]) + anyio_from_thread.run(self._send_message_to_session, server_name, msg) + + def _subscribe_mcp_server_topics(self, server_id: ServerId, server_name: ServerName, + after_subscribed: Callable[[Any], Awaitable[None]]): + topic_filters = [ + (mqtt_topic.get_server_capability_change_topic(server_id, server_name), SubscribeOptions(qos=QOS)), + (mqtt_topic.get_rpc_topic(self.mcp_client_id, server_id, server_name), SubscribeOptions(qos=QOS, noLocal=True)) + ] + ret, mid = self.client.subscribe(topic=topic_filters) + if ret != mqtt.MQTT_ERR_SUCCESS: + logger.error(f"Failed to subscribe to topics for server_name: {server_name}") + return False + userdata = self.client.user_data_get() + pending_subs = userdata.get("pending_subs", {}) + pending_subs[mid] = (server_name, server_id, after_subscribed) + userdata["pending_subs"] = pending_subs + return True + + async def _send_message_to_session(self, server_name: ServerName, msg: mqtt.MQTTMessage): + if not (client_session := self.client_sessions.get(server_name, None)): + logger.error(f"_send_message_to_session: No session for server_name: {server_name}") + return + payload = msg.payload.decode() + server_id = client_session.server_id + if server_id not in self._read_stream_writers: + logger.error(f"No session for server_id: {server_id}") + return + read_stream_writer = self._read_stream_writers[server_id] + try: + message = types.JSONRPCMessage.model_validate_json(payload) + logger.debug(f"Sending msg to session for server_id: {server_id}, msg: {message}") + with anyio.fail_after(3): + await read_stream_writer.send(message) + except Exception as exc: + logger.error(f"Failed to send msg to session for server_id: {server_id}, exception: {exc}") + traceback.print_exc() + ## TODO: the session does not handle exceptions for now + #await read_stream_writer.send(exc) + async def _receieved_from_session(self, server_id: ServerId, server_name: ServerName, + write_stream_reader: RcvStream): + async with write_stream_reader: + async for msg in write_stream_reader: + logger.debug(f"Got msg from session for server_id: {server_id}, msg: {msg}") + match msg.model_dump(): + case {"method": method} if method == "notifications/initialized": + logger.debug(f"Session initialized for server_id: {server_id}") + topic = mqtt_topic.get_rpc_topic(self.mcp_client_id, server_id, server_name) + case {"method": method} if method.endswith("/list_changed"): + topic = None + logger.warning("Resource updates should not be sent from the session. Ignoring.") + case {"method": method} if method == "initialize": + topic = mqtt_topic.get_server_control_topic(server_id, server_name) + case _: + topic = mqtt_topic.get_rpc_topic(self.mcp_client_id, server_id, server_name) + if topic: + self.publish_json_rpc_message(topic, message = msg) + # cleanup + if server_id in self._read_stream_writers: + logger.debug(f"Removing session for server_id: {server_id}") + stream = self._read_stream_writers.pop(server_id) + await stream.aclose() + + # unsubscribe from the topics + logger.debug(f"Unsubscribing from topics for server_id: {server_id}, server_name: {server_name}") + topic_filters = [ + mqtt_topic.get_server_capability_change_topic(server_id, server_name), + mqtt_topic.get_rpc_topic(self.mcp_client_id, server_id, server_name) + ] + self.client.unsubscribe(topic=topic_filters) + + if server_id in self.server_list.get(server_name, {}): + _ = self.server_list[server_name].pop(server_id) + if not self.server_list[server_name]: + _ = self.server_list.pop(server_name) + if self.on_mcp_disconnect: + self._task_group.start_soon(self.on_mcp_disconnect, self, server_name) + + if server_name in self.client_sessions: + _ = self.client_sessions.pop(server_name) + + if server_name in self._current_server_id: + _ = self._current_server_id.pop(server_name) + logger.debug(f"Session stream closed for server_id: {server_id}") + + def pick_server_id(self, server_name: str) -> ServerId: + server_id = random.choice(list(self.server_list[server_name].keys())) + self._current_server_id[server_name] = server_id + return server_id + +def validate_server_name(name: str): + if "/" not in name: + raise ValueError(f"Invalid server name: {name}, must contain a '/'") + elif ("+" in name) or ("#" in name): + raise ValueError(f"Invalid server name: {name}, must not contain '+' or '#'") + elif name[0] == "/": + raise ValueError(f"Invalid server name: {name}, must not start with '/'") diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index bf0ce880a5..2108b1cf61 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -35,6 +35,7 @@ from mcp.server.lowlevel.server import lifespan as default_lifespan from mcp.server.session import ServerSession, ServerSessionT from mcp.server.sse import SseServerTransport +from mcp.server.mqtt import validate_server_name, start_mqtt, MqttOptions from mcp.server.stdio import stdio_server from mcp.shared.context import LifespanContextT, RequestContext from mcp.types import ( @@ -52,7 +53,6 @@ logger = get_logger(__name__) - class Settings(BaseSettings, Generic[LifespanResultT]): """FastMCP server settings. @@ -76,6 +76,12 @@ class Settings(BaseSettings, Generic[LifespanResultT]): sse_path: str = "/sse" message_path: str = "/messages/" + # MQTT settings + mqtt_server_description: str = '' + mqtt_server_meta: dict[str, Any] = {} + mqtt_client_id: str | None = None + mqtt_options: MqttOptions = MqttOptions() + # resource settings warn_on_duplicate_resources: bool = True @@ -145,18 +151,21 @@ def name(self) -> str: def instructions(self) -> str | None: return self._mcp_server.instructions - def run(self, transport: Literal["stdio", "sse"] = "stdio") -> None: + def run(self, transport: Literal["stdio", "sse", "mqtt"] = "stdio") -> None: """Run the FastMCP server. Note this is a synchronous function. Args: transport: Transport protocol to use ("stdio" or "sse") """ - TRANSPORTS = Literal["stdio", "sse"] + TRANSPORTS = Literal["stdio", "sse", "mqtt"] if transport not in TRANSPORTS.__args__: # type: ignore raise ValueError(f"Unknown transport: {transport}") if transport == "stdio": anyio.run(self.run_stdio_async) + elif transport == "mqtt": + validate_server_name(self._mcp_server.name) + anyio.run(self.run_mqtt_async) else: # transport == "sse" anyio.run(self.run_sse_async) @@ -477,6 +486,23 @@ async def run_sse_async(self) -> None: server = uvicorn.Server(config) await server.serve() + async def run_mqtt_async(self) -> None: + """Run the server using MQTT transport.""" + def server_session_run(read_stream: Any, write_stream: Any): + return self._mcp_server.run( + read_stream, + write_stream, + self._mcp_server.create_initialization_options(), + ) + await start_mqtt( + server_session_run, + server_name = self._mcp_server.name, + server_description=self.settings.mqtt_server_description, + server_meta = self.settings.mqtt_server_meta, + client_id = self.settings.mqtt_client_id, + mqtt_options = self.settings.mqtt_options + ) + def sse_app(self) -> Starlette: """Return an instance of the SSE server app.""" sse = SseServerTransport(self.settings.message_path) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index dbaff30516..37e5821057 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -497,6 +497,7 @@ async def run( lifespan_context, raise_exceptions, ) + logger.debug("Server closed") async def _handle_message( self, diff --git a/src/mcp/server/mqtt.py b/src/mcp/server/mqtt.py new file mode 100644 index 0000000000..7d4339e48a --- /dev/null +++ b/src/mcp/server/mqtt.py @@ -0,0 +1,289 @@ +""" +This module implements the MQTT transport for the MCP server. +""" +from uuid import uuid4 +from mcp.shared.mqtt import MqttTransportBase, MqttOptions, QOS, PROPERTY_K_MQTT_CLIENT_ID, MCP_SERVER_NAME +import asyncio +import anyio.to_thread as anyio_to_thread +import anyio.from_thread as anyio_from_thread +import json +import traceback +import mcp.shared.mqtt_topic as mqtt_topic +import paho.mqtt.client as mqtt +import logging +from paho.mqtt.reasoncodes import ReasonCode +from paho.mqtt.properties import Properties +from paho.mqtt.subscribeoptions import SubscribeOptions +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from typing import Any, TypeAlias, Callable, Awaitable +import mcp.types as types + +RcvStream : TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage] +SndStream : TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage] +RcvStreamEx : TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] +SndStreamEX : TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage | Exception] +ServerSessionRun : TypeAlias = Callable[[RcvStreamEx, SndStream], Awaitable[Any]] + +logger = logging.getLogger(__name__) + +class MqttTransportServer(MqttTransportBase): + + def __init__(self, server_session_run: ServerSessionRun, server_name: str, + server_description: str, + server_meta: dict[str, Any], + client_id: str | None = None, + mqtt_options: MqttOptions = MqttOptions()): + uuid = uuid4().hex + mqtt_clientid = client_id if client_id else uuid + self.server_id = mqtt_clientid + self.server_name = server_name + self.server_description = server_description + self.server_meta = server_meta + self.server_session_run = server_session_run + super().__init__("mcp-server", mqtt_clientid = mqtt_clientid, + mqtt_options = mqtt_options, + disconnected_msg = None, + disconnected_msg_retain = True) + + def get_presence_topic(self) -> str: + return mqtt_topic.get_server_presence_topic(self.server_id, self.server_name) + + def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code : ReasonCode, properties: Properties | None): + super()._on_connect(client, userdata, connect_flags, reason_code, properties) + if reason_code == 0: + if properties and hasattr(properties, "UserProperty"): + user_properties: dict[str, Any] = dict(properties.UserProperty) # type: ignore + if MCP_SERVER_NAME in user_properties: + broker_suggested_server_name = user_properties[MCP_SERVER_NAME] + self.server_name = broker_suggested_server_name + logger.debug(f"Used broker suggested server name: {broker_suggested_server_name}") + else: + logger.error(f"No {PROPERTY_K_MQTT_CLIENT_ID} in UserProperties") + self.server_control_topic = mqtt_topic.get_server_control_topic(self.server_id, self.server_name) + ## Subscribe to the server control topic + client.subscribe(self.server_control_topic, QOS) + ## Reister the server on the presence topic + online_msg = types.JSONRPCMessage( + types.JSONRPCNotification( + jsonrpc="2.0", + method = "notifications/server/online", + params = { + "description": self.server_description, + "meta": self.server_meta + } + )) + self.publish_json_rpc_message( + self.get_presence_topic(), message=online_msg, retain=True) + + def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): + logger.debug(f"Received message on topic {msg.topic}: {msg.payload.decode()}") + match msg.topic: + case str() as t if t == self.server_control_topic: + self.handle_server_contorl_message(msg) + case str() as t if t.startswith(mqtt_topic.CLIENT_CAPABILITY_CHANGE_BASE): + self.handle_client_capability_change_message(msg) + case str() as t if t.startswith(mqtt_topic.RPC_BASE): + self.handle_rpc_message(msg) + case str() as t if t.startswith(mqtt_topic.CLIENT_PRESENCE_BASE): + self.handle_client_presence_message(msg) + case _: + logger.error(f"Received message on unexpected topic: {msg.topic}") + + def _on_subscribe(self, client: mqtt.Client, userdata: Any, mid: int, + reason_code_list: list[ReasonCode], properties: Properties | None): + if mid in userdata.get("pending_subs", {}): + mcp_client_id, msg, rpc_msg_id = userdata["pending_subs"].pop(mid) + ## only create session if all topic subscribed successfully + if all([rc.value == QOS for rc in reason_code_list]): + logger.debug(f"Subscribed to topics for mcp_client_id: {mcp_client_id}") + anyio_from_thread.run(self.create_session, mcp_client_id, msg) + else: + logger.error(f"Failed to subscribe to topics for mcp_client_id: {mcp_client_id}, reason_codes: {reason_code_list}") + err = types.JSONRPCError( + jsonrpc="2.0", + id=rpc_msg_id, + error=types.ErrorData( + code=types.INTERNAL_ERROR, + message="Failed to subscribe to client topics" + ) + ) + self.publish_json_rpc_message( + mqtt_topic.get_rpc_topic(mcp_client_id, self.server_id, self.server_name), + message = types.JSONRPCMessage(err) + ) + + def handle_server_contorl_message(self, msg: mqtt.MQTTMessage): + if msg.properties and hasattr(msg.properties, "UserProperty"): + user_properties: dict[str, Any] = dict(msg.properties.UserProperty) # type: ignore + if PROPERTY_K_MQTT_CLIENT_ID in user_properties: + mcp_client_id = user_properties[PROPERTY_K_MQTT_CLIENT_ID] + if mcp_client_id in self._read_stream_writers: + anyio_from_thread.run(self._send_message_to_session, mcp_client_id, msg) + else: + self.maybe_subscribe_to_client(mcp_client_id, msg) + else: + logger.error(f"No {PROPERTY_K_MQTT_CLIENT_ID} in UserProperties") + else: + logger.error("No UserProperties in control message") + + def handle_client_capability_change_message(self, msg: mqtt.MQTTMessage) -> None: + mcp_client_id = msg.topic.split("/")[-1] + anyio_from_thread.run(self._send_message_to_session, mcp_client_id, msg) + + def handle_rpc_message(self, msg: mqtt.MQTTMessage) -> None: + mcp_client_id = msg.topic.split("/")[1] + try: + json_msg = json.loads(msg.payload.decode()) + if "method" in json_msg: + if json_msg["method"] == "notifications/disconnected": + stream = self._read_stream_writers[mcp_client_id] + anyio_from_thread.run(stream.aclose) + logger.debug(f"Closed read_stream for mcp_client_id: {mcp_client_id}") + return + else: + anyio_from_thread.run(self._send_message_to_session, mcp_client_id, msg) + else: + anyio_from_thread.run(self._send_message_to_session, mcp_client_id, msg) + except json.JSONDecodeError: + logger.error(f"Invalid JSON in RPC message for mcp_client_id: {mcp_client_id}") + + def handle_client_presence_message(self, msg: mqtt.MQTTMessage) -> None: + mcp_client_id = msg.topic.split("/")[-1] + if mcp_client_id not in self._read_stream_writers: + logger.error(f"No session for mcp_client_id: {mcp_client_id}") + return + try: + json_msg = json.loads(msg.payload.decode()) + if "method" in json_msg: + if json_msg["method"] == "notifications/disconnected": + stream = self._read_stream_writers[mcp_client_id] + anyio_from_thread.run(stream.aclose) + logger.debug(f"Closed read_stream for mcp_client_id: {mcp_client_id}") + else: + logger.error(f"Unknown method in presence message for mcp_client_id: {mcp_client_id}") + else: + logger.error(f"No method in presence message for mcp_client_id: {mcp_client_id}") + except json.JSONDecodeError: + logger.error(f"Invalid JSON in presence message for mcp_client_id: {mcp_client_id}") + + async def create_session(self, mcp_client_id: str, msg: mqtt.MQTTMessage): + ## Streams are used to communicate between the MqttTransportServer and the MCPSession: + ## 1. (msg) --> MqttBroker --> MqttTransportServer -->[read_stream_writer]-->[read_stream]--> MCPSession + ## 2. MqttBroker <-- MqttTransportServer <--[write_stream_reader]--[write_stream]-- MCPSession <-- (msg) + read_stream: RcvStreamEx + read_stream_writer: SndStreamEX + write_stream: SndStream + write_stream_reader: RcvStream + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) # type: ignore + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) # type: ignore + self._read_stream_writers[mcp_client_id] = read_stream_writer + self._task_group.start_soon(self.server_session_run, read_stream, write_stream) + self._task_group.start_soon(self._receieved_from_session, mcp_client_id, write_stream_reader) + logger.debug(f"Created new session for mcp_client_id: {mcp_client_id}") + await self._send_message_to_session(mcp_client_id, msg) + + def maybe_subscribe_to_client(self, mcp_client_id: str, msg: mqtt.MQTTMessage): + try: + json_msg = json.loads(msg.payload.decode()) + if "id" in json_msg: + rpc_msg_id = json_msg["id"] + self.subscribe_to_client(mcp_client_id, msg, rpc_msg_id) + else: + logger.error(f"No id in control message for mcp_client_id: {mcp_client_id}") + except json.JSONDecodeError: + logger.error(f"Invalid JSON in control message for mcp_client_id: {mcp_client_id}") + return + + def subscribe_to_client(self, mcp_client_id: str, msg: mqtt.MQTTMessage, rcp_msg_id: Any): + topic_filters = [ + (mqtt_topic.get_client_presence_topic(mcp_client_id), SubscribeOptions(qos=QOS)), + (mqtt_topic.get_client_capability_change_topic(mcp_client_id), SubscribeOptions(qos=QOS)), + (mqtt_topic.get_rpc_topic(mcp_client_id, self.server_id, self.server_name), SubscribeOptions(qos=QOS, noLocal=True)) + ] + ret, mid = self.client.subscribe(topic=topic_filters) + if ret != mqtt.MQTT_ERR_SUCCESS: + logger.error(f"Failed to subscribe to topics for mcp_client_id: {mcp_client_id}") + return + userdata = self.client.user_data_get() + pending_subs = userdata.get("pending_subs", {}) + pending_subs[mid] = (mcp_client_id, msg, rcp_msg_id) + userdata["pending_subs"] = pending_subs + + async def _send_message_to_session(self, mcp_client_id: str, msg: mqtt.MQTTMessage): + payload = msg.payload.decode() + if mcp_client_id not in self._read_stream_writers: + logger.error(f"No session for mcp_client_id: {mcp_client_id}") + return + read_stream_writer = self._read_stream_writers[mcp_client_id] + try: + message = types.JSONRPCMessage.model_validate_json(payload) + logger.debug(f"Sending msg to session for mcp_client_id: {mcp_client_id}, msg: {message}") + with anyio.fail_after(3): + await read_stream_writer.send(message) + except Exception as exc: + logger.error(f"Failed to send msg to session for mcp_client_id: {mcp_client_id}, exception: {exc}") + traceback.print_exc() + ## TODO: the session does not handle exceptions for now + #await read_stream_writer.send(exc) + + async def _receieved_from_session(self, mcp_client_id: str, write_stream_reader: RcvStream): + async with write_stream_reader: + async for msg in write_stream_reader: + logger.debug(f"Got msg from session for mcp_client_id: {mcp_client_id}, msg: {msg}") + match msg.model_dump(): + case {"method": "notifications/resources/updated"}: + logger.warning("Resource updates should not be sent from the session. Ignoring.") + case {"method": method} if method.endswith("/list_changed"): + logger.warning("Resource updates should not be sent from the session. Ignoring.") + case _: + topic = mqtt_topic.get_rpc_topic(mcp_client_id, self.server_id, self.server_name) + self.publish_json_rpc_message(topic, message = msg) + # cleanup + if mcp_client_id in self._read_stream_writers: + logger.debug(f"Removing session for mcp_client_id: {mcp_client_id}") + stream = self._read_stream_writers.pop(mcp_client_id) + await stream.aclose() + + # unsubscribe from the client topics + logger.debug(f"Unsubscribing from topics for mcp_client_id: {mcp_client_id}") + topic_filters = [ + mqtt_topic.get_client_presence_topic(mcp_client_id), + mqtt_topic.get_client_capability_change_topic(mcp_client_id), + mqtt_topic.get_rpc_topic(mcp_client_id, self.server_id, self.server_name) + ] + self.client.unsubscribe(topic=topic_filters) + + logger.debug(f"Session stream closed for mcp_client_id: {mcp_client_id}") + +async def start_mqtt( + server_session_run: ServerSessionRun, server_name: str, + server_description: str, + server_meta: dict[str, Any], + client_id: str | None = None, + mqtt_options: MqttOptions = MqttOptions()): + async with MqttTransportServer( + server_session_run, + server_name = server_name, + server_description=server_description, + server_meta = server_meta, + client_id = client_id, + mqtt_options = mqtt_options + ) as mqtt_trans: + def start(): + mqtt_trans.connect() + mqtt_trans.client.loop_forever() + try: + await anyio_to_thread.run_sync(start) + except asyncio.CancelledError: + logger.debug("MQTT transport (MCP server) got cancelled") + except Exception as exc: + logger.error(f"MQTT transport (MCP server) failed with exception: {exc}") + +def validate_server_name(name: str): + if "/" not in name: + raise ValueError(f"Invalid server name: {name}, must contain a '/'") + elif ("+" in name) or ("#" in name): + raise ValueError(f"Invalid server name: {name}, must not contain '+' or '#'") + elif name[0] == "/": + raise ValueError(f"Invalid server name: {name}, must not start with '/'") diff --git a/src/mcp/shared/mqtt.py b/src/mcp/shared/mqtt.py new file mode 100644 index 0000000000..a8705d970a --- /dev/null +++ b/src/mcp/shared/mqtt.py @@ -0,0 +1,234 @@ +""" +MQTT Transport Base Module + +""" +from types import TracebackType +import paho.mqtt.client as mqtt +import logging +from paho.mqtt.reasoncodes import ReasonCode +from paho.mqtt.enums import CallbackAPIVersion +from paho.mqtt.properties import Properties +from paho.mqtt.packettypes import PacketTypes +import anyio +import anyio.from_thread as anyio_from_thread +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import BaseModel, SecretStr +from typing import Literal, Optional, Any, TypeAlias, Callable, Awaitable +import mcp.types as types +from typing_extensions import Self +from abc import ABC, abstractmethod + +DEFAULT_LOG_FORMAT = "%(asctime)s - %(message)s" +QOS = 0 +MCP_SERVER_NAME = "MCP-SERVER-NAME" +MCP_SERVER_NAME_FILTERS = "MCP-SERVER-NAME-FILTERS" +MCP_AUTH_ROLE = "MCP-AUTH-ROLE" +PROPERTY_K_MCP_COMPONENT = "MCP-COMPONENT-TYPE" +PROPERTY_K_MQTT_CLIENT_ID = "MCP-MQTT-CLIENT-ID" +logger = logging.getLogger(__name__) + +RcvStream : TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage] +SndStream : TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage] +RcvStreamEx : TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] +SndStreamEX : TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage | Exception] +ServerRun : TypeAlias = Callable[[RcvStreamEx, SndStream], Awaitable[Any]] + +class MqttOptions(BaseModel): + host: str = "localhost" + port: int = 1883 + transport: Literal['tcp', 'websockets', 'unix'] = 'tcp' + keepalive: int = 60 + bind_address: str = '' + bind_port: int = 0 + username: Optional[str] = None + password: Optional[SecretStr] = None + tls_enabled: bool = False + tls_version: Optional[int] = None + tls_insecure: bool = False + ca_certs: Optional[str] = None + certfile: Optional[str] = None + keyfile: Optional[str] = None + ciphers: Optional[str] = None + keyfile_password: Optional[str] = None + alpn_protocols: Optional[list[str]] = None + websocket_path: str = '/mqtt' + websocket_headers: Optional[dict[str, str]] = None + +class MqttTransportBase(ABC): + _read_stream_writers: dict[ + str, SndStreamEX + ] + + def __init__(self, + mcp_component_type: Literal["mcp-client", "mcp-server"], + mqtt_clientid: str | None = None, + mqtt_options: MqttOptions = MqttOptions(), + disconnected_msg: types.JSONRPCMessage | None = None, + disconnected_msg_retain: bool = True): + self._read_stream_writers = {} + self._last_connect_fail_reason = None + self.mqtt_clientid = mqtt_clientid + self.mcp_component_type = mcp_component_type + self.mqtt_options = mqtt_options + self.disconnected_msg = disconnected_msg + self.disconnected_msg_retain = disconnected_msg_retain + client = mqtt.Client( + callback_api_version=CallbackAPIVersion.VERSION2, + client_id=mqtt_clientid, protocol=mqtt.MQTTv5, + userdata={}, + transport=mqtt_options.transport, + reconnect_on_failure=True + ) + client.reconnect_delay_set( + min_delay=1, max_delay=120 + ) + client.username_pw_set(mqtt_options.username, mqtt_options.password.get_secret_value() if mqtt_options.password else None) + if mqtt_options.tls_enabled: + client.tls_set( # type: ignore + ca_certs=mqtt_options.ca_certs, + certfile=mqtt_options.certfile, + keyfile=mqtt_options.keyfile, + tls_version=mqtt_options.tls_version, + ciphers=mqtt_options.ciphers, + keyfile_password=mqtt_options.keyfile_password, + alpn_protocols=mqtt_options.alpn_protocols + ) + client.tls_insecure_set(mqtt_options.tls_insecure) + if mqtt_options.transport == 'websockets': + client.ws_set_options(path=mqtt_options.websocket_path, headers=mqtt_options.websocket_headers) + client.on_connect = self._on_connect + client.on_message = self._on_message + client.on_subscribe = self._on_subscribe + ## We need to set an empty will message to clean the retained presence + ## message when the MCP server goes offline. + ## Note that if the broker suggested a new server name, it's the broker's + ## responsibility to clean the retained presence message and send the + ## last will message on the changed presence topic. + client.will_set( + topic = self.get_presence_topic(), + payload = disconnected_msg.model_dump_json() if disconnected_msg else None, + qos = QOS, + retain = disconnected_msg_retain, + properties = self.get_publish_properties(), + ) + logger.info(f"MCP component type: {mcp_component_type}, MQTT clientid: {mqtt_clientid}, MQTT settings: {mqtt_options}") + self.client = client + + async def __aenter__(self) -> Self: + self._task_group = anyio.create_task_group() + await self._task_group.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + await self.stop_mqtt() + self._task_group.cancel_scope.cancel() + return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + + def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code: ReasonCode, properties: Properties | None): + if reason_code == 0: + logger.debug(f"Connected to MQTT broker_host at {self.mqtt_options.host}:{self.mqtt_options.port}") + self.assert_property(properties, "RetainAvailable", 1) + self.assert_property(properties, "WildcardSubscriptionAvailable", 1) + else: + self._last_connect_fail_reason = reason_code + logger.error(f"Failed to connect, return code {reason_code}") + + def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): + pass + + def _on_subscribe(self, client: mqtt.Client, userdata: Any, mid: int, + reason_code_list: list[ReasonCode], properties: Properties | None): + pass + + def is_connected(self) -> bool: + return self.client.is_connected() + + def get_last_connect_fail_reason(self) -> ReasonCode | None: + return self._last_connect_fail_reason + + def publish_json_rpc_message(self, topic: str, message: types.JSONRPCMessage | None, + retain: bool = False): + props = self.get_publish_properties() + payload = message.model_dump_json(by_alias=True, exclude_none=True) if message else None + self.client.publish(topic=topic, payload=payload, qos=QOS, retain=retain, properties=props) + + def get_publish_properties(self): + props = Properties(PacketTypes.PUBLISH) + props.UserProperty = [ + (PROPERTY_K_MCP_COMPONENT, self.mcp_component_type), + (PROPERTY_K_MQTT_CLIENT_ID, self.mqtt_clientid) + ] + return props + + def connect(self): + logger.debug("Setting up MQTT connection") + props = Properties(PacketTypes.CONNECT) + props.UserProperty = [ + (PROPERTY_K_MCP_COMPONENT, self.mcp_component_type) + ] + return self.client.connect( + host = self.mqtt_options.host, + port = self.mqtt_options.port, + keepalive = self.mqtt_options.keepalive, + bind_address = self.mqtt_options.bind_address, + bind_port = self.mqtt_options.bind_port, + clean_start=True, + properties=props, + ) + + def assert_property(self, properties: Properties | None, property_name: str, expected_value: Any): + if get_property(properties, property_name) == expected_value: + pass + else: + anyio_from_thread.run(self.stop_mqtt) + raise ValueError(f"{property_name} not available") + + @abstractmethod + def get_presence_topic(self) -> str: + pass + + async def stop_mqtt(self): + self.publish_json_rpc_message( + self.get_presence_topic(), + message = self.disconnected_msg, + retain = self.disconnected_msg_retain + ) + self.client.disconnect() + self.client.loop_stop() + for stream in self._read_stream_writers.values(): + await stream.aclose() + self._read_stream_writers = {} + logger.debug("Disconnected from MQTT broker_host") + +def get_property(properties: Properties | None, property_name: str): + if properties and hasattr(properties, property_name): + return getattr(properties, property_name) + else: + return False + +def configure_logging( + level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO", + format: str = DEFAULT_LOG_FORMAT, +) -> None: + handlers: list[logging.Handler] = [] + try: + from rich.console import Console + from rich.logging import RichHandler + + handlers.append(RichHandler(console=Console(stderr=True), rich_tracebacks=True)) + except ImportError: + pass + + if not handlers: + handlers.append(logging.StreamHandler()) + + logging.basicConfig( + level=level, + format=format, + handlers=handlers, + ) diff --git a/src/mcp/shared/mqtt_topic.py b/src/mcp/shared/mqtt_topic.py new file mode 100644 index 0000000000..1c5e3da536 --- /dev/null +++ b/src/mcp/shared/mqtt_topic.py @@ -0,0 +1,26 @@ + + +SERVER_CONTROL_BASE: str = '$mcp-server' +SERVER_CAPABILITY_CHANGE_BASE: str = '$mcp-server/capability' +SERVER_PRESENCE_BASE: str = '$mcp-server/presence' +CLIENT_PRESENCE_BASE: str = '$mcp-client/presence' +CLIENT_CAPABILITY_CHANGE_BASE: str = '$mcp-client/capability' +RPC_BASE: str = '$mcp-rpc' + +def get_server_control_topic(server_id: str, server_name: str) -> str: + return f"{SERVER_CONTROL_BASE}/{server_id}/{server_name}" + +def get_server_capability_change_topic(server_id: str, server_name: str) -> str: + return f"{SERVER_CAPABILITY_CHANGE_BASE}/{server_id}/{server_name}" + +def get_server_presence_topic(server_id: str, server_name: str) -> str: + return f"{SERVER_PRESENCE_BASE}/{server_id}/{server_name}" + +def get_client_presence_topic(mcp_clientid: str) -> str: + return f"{CLIENT_PRESENCE_BASE}/{mcp_clientid}" + +def get_client_capability_change_topic(mcp_clientid: str) -> str: + return f"{CLIENT_CAPABILITY_CHANGE_BASE}/{mcp_clientid}" + +def get_rpc_topic(mcp_clientid: str, server_id: str, server_name: str) -> str: + return f"{RPC_BASE}/{mcp_clientid}/{server_id}/{server_name}" diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 05fd3ce37f..062830f3d7 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,6 +1,6 @@ import logging from collections.abc import Callable -from contextlib import AsyncExitStack +from contextlib import AsyncExitStack, asynccontextmanager from datetime import timedelta from types import TracebackType from typing import Any, Generic, TypeVar @@ -187,6 +187,7 @@ def __init__( self._receive_notification_type = receive_notification_type self._read_timeout_seconds = read_timeout_seconds self._in_flight = {} + self._receive_loop_alive = None self._exit_stack = AsyncExitStack() @@ -207,7 +208,9 @@ async def __aexit__( # would be very surprising behavior), so make sure to cancel the tasks # in the task group. self._task_group.cancel_scope.cancel() - return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + if self._receive_loop_alive: + return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + return False async def send_request( self, @@ -249,7 +252,12 @@ async def send_request( if self._read_timeout_seconds is None else self._read_timeout_seconds.total_seconds() ): - response_or_error = await response_stream_reader.receive() + async with response_stream_reader: + async for response_or_error in response_stream_reader: + if isinstance(response_or_error, JSONRPCError): + raise McpError(response_or_error.error) + else: + return result_type.model_validate(response_or_error.result) except TimeoutError: raise McpError( ErrorData( @@ -262,11 +270,6 @@ async def send_request( ) ) - if isinstance(response_or_error, JSONRPCError): - raise McpError(response_or_error.error) - else: - return result_type.model_validate(response_or_error.result) - async def send_notification(self, notification: SendNotificationT) -> None: """ Emits a notification, which is a one-way message that does not expect @@ -296,9 +299,19 @@ async def _send_response( await self._write_stream.send(JSONRPCMessage(jsonrpc_response)) async def _receive_loop(self) -> None: + @asynccontextmanager + async def receive_loop_status(): + try: + self._receive_loop_alive = True + yield + finally: + self._receive_loop_alive = False + async with ( self._read_stream, self._write_stream, + self._exit_stack, + receive_loop_status() ): async for message in self._read_stream: if isinstance(message, Exception): diff --git a/uv.lock b/uv.lock index 424e2d4823..1a50bb9d5e 100644 --- a/uv.lock +++ b/uv.lock @@ -492,6 +492,7 @@ dependencies = [ { name = "anyio" }, { name = "httpx" }, { name = "httpx-sse" }, + { name = "paho-mqtt" }, { name = "pydantic" }, { name = "pydantic-settings" }, { name = "sse-starlette" }, @@ -533,6 +534,7 @@ requires-dist = [ { name = "anyio", specifier = ">=4.5" }, { name = "httpx", specifier = ">=0.27" }, { name = "httpx-sse", specifier = ">=0.4" }, + { name = "paho-mqtt", git = "https://github.com/eclipse-paho/paho.mqtt.python.git?tag=v2.1.0" }, { name = "pydantic", specifier = ">=2.7.2,<3.0.0" }, { name = "pydantic-settings", specifier = ">=2.5.2" }, { name = "python-dotenv", marker = "extra == 'cli'", specifier = ">=1.0.0" }, @@ -856,6 +858,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/90/96/04b8e52da071d28f5e21a805b19cb9390aa17a47462ac87f5e2696b9566d/paginate-0.5.7-py2.py3-none-any.whl", hash = "sha256:b885e2af73abcf01d9559fd5216b57ef722f8c42affbb63942377668e35c7591", size = 13746 }, ] +[[package]] +name = "paho-mqtt" +version = "2.1.0" +source = { git = "https://github.com/eclipse-paho/paho.mqtt.python.git?tag=v2.1.0#af64a4365c6ac5a7a4d339e7b00f44df91353b35" } + [[package]] name = "pathspec" version = "0.12.1" @@ -1618,4 +1625,4 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2b/fb/c492d6daa5ec067c2988ac80c61359ace5c4c674c532985ac5a123436cec/websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b359ed09954d7c18bbc1680f380c7301f92c60bf924171629c5db97febb12f04", size = 174155 }, { url = "https://files.pythonhosted.org/packages/68/a1/dcb68430b1d00b698ae7a7e0194433bce4f07ded185f0ee5fb21e2a2e91e/websockets-15.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cad21560da69f4ce7658ca2cb83138fb4cf695a2ba3e475e0559e05991aa8122", size = 176884 }, { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743 }, -] \ No newline at end of file +]