From dce207b403a8e33e0a6ce70b724ec54c3e92831e Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Mon, 31 Mar 2025 17:35:43 +0800 Subject: [PATCH 01/23] fix: close the session when read_stream closed --- src/mcp/server/lowlevel/server.py | 1 + src/mcp/shared/session.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index dbaff30516..37e5821057 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -497,6 +497,7 @@ async def run( lifespan_context, raise_exceptions, ) + logger.debug("Server closed") async def _handle_message( self, diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 05fd3ce37f..9f04bf841a 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,6 +1,6 @@ import logging from collections.abc import Callable -from contextlib import AsyncExitStack +from contextlib import AsyncExitStack, asynccontextmanager from datetime import timedelta from types import TracebackType from typing import Any, Generic, TypeVar @@ -187,6 +187,7 @@ def __init__( self._receive_notification_type = receive_notification_type self._read_timeout_seconds = read_timeout_seconds self._in_flight = {} + self._receive_loop_alive = None self._exit_stack = AsyncExitStack() @@ -207,7 +208,9 @@ async def __aexit__( # would be very surprising behavior), so make sure to cancel the tasks # in the task group. self._task_group.cancel_scope.cancel() - return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + if self._receive_loop_alive: + return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + return False async def send_request( self, @@ -296,9 +299,19 @@ async def _send_response( await self._write_stream.send(JSONRPCMessage(jsonrpc_response)) async def _receive_loop(self) -> None: + @asynccontextmanager + async def receive_loop_status(): + try: + self._receive_loop_alive = True + yield + finally: + self._receive_loop_alive = False + async with ( self._read_stream, self._write_stream, + self._exit_stack, + receive_loop_status() ): async for message in self._read_stream: if isinstance(message, Exception): From 9e197cbd13dab0316d6801e50a094d50d8eb9dbb Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Sun, 30 Mar 2025 13:04:12 +0800 Subject: [PATCH 02/23] support mqtt transport --- examples/fastmcp/mqtt_simple_echo.py | 17 ++ pyproject.toml | 2 + src/mcp/server/fastmcp/server.py | 32 ++- src/mcp/server/mqtt.py | 366 +++++++++++++++++++++++++++ src/mcp/shared/mqtt_channel.py | 30 +++ uv.lock | 9 +- 6 files changed, 452 insertions(+), 4 deletions(-) create mode 100644 examples/fastmcp/mqtt_simple_echo.py create mode 100644 src/mcp/server/mqtt.py create mode 100644 src/mcp/shared/mqtt_channel.py diff --git a/examples/fastmcp/mqtt_simple_echo.py b/examples/fastmcp/mqtt_simple_echo.py new file mode 100644 index 0000000000..5cea4f183c --- /dev/null +++ b/examples/fastmcp/mqtt_simple_echo.py @@ -0,0 +1,17 @@ +""" +FastMCP Echo Server +""" + +from mcp.server.fastmcp import FastMCP + +# Create server +mcp = FastMCP( + "demo_server/echo", + log_level="DEBUG", + mqtt_service_description="A simple FastMCP server that echoes back the input text." +) + +@mcp.tool() +def echo(text: str) -> str: + """Echo the input text""" + return text diff --git a/pyproject.toml b/pyproject.toml index 25514cd6b0..52d5577b55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "sse-starlette>=1.6.1", "pydantic-settings>=2.5.2", "uvicorn>=0.23.1", + "paho-mqtt", ] [project.optional-dependencies] @@ -105,6 +106,7 @@ members = ["examples/servers/*"] [tool.uv.sources] mcp = { workspace = true } +paho-mqtt = { git = "https://github.com/eclipse-paho/paho.mqtt.python.git", tag = "v2.1.0" } [tool.pytest.ini_options] xfail_strict = true diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index bf0ce880a5..6f175e29d1 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -35,6 +35,7 @@ from mcp.server.lowlevel.server import lifespan as default_lifespan from mcp.server.session import ServerSession, ServerSessionT from mcp.server.sse import SseServerTransport +from mcp.server.mqtt import validate_service_name, start_mqtt, MqttOptions from mcp.server.stdio import stdio_server from mcp.shared.context import LifespanContextT, RequestContext from mcp.types import ( @@ -52,7 +53,6 @@ logger = get_logger(__name__) - class Settings(BaseSettings, Generic[LifespanResultT]): """FastMCP server settings. @@ -76,6 +76,12 @@ class Settings(BaseSettings, Generic[LifespanResultT]): sse_path: str = "/sse" message_path: str = "/messages/" + # MQTT settings + mqtt_service_description: str = '' + mqtt_service_meta: dict[str, Any] = {} + mqtt_client_id_prefix: str | None = None + mqtt_options: MqttOptions = MqttOptions() + # resource settings warn_on_duplicate_resources: bool = True @@ -145,18 +151,21 @@ def name(self) -> str: def instructions(self) -> str | None: return self._mcp_server.instructions - def run(self, transport: Literal["stdio", "sse"] = "stdio") -> None: + def run(self, transport: Literal["stdio", "sse", "mqtt"] = "stdio") -> None: """Run the FastMCP server. Note this is a synchronous function. Args: transport: Transport protocol to use ("stdio" or "sse") """ - TRANSPORTS = Literal["stdio", "sse"] + TRANSPORTS = Literal["stdio", "sse", "mqtt"] if transport not in TRANSPORTS.__args__: # type: ignore raise ValueError(f"Unknown transport: {transport}") if transport == "stdio": anyio.run(self.run_stdio_async) + elif transport == "mqtt": + validate_service_name(self._mcp_server.name) + anyio.run(self.run_mqtt_async) else: # transport == "sse" anyio.run(self.run_sse_async) @@ -477,6 +486,23 @@ async def run_sse_async(self) -> None: server = uvicorn.Server(config) await server.serve() + async def run_mqtt_async(self) -> None: + """Run the server using MQTT transport.""" + def server_run(read_stream: Any, write_stream: Any): + return self._mcp_server.run( + read_stream, + write_stream, + self._mcp_server.create_initialization_options(), + ) + await start_mqtt( + server_run, + service_name = self._mcp_server.name, + service_description=self.settings.mqtt_service_description, + service_meta = self.settings.mqtt_service_meta, + client_id_prefix = self.settings.mqtt_client_id_prefix, + mqtt_options = self.settings.mqtt_options + ) + def sse_app(self) -> Starlette: """Return an instance of the SSE server app.""" sse = SseServerTransport(self.settings.message_path) diff --git a/src/mcp/server/mqtt.py b/src/mcp/server/mqtt.py new file mode 100644 index 0000000000..62fc691068 --- /dev/null +++ b/src/mcp/server/mqtt.py @@ -0,0 +1,366 @@ +""" +SSE Server Transport Module + +This module implements a Server-Sent Events (SSE) transport layer for MCP servers." +""" + +import asyncio +import json +import traceback +from types import TracebackType +import mcp.shared.mqtt_channel as mqtt_channel +import paho.mqtt.client as mqtt +import logging +from paho.mqtt.reasoncodes import ReasonCode +from paho.mqtt.enums import CallbackAPIVersion +from paho.mqtt.properties import Properties +from paho.mqtt.subscribeoptions import SubscribeOptions +from uuid import uuid4 +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import BaseModel +from typing import Literal, Optional, Any, TypeAlias, Callable, Awaitable +import mcp.types as types +from typing_extensions import Self + +QOS = 1 +logger = logging.getLogger(__name__) + +RcvStream : TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage] +SndStream : TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage] +RcvStreamEx : TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] +SndStreamEX : TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage | Exception] +ServerRun : TypeAlias = Callable[[RcvStreamEx, SndStream], Awaitable[Any]] + +class MqttOptions(BaseModel): + host: str = "localhost" + port: int = 1883 + transport: Literal['tcp', 'websockets', 'unix'] = 'tcp' + keepalive: int = 60 + bind_address: str = '' + bind_port: int = 0 + username: Optional[str] = None + password: Optional[str] = None + tls_enabled: bool = False + tls_version: Optional[int] = None + tls_insecure: bool = False + ca_certs: Optional[str] = None + certfile: Optional[str] = None + keyfile: Optional[str] = None + ciphers: Optional[str] = None + keyfile_password: Optional[str] = None + alpn_protocols: Optional[list[str]] = None + websocket_path: str = '/mqtt' + websocket_headers: Optional[dict[str, str]] = None + +class MqttTransport: + _read_stream_writers: dict[ + str, SndStreamEX + ] + + def __init__(self, server_run: ServerRun, service_name: str, + service_description: str, + service_meta: dict[str, Any], + client_id_prefix: str | None = None, + mqtt_options: MqttOptions = MqttOptions()): + self._read_stream_writers = {} + self.resource_ids: dict[str, str] = {} + uuid = uuid4().hex + service_id = f"{client_id_prefix}-{uuid}" if client_id_prefix else uuid + self.mqtt_options = mqtt_options + self.service_name = service_name + self.service_description = service_description + self.service_meta = service_meta + self.service_id = service_id + self.service_control_channel = mqtt_channel.get_service_control_channel(service_name) + self.service_presence_channel = mqtt_channel.get_service_presence_channel(service_id, service_name) + self.service_capability_change_channel = mqtt_channel.get_service_capability_change_channel(service_id, service_name) + self.server_run = server_run + client = mqtt.Client( + callback_api_version=CallbackAPIVersion.VERSION2, + client_id=service_id, protocol=mqtt.MQTTv5, + userdata={}, + transport=mqtt_options.transport, reconnect_on_failure=True + ) + client.username_pw_set(mqtt_options.username, mqtt_options.password) + if mqtt_options.tls_enabled: + client.tls_set( # type: ignore + ca_certs=mqtt_options.ca_certs, + certfile=mqtt_options.certfile, + keyfile=mqtt_options.keyfile, + tls_version=mqtt_options.tls_version, + ciphers=mqtt_options.ciphers, + keyfile_password=mqtt_options.keyfile_password, + alpn_protocols=mqtt_options.alpn_protocols + ) + client.tls_insecure_set(mqtt_options.tls_insecure) + if mqtt_options.transport == 'websockets': + client.ws_set_options(path=mqtt_options.websocket_path, headers=mqtt_options.websocket_headers) + client.will_set(topic=self.service_presence_channel, payload=None, qos=QOS, retain=True) + client.on_connect = self._on_connect + client.on_message = self._on_message + client.on_subscribe = self._on_subscribe + self.client = client + + async def __aenter__(self) -> Self: + self._task_group = anyio.create_task_group() + await self._task_group.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + self._task_group.cancel_scope.cancel() + return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + + def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code : ReasonCode, properties: Properties | None): + if reason_code == 0: + logger.debug(f"Connected to MQTT broker_host at {self.mqtt_options.host}:{self.mqtt_options.port}") + self.assert_property(properties, "RetainAvailable", 1) + self.assert_property(properties, "WildcardSubscriptionAvailable", 1) + ## Subscribe to the service control channel + client.subscribe(self.service_control_channel, QOS) + ## Reister the service on the presence channel + online_msg = types.JSONRPCNotification( + jsonrpc="2.0", + method = "notifications/service/online", + params = { + "description": self.service_description, + "meta": self.service_meta + } + ) + client.publish(self.service_presence_channel, + payload=online_msg.model_dump_json(), qos=QOS, retain=True) + else: + logger.error(f"Failed to connect, return code {reason_code}") + + def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): + logger.debug(f"Received message on topic {msg.topic}: {msg.payload.decode()}") + match msg.topic: + case str() as t if t == self.service_control_channel: + self.handle_service_contorl_message(msg) + case str() as t if t.startswith(mqtt_channel.CLIENT_CAPABILITY_CHANGE_BASE): + self.handle_client_capability_change_message(msg) + case str() as t if t.startswith(mqtt_channel.RPC_BASE): + self.handle_rpc_message(msg) + case str() as t if t.startswith(mqtt_channel.CLIENT_PRESENCE_BASE): + self.handle_client_presence_message(msg) + case _: + logger.error(f"Received message on unexpected topic: {msg.topic}") + + def _on_subscribe(self, client: mqtt.Client, userdata: Any, mid: int, + reason_code_list: list[ReasonCode], properties: Properties | None): + if mid in userdata.get("pending_subs", {}): + mcp_client_id, msg, rpc_msg_id = userdata["pending_subs"].pop(mid) + ## only create session if all topic subscribed successfully + if all([rc.value == QOS for rc in reason_code_list]): + logger.debug(f"Subscribed to topics for mcp_client_id: {mcp_client_id}") + anyio.from_thread.run(self.create_session, mcp_client_id, msg) + else: + logger.error(f"Failed to subscribe to topics for mcp_client_id: {mcp_client_id}, reason_codes: {reason_code_list}") + err = types.JSONRPCError( + jsonrpc="2.0", + id=rpc_msg_id, + error=types.ErrorData( + code=types.INTERNAL_ERROR, + message="Failed to subscribe to client topics" + ) + ) + self.publish_json_rpc_message( + mqtt_channel.get_rpc_channel(mcp_client_id, self.service_name), + types.JSONRPCMessage(err) + ) + + def handle_service_contorl_message(self, msg: mqtt.MQTTMessage): + if msg.properties and hasattr(msg.properties, "UserProperty"): + user_properties: dict[str, Any] = dict(msg.properties.UserProperty) # type: ignore + if "mcp_client_id" in user_properties: + mcp_client_id = user_properties["mcp_client_id"] + if mcp_client_id in self._read_stream_writers: + anyio.from_thread.run(self.send_message_to_session, mcp_client_id, msg) + else: + self.maybe_subscribe_to_client(mcp_client_id, msg) + else: + logger.error("No mcp_client_id in UserProperties") + else: + logger.error("No UserProperties in control message") + + def handle_client_capability_change_message(self, msg: mqtt.MQTTMessage) -> None: + mcp_client_id = msg.topic.split("/")[-1] + anyio.from_thread.run(self.send_message_to_session, mcp_client_id, msg) + + def handle_rpc_message(self, msg: mqtt.MQTTMessage) -> None: + mcp_client_id = msg.topic.split("/")[1] + anyio.from_thread.run(self.send_message_to_session, mcp_client_id, msg) + + def handle_client_presence_message(self, msg: mqtt.MQTTMessage) -> None: + mcp_client_id = msg.topic.split("/")[-1] + if mcp_client_id not in self._read_stream_writers: + logger.error(f"No session for mcp_client_id: {mcp_client_id}") + return + try: + json_msg = json.loads(msg.payload.decode()) + if "method" in json_msg: + if json_msg["method"] == "notifications/disconnected": + stream = self._read_stream_writers.pop(mcp_client_id) + anyio.from_thread.run(stream.aclose) + logger.debug(f"Removed session for mcp_client_id: {mcp_client_id}") + else: + logger.error(f"Unknown method in control message for mcp_client_id: {mcp_client_id}") + else: + logger.error(f"No method in control message for mcp_client_id: {mcp_client_id}") + except json.JSONDecodeError: + logger.error(f"Invalid JSON in control message for mcp_client_id: {mcp_client_id}") + + async def create_session(self, mcp_client_id: str, msg: mqtt.MQTTMessage): + ## Streams are used to communicate between the MqttTransport and the MCPSession: + ## 1. (msg) --> MqttBroker --> MqttTransport -->[read_stream_writer]-->[read_stream]--> MCPSession + ## 2. MqttBroker <-- MqttTransport <--[write_stream_reader]--[write_stream]-- MCPSession <-- (msg) + read_stream: RcvStreamEx + read_stream_writer: SndStreamEX + write_stream: SndStream + write_stream_reader: RcvStream + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + self._read_stream_writers[mcp_client_id] = read_stream_writer + self._task_group.start_soon(self.server_run, read_stream, write_stream) + self._task_group.start_soon(self.receieved_from_session, mcp_client_id, write_stream_reader) + logger.debug(f"Created new session for mcp_client_id: {mcp_client_id}") + await self.send_message_to_session(mcp_client_id, msg) + + def maybe_subscribe_to_client(self, mcp_client_id: str, msg: mqtt.MQTTMessage): + try: + json_msg = json.loads(msg.payload.decode()) + if "id" in json_msg: + rpc_msg_id = json_msg["id"] + self.subscribe_to_client(mcp_client_id, msg, rpc_msg_id) + else: + logger.error(f"No id in control message for mcp_client_id: {mcp_client_id}") + except json.JSONDecodeError: + logger.error(f"Invalid JSON in control message for mcp_client_id: {mcp_client_id}") + return + + def subscribe_to_client(self, mcp_client_id: str, msg: mqtt.MQTTMessage, rcp_msg_id: Any): + topic_filters = [ + (mqtt_channel.get_client_presence_channel(mcp_client_id), SubscribeOptions(qos=QOS)), + (mqtt_channel.get_client_capability_change_channel(mcp_client_id), SubscribeOptions(qos=QOS)), + (mqtt_channel.get_rpc_channel(mcp_client_id, self.service_name), SubscribeOptions(qos=QOS, noLocal=True)) + ] + ret, mid = self.client.subscribe(topic=topic_filters) + if ret != mqtt.MQTT_ERR_SUCCESS: + logger.error(f"Failed to subscribe to topics for mcp_client_id: {mcp_client_id}") + return + userdata = self.client.user_data_get() + pending_subs = userdata.get("pending_subs", {}) + pending_subs[mid] = (mcp_client_id, msg, rcp_msg_id) + userdata["pending_subs"] = pending_subs + + async def send_message_to_session(self, mcp_client_id: str, msg: mqtt.MQTTMessage): + payload = msg.payload.decode() + if mcp_client_id not in self._read_stream_writers: + logger.error(f"No session for mcp_client_id: {mcp_client_id}") + return + read_stream_writer = self._read_stream_writers[mcp_client_id] + try: + message = types.JSONRPCMessage.model_validate_json(payload) + logger.debug(f"Sending msg to session for mcp_client_id: {mcp_client_id}, msg: {message}") + with anyio.fail_after(3): + await read_stream_writer.send(message) + except Exception as exc: + logger.error(f"Failed to send msg to session for mcp_client_id: {mcp_client_id}, exception: {exc}") + traceback.print_exc() + ## TODO: the session does not handle exceptions for now + #await read_stream_writer.send(exc) + + async def receieved_from_session(self, mcp_client_id: str, write_stream_reader: RcvStream): + async with write_stream_reader: + async for msg in write_stream_reader: + logger.debug(f"Got msg from session for mcp_client_id: {mcp_client_id}, msg: {msg}") + match msg.model_dump(): + case {"method": "notifications/resources/updated", "params": {"uri": uri}}: + ## Mantain a mapping of resource_id to uri + resource_id = uuid4().hex + self.resource_ids[resource_id] = uri + topic = mqtt_channel.get_service_resource_update_channel(self.service_id, resource_id) + case {"method": method} if method.endswith("/list_changed"): + topic = mqtt_channel.get_service_capability_change_channel(self.service_id, self.service_name) + case _: + topic = mqtt_channel.get_rpc_channel(mcp_client_id, self.service_name) + self.publish_json_rpc_message(topic, msg) + # cleanup + if mcp_client_id in self._read_stream_writers: + logger.debug(f"Removing session for mcp_client_id: {mcp_client_id}") + stream = self._read_stream_writers.pop(mcp_client_id) + await stream.aclose() + + logger.debug(f"Session stream closed for mcp_client_id: {mcp_client_id}") + + def publish_json_rpc_message(self, topic: str, message: types.JSONRPCMessage): + json = message.model_dump_json(by_alias=True, exclude_none=True) + self.client.publish(topic, json, qos=QOS) + + def connect(self): + logger.debug("Setting up MQTT connection") + self.client.connect( + host = self.mqtt_options.host, + port = self.mqtt_options.port, + keepalive = self.mqtt_options.keepalive, + bind_address = self.mqtt_options.bind_address, + bind_port = self.mqtt_options.bind_port, + clean_start=True + ) + + def assert_property(self, properties: Properties | None, property_name: str, expected_value: Any): + if get_property(properties, property_name) == expected_value: + pass + else: + self.stop_mqtt() + raise ValueError(f"{property_name} not available") + + def stop_mqtt(self): + self.client.publish(self.service_presence_channel, payload=None, qos=QOS, retain=True) + self.client.disconnect() + self.client.loop_stop() + for stream in self._read_stream_writers.values(): + anyio.from_thread.run(stream.aclose) + self._read_stream_writers = {} + logger.debug("Disconnected from MQTT broker_host") + +def get_property(properties: Properties | None, property_name: str): + if properties and hasattr(properties, property_name): + return getattr(properties, property_name) + else: + return False + +async def start_mqtt( + server_run: ServerRun, service_name: str, + service_description: str, + service_meta: dict[str, Any], + client_id_prefix: str | None = None, + mqtt_options: MqttOptions = MqttOptions()): + async with MqttTransport( + server_run, + service_name = service_name, + service_description=service_description, + service_meta = service_meta, + client_id_prefix = client_id_prefix, + mqtt_options = mqtt_options + ) as mqtt_trans: + def start(): + mqtt_trans.connect() + mqtt_trans.client.loop_forever() + try: + await anyio.to_thread.run_sync(start) + except asyncio.CancelledError: + logger.debug("MQTT transport got cancelled") + +def validate_service_name(name: str): + if "/" not in name: + raise ValueError(f"Invalid service name: {name}, must contain a '/'") + elif ("+" in name) or ("#" in name): + raise ValueError(f"Invalid service name: {name}, must not contain '+' or '#'") + elif name[0] == "/": + raise ValueError(f"Invalid service name: {name}, must not start with '/'") diff --git a/src/mcp/shared/mqtt_channel.py b/src/mcp/shared/mqtt_channel.py new file mode 100644 index 0000000000..1de1801090 --- /dev/null +++ b/src/mcp/shared/mqtt_channel.py @@ -0,0 +1,30 @@ + + +SERVICE_CONTROL_BASE: str = '$mcp-service' +SERVICE_CAPABILITY_CHANGE_BASE: str = '$mcp-service/capability-change' +SERVICE_RESOURCE_UPDATE_BASE: str = '$mcp-service/resource-update' +SERVICE_PRESENCE_BASE: str = '$mcp-service/presence' +CLIENT_PRESENCE_BASE: str = '$mcp-client/presence' +CLIENT_CAPABILITY_CHANGE_BASE: str = '$mcp-client/capability-change' +RPC_BASE: str = '$mcp-rpc-endpoint' + +def get_service_control_channel(service_name: str) -> str: + return f"{SERVICE_CONTROL_BASE}/{service_name}" + +def get_service_capability_change_channel(service_id: str, service_name: str) -> str: + return f"{SERVICE_CAPABILITY_CHANGE_BASE}/{service_id}/{service_name}" + +def get_service_resource_update_channel(service_id: str, resource_id: str) -> str: + return f"{SERVICE_RESOURCE_UPDATE_BASE}/{service_id}/{resource_id}" + +def get_service_presence_channel(service_id: str, service_name: str) -> str: + return f"{SERVICE_PRESENCE_BASE}/{service_id}/{service_name}" + +def get_client_presence_channel(mcp_clientid: str) -> str: + return f"{CLIENT_PRESENCE_BASE}/{mcp_clientid}" + +def get_client_capability_change_channel(mcp_clientid: str) -> str: + return f"{CLIENT_CAPABILITY_CHANGE_BASE}/{mcp_clientid}" + +def get_rpc_channel(mcp_clientid: str, service_name: str) -> str: + return f"{RPC_BASE}/{mcp_clientid}/{service_name}" diff --git a/uv.lock b/uv.lock index 424e2d4823..1a50bb9d5e 100644 --- a/uv.lock +++ b/uv.lock @@ -492,6 +492,7 @@ dependencies = [ { name = "anyio" }, { name = "httpx" }, { name = "httpx-sse" }, + { name = "paho-mqtt" }, { name = "pydantic" }, { name = "pydantic-settings" }, { name = "sse-starlette" }, @@ -533,6 +534,7 @@ requires-dist = [ { name = "anyio", specifier = ">=4.5" }, { name = "httpx", specifier = ">=0.27" }, { name = "httpx-sse", specifier = ">=0.4" }, + { name = "paho-mqtt", git = "https://github.com/eclipse-paho/paho.mqtt.python.git?tag=v2.1.0" }, { name = "pydantic", specifier = ">=2.7.2,<3.0.0" }, { name = "pydantic-settings", specifier = ">=2.5.2" }, { name = "python-dotenv", marker = "extra == 'cli'", specifier = ">=1.0.0" }, @@ -856,6 +858,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/90/96/04b8e52da071d28f5e21a805b19cb9390aa17a47462ac87f5e2696b9566d/paginate-0.5.7-py2.py3-none-any.whl", hash = "sha256:b885e2af73abcf01d9559fd5216b57ef722f8c42affbb63942377668e35c7591", size = 13746 }, ] +[[package]] +name = "paho-mqtt" +version = "2.1.0" +source = { git = "https://github.com/eclipse-paho/paho.mqtt.python.git?tag=v2.1.0#af64a4365c6ac5a7a4d339e7b00f44df91353b35" } + [[package]] name = "pathspec" version = "0.12.1" @@ -1618,4 +1625,4 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2b/fb/c492d6daa5ec067c2988ac80c61359ace5c4c674c532985ac5a123436cec/websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b359ed09954d7c18bbc1680f380c7301f92c60bf924171629c5db97febb12f04", size = 174155 }, { url = "https://files.pythonhosted.org/packages/68/a1/dcb68430b1d00b698ae7a7e0194433bce4f07ded185f0ee5fb21e2a2e91e/websockets-15.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cad21560da69f4ce7658ca2cb83138fb4cf695a2ba3e475e0559e05991aa8122", size = 176884 }, { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743 }, -] \ No newline at end of file +] From bcd729bc16bd453b23c831cbe6a8d1ac5b2d4b59 Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Tue, 1 Apr 2025 21:23:56 +0800 Subject: [PATCH 03/23] change user property name mcp_client_id to mcp-client-id --- src/mcp/server/mqtt.py | 52 ++++++++++++++++------------------ src/mcp/shared/mqtt_channel.py | 30 -------------------- src/mcp/shared/mqtt_topic.py | 30 ++++++++++++++++++++ 3 files changed, 55 insertions(+), 57 deletions(-) delete mode 100644 src/mcp/shared/mqtt_channel.py create mode 100644 src/mcp/shared/mqtt_topic.py diff --git a/src/mcp/server/mqtt.py b/src/mcp/server/mqtt.py index 62fc691068..87c309e2a0 100644 --- a/src/mcp/server/mqtt.py +++ b/src/mcp/server/mqtt.py @@ -8,7 +8,7 @@ import json import traceback from types import TracebackType -import mcp.shared.mqtt_channel as mqtt_channel +import mcp.shared.mqtt_topic as mqtt_topic import paho.mqtt.client as mqtt import logging from paho.mqtt.reasoncodes import ReasonCode @@ -24,6 +24,7 @@ from typing_extensions import Self QOS = 1 +PROPERTY_K_MCP_CLIENT_ID = "mcp-client-id" logger = logging.getLogger(__name__) RcvStream : TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage] @@ -72,9 +73,9 @@ def __init__(self, server_run: ServerRun, service_name: str, self.service_description = service_description self.service_meta = service_meta self.service_id = service_id - self.service_control_channel = mqtt_channel.get_service_control_channel(service_name) - self.service_presence_channel = mqtt_channel.get_service_presence_channel(service_id, service_name) - self.service_capability_change_channel = mqtt_channel.get_service_capability_change_channel(service_id, service_name) + self.service_control_topic = mqtt_topic.get_service_control_topic(service_name) + self.service_presence_topic = mqtt_topic.get_service_presence_topic(service_id, service_name) + self.service_capability_change_topic = mqtt_topic.get_service_capability_change_topic(service_id, service_name) self.server_run = server_run client = mqtt.Client( callback_api_version=CallbackAPIVersion.VERSION2, @@ -96,7 +97,7 @@ def __init__(self, server_run: ServerRun, service_name: str, client.tls_insecure_set(mqtt_options.tls_insecure) if mqtt_options.transport == 'websockets': client.ws_set_options(path=mqtt_options.websocket_path, headers=mqtt_options.websocket_headers) - client.will_set(topic=self.service_presence_channel, payload=None, qos=QOS, retain=True) + client.will_set(topic=self.service_presence_topic, payload=None, qos=QOS, retain=True) client.on_connect = self._on_connect client.on_message = self._on_message client.on_subscribe = self._on_subscribe @@ -122,7 +123,7 @@ def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.Co self.assert_property(properties, "RetainAvailable", 1) self.assert_property(properties, "WildcardSubscriptionAvailable", 1) ## Subscribe to the service control channel - client.subscribe(self.service_control_channel, QOS) + client.subscribe(self.service_control_topic, QOS) ## Reister the service on the presence channel online_msg = types.JSONRPCNotification( jsonrpc="2.0", @@ -132,7 +133,7 @@ def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.Co "meta": self.service_meta } ) - client.publish(self.service_presence_channel, + client.publish(self.service_presence_topic, payload=online_msg.model_dump_json(), qos=QOS, retain=True) else: logger.error(f"Failed to connect, return code {reason_code}") @@ -140,13 +141,13 @@ def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.Co def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): logger.debug(f"Received message on topic {msg.topic}: {msg.payload.decode()}") match msg.topic: - case str() as t if t == self.service_control_channel: + case str() as t if t == self.service_control_topic: self.handle_service_contorl_message(msg) - case str() as t if t.startswith(mqtt_channel.CLIENT_CAPABILITY_CHANGE_BASE): + case str() as t if t.startswith(mqtt_topic.CLIENT_CAPABILITY_CHANGE_BASE): self.handle_client_capability_change_message(msg) - case str() as t if t.startswith(mqtt_channel.RPC_BASE): + case str() as t if t.startswith(mqtt_topic.RPC_BASE): self.handle_rpc_message(msg) - case str() as t if t.startswith(mqtt_channel.CLIENT_PRESENCE_BASE): + case str() as t if t.startswith(mqtt_topic.CLIENT_PRESENCE_BASE): self.handle_client_presence_message(msg) case _: logger.error(f"Received message on unexpected topic: {msg.topic}") @@ -170,21 +171,21 @@ def _on_subscribe(self, client: mqtt.Client, userdata: Any, mid: int, ) ) self.publish_json_rpc_message( - mqtt_channel.get_rpc_channel(mcp_client_id, self.service_name), + mqtt_topic.get_rpc_topic(mcp_client_id, self.service_name), types.JSONRPCMessage(err) ) def handle_service_contorl_message(self, msg: mqtt.MQTTMessage): if msg.properties and hasattr(msg.properties, "UserProperty"): user_properties: dict[str, Any] = dict(msg.properties.UserProperty) # type: ignore - if "mcp_client_id" in user_properties: - mcp_client_id = user_properties["mcp_client_id"] + if PROPERTY_K_MCP_CLIENT_ID in user_properties: + mcp_client_id = user_properties[PROPERTY_K_MCP_CLIENT_ID] if mcp_client_id in self._read_stream_writers: anyio.from_thread.run(self.send_message_to_session, mcp_client_id, msg) else: self.maybe_subscribe_to_client(mcp_client_id, msg) else: - logger.error("No mcp_client_id in UserProperties") + logger.error(f"No {PROPERTY_K_MCP_CLIENT_ID} in UserProperties") else: logger.error("No UserProperties in control message") @@ -245,9 +246,9 @@ def maybe_subscribe_to_client(self, mcp_client_id: str, msg: mqtt.MQTTMessage): def subscribe_to_client(self, mcp_client_id: str, msg: mqtt.MQTTMessage, rcp_msg_id: Any): topic_filters = [ - (mqtt_channel.get_client_presence_channel(mcp_client_id), SubscribeOptions(qos=QOS)), - (mqtt_channel.get_client_capability_change_channel(mcp_client_id), SubscribeOptions(qos=QOS)), - (mqtt_channel.get_rpc_channel(mcp_client_id, self.service_name), SubscribeOptions(qos=QOS, noLocal=True)) + (mqtt_topic.get_client_presence_topic(mcp_client_id), SubscribeOptions(qos=QOS)), + (mqtt_topic.get_client_capability_change_topic(mcp_client_id), SubscribeOptions(qos=QOS)), + (mqtt_topic.get_rpc_topic(mcp_client_id, self.service_name), SubscribeOptions(qos=QOS, noLocal=True)) ] ret, mid = self.client.subscribe(topic=topic_filters) if ret != mqtt.MQTT_ERR_SUCCESS: @@ -280,16 +281,13 @@ async def receieved_from_session(self, mcp_client_id: str, write_stream_reader: async for msg in write_stream_reader: logger.debug(f"Got msg from session for mcp_client_id: {mcp_client_id}, msg: {msg}") match msg.model_dump(): - case {"method": "notifications/resources/updated", "params": {"uri": uri}}: - ## Mantain a mapping of resource_id to uri - resource_id = uuid4().hex - self.resource_ids[resource_id] = uri - topic = mqtt_channel.get_service_resource_update_channel(self.service_id, resource_id) + case {"method": "notifications/resources/updated"}: + logger.warning("Resource updates should not be sent from the session. Ignoring.") case {"method": method} if method.endswith("/list_changed"): - topic = mqtt_channel.get_service_capability_change_channel(self.service_id, self.service_name) + logger.warning("Resource updates should not be sent from the session. Ignoring.") case _: - topic = mqtt_channel.get_rpc_channel(mcp_client_id, self.service_name) - self.publish_json_rpc_message(topic, msg) + topic = mqtt_topic.get_rpc_topic(mcp_client_id, self.service_name) + self.publish_json_rpc_message(topic, msg) # cleanup if mcp_client_id in self._read_stream_writers: logger.debug(f"Removing session for mcp_client_id: {mcp_client_id}") @@ -321,7 +319,7 @@ def assert_property(self, properties: Properties | None, property_name: str, exp raise ValueError(f"{property_name} not available") def stop_mqtt(self): - self.client.publish(self.service_presence_channel, payload=None, qos=QOS, retain=True) + self.client.publish(self.service_presence_topic, payload=None, qos=QOS, retain=True) self.client.disconnect() self.client.loop_stop() for stream in self._read_stream_writers.values(): diff --git a/src/mcp/shared/mqtt_channel.py b/src/mcp/shared/mqtt_channel.py deleted file mode 100644 index 1de1801090..0000000000 --- a/src/mcp/shared/mqtt_channel.py +++ /dev/null @@ -1,30 +0,0 @@ - - -SERVICE_CONTROL_BASE: str = '$mcp-service' -SERVICE_CAPABILITY_CHANGE_BASE: str = '$mcp-service/capability-change' -SERVICE_RESOURCE_UPDATE_BASE: str = '$mcp-service/resource-update' -SERVICE_PRESENCE_BASE: str = '$mcp-service/presence' -CLIENT_PRESENCE_BASE: str = '$mcp-client/presence' -CLIENT_CAPABILITY_CHANGE_BASE: str = '$mcp-client/capability-change' -RPC_BASE: str = '$mcp-rpc-endpoint' - -def get_service_control_channel(service_name: str) -> str: - return f"{SERVICE_CONTROL_BASE}/{service_name}" - -def get_service_capability_change_channel(service_id: str, service_name: str) -> str: - return f"{SERVICE_CAPABILITY_CHANGE_BASE}/{service_id}/{service_name}" - -def get_service_resource_update_channel(service_id: str, resource_id: str) -> str: - return f"{SERVICE_RESOURCE_UPDATE_BASE}/{service_id}/{resource_id}" - -def get_service_presence_channel(service_id: str, service_name: str) -> str: - return f"{SERVICE_PRESENCE_BASE}/{service_id}/{service_name}" - -def get_client_presence_channel(mcp_clientid: str) -> str: - return f"{CLIENT_PRESENCE_BASE}/{mcp_clientid}" - -def get_client_capability_change_channel(mcp_clientid: str) -> str: - return f"{CLIENT_CAPABILITY_CHANGE_BASE}/{mcp_clientid}" - -def get_rpc_channel(mcp_clientid: str, service_name: str) -> str: - return f"{RPC_BASE}/{mcp_clientid}/{service_name}" diff --git a/src/mcp/shared/mqtt_topic.py b/src/mcp/shared/mqtt_topic.py new file mode 100644 index 0000000000..1f96f255c9 --- /dev/null +++ b/src/mcp/shared/mqtt_topic.py @@ -0,0 +1,30 @@ + + +SERVICE_CONTROL_BASE: str = '$mcp-service' +SERVICE_CAPABILITY_CHANGE_BASE: str = '$mcp-service/capability/list-changed' +SERVICE_RESOURCE_UPDATE_BASE: str = '$mcp-service/capability/resource-updated' +SERVICE_PRESENCE_BASE: str = '$mcp-service/presence' +CLIENT_PRESENCE_BASE: str = '$mcp-client/presence' +CLIENT_CAPABILITY_CHANGE_BASE: str = '$mcp-client/capability/list-changed' +RPC_BASE: str = '$mcp-rpc-endpoint' + +def get_service_control_topic(service_name: str) -> str: + return f"{SERVICE_CONTROL_BASE}/{service_name}" + +def get_service_capability_change_topic(service_id: str, service_name: str) -> str: + return f"{SERVICE_CAPABILITY_CHANGE_BASE}/{service_id}/{service_name}" + +def get_service_resource_update_topic(service_id: str) -> str: + return f"{SERVICE_RESOURCE_UPDATE_BASE}/{service_id}" + +def get_service_presence_topic(service_id: str, service_name: str) -> str: + return f"{SERVICE_PRESENCE_BASE}/{service_id}/{service_name}" + +def get_client_presence_topic(mcp_clientid: str) -> str: + return f"{CLIENT_PRESENCE_BASE}/{mcp_clientid}" + +def get_client_capability_change_topic(mcp_clientid: str) -> str: + return f"{CLIENT_CAPABILITY_CHANGE_BASE}/{mcp_clientid}" + +def get_rpc_topic(mcp_clientid: str, service_name: str) -> str: + return f"{RPC_BASE}/{mcp_clientid}/{service_name}" From 91ad517bf0370d7b5fd3c11652e208b7190a8c8f Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Tue, 1 Apr 2025 22:52:22 +0800 Subject: [PATCH 04/23] separate the mqtt transport to base class and sub classes --- src/mcp/server/mqtt.py | 149 ++++++----------------------------------- src/mcp/shared/mqtt.py | 148 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 167 insertions(+), 130 deletions(-) create mode 100644 src/mcp/shared/mqtt.py diff --git a/src/mcp/server/mqtt.py b/src/mcp/server/mqtt.py index 87c309e2a0..f23b361385 100644 --- a/src/mcp/server/mqtt.py +++ b/src/mcp/server/mqtt.py @@ -1,31 +1,21 @@ """ -SSE Server Transport Module - -This module implements a Server-Sent Events (SSE) transport layer for MCP servers." +This module implements the MQTT transport for the MCP server. """ - +from uuid import uuid4 +from mcp.shared.mqtt import MqttTransportBase, MqttOptions, QOS import asyncio import json import traceback -from types import TracebackType import mcp.shared.mqtt_topic as mqtt_topic import paho.mqtt.client as mqtt import logging from paho.mqtt.reasoncodes import ReasonCode -from paho.mqtt.enums import CallbackAPIVersion from paho.mqtt.properties import Properties from paho.mqtt.subscribeoptions import SubscribeOptions -from uuid import uuid4 import anyio from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import BaseModel -from typing import Literal, Optional, Any, TypeAlias, Callable, Awaitable +from typing import Any, TypeAlias, Callable, Awaitable import mcp.types as types -from typing_extensions import Self - -QOS = 1 -PROPERTY_K_MCP_CLIENT_ID = "mcp-client-id" -logger = logging.getLogger(__name__) RcvStream : TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage] SndStream : TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage] @@ -33,98 +23,36 @@ SndStreamEX : TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage | Exception] ServerRun : TypeAlias = Callable[[RcvStreamEx, SndStream], Awaitable[Any]] -class MqttOptions(BaseModel): - host: str = "localhost" - port: int = 1883 - transport: Literal['tcp', 'websockets', 'unix'] = 'tcp' - keepalive: int = 60 - bind_address: str = '' - bind_port: int = 0 - username: Optional[str] = None - password: Optional[str] = None - tls_enabled: bool = False - tls_version: Optional[int] = None - tls_insecure: bool = False - ca_certs: Optional[str] = None - certfile: Optional[str] = None - keyfile: Optional[str] = None - ciphers: Optional[str] = None - keyfile_password: Optional[str] = None - alpn_protocols: Optional[list[str]] = None - websocket_path: str = '/mqtt' - websocket_headers: Optional[dict[str, str]] = None +PROPERTY_K_MCP_CLIENT_ID = "mcp-client-id" +logger = logging.getLogger(__name__) -class MqttTransport: - _read_stream_writers: dict[ - str, SndStreamEX - ] +class MqttTransport(MqttTransportBase): def __init__(self, server_run: ServerRun, service_name: str, service_description: str, service_meta: dict[str, Any], client_id_prefix: str | None = None, mqtt_options: MqttOptions = MqttOptions()): - self._read_stream_writers = {} - self.resource_ids: dict[str, str] = {} uuid = uuid4().hex - service_id = f"{client_id_prefix}-{uuid}" if client_id_prefix else uuid - self.mqtt_options = mqtt_options + mqtt_clientid = f"{client_id_prefix}-{uuid}" if client_id_prefix else uuid + self.service_id = mqtt_clientid self.service_name = service_name self.service_description = service_description self.service_meta = service_meta - self.service_id = service_id self.service_control_topic = mqtt_topic.get_service_control_topic(service_name) - self.service_presence_topic = mqtt_topic.get_service_presence_topic(service_id, service_name) - self.service_capability_change_topic = mqtt_topic.get_service_capability_change_topic(service_id, service_name) + self.service_presence_topic = mqtt_topic.get_service_presence_topic(self.service_id, service_name) + self.service_capability_change_topic = mqtt_topic.get_service_capability_change_topic(self.service_id, service_name) self.server_run = server_run - client = mqtt.Client( - callback_api_version=CallbackAPIVersion.VERSION2, - client_id=service_id, protocol=mqtt.MQTTv5, - userdata={}, - transport=mqtt_options.transport, reconnect_on_failure=True - ) - client.username_pw_set(mqtt_options.username, mqtt_options.password) - if mqtt_options.tls_enabled: - client.tls_set( # type: ignore - ca_certs=mqtt_options.ca_certs, - certfile=mqtt_options.certfile, - keyfile=mqtt_options.keyfile, - tls_version=mqtt_options.tls_version, - ciphers=mqtt_options.ciphers, - keyfile_password=mqtt_options.keyfile_password, - alpn_protocols=mqtt_options.alpn_protocols - ) - client.tls_insecure_set(mqtt_options.tls_insecure) - if mqtt_options.transport == 'websockets': - client.ws_set_options(path=mqtt_options.websocket_path, headers=mqtt_options.websocket_headers) - client.will_set(topic=self.service_presence_topic, payload=None, qos=QOS, retain=True) - client.on_connect = self._on_connect - client.on_message = self._on_message - client.on_subscribe = self._on_subscribe - self.client = client - - async def __aenter__(self) -> Self: - self._task_group = anyio.create_task_group() - await self._task_group.__aenter__() - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> bool | None: - self._task_group.cancel_scope.cancel() - return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + super().__init__(mqtt_clientid, mqtt_options) + self.presence_topic = mqtt_topic.get_service_presence_topic(self.service_id, service_name) + self.client.will_set(topic=self.presence_topic, payload=None, qos=QOS, retain=True) def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code : ReasonCode, properties: Properties | None): if reason_code == 0: - logger.debug(f"Connected to MQTT broker_host at {self.mqtt_options.host}:{self.mqtt_options.port}") - self.assert_property(properties, "RetainAvailable", 1) - self.assert_property(properties, "WildcardSubscriptionAvailable", 1) - ## Subscribe to the service control channel + super()._on_connect(client, userdata, connect_flags, reason_code, properties) + ## Subscribe to the service control topic client.subscribe(self.service_control_topic, QOS) - ## Reister the service on the presence channel + ## Reister the service on the presence topic online_msg = types.JSONRPCNotification( jsonrpc="2.0", method = "notifications/service/online", @@ -133,10 +61,8 @@ def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.Co "meta": self.service_meta } ) - client.publish(self.service_presence_topic, - payload=online_msg.model_dump_json(), qos=QOS, retain=True) - else: - logger.error(f"Failed to connect, return code {reason_code}") + client.publish(self.presence_topic, payload=online_msg.model_dump_json(), + qos=QOS, retain=True) def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): logger.debug(f"Received message on topic {msg.topic}: {msg.payload.decode()}") @@ -296,43 +222,6 @@ async def receieved_from_session(self, mcp_client_id: str, write_stream_reader: logger.debug(f"Session stream closed for mcp_client_id: {mcp_client_id}") - def publish_json_rpc_message(self, topic: str, message: types.JSONRPCMessage): - json = message.model_dump_json(by_alias=True, exclude_none=True) - self.client.publish(topic, json, qos=QOS) - - def connect(self): - logger.debug("Setting up MQTT connection") - self.client.connect( - host = self.mqtt_options.host, - port = self.mqtt_options.port, - keepalive = self.mqtt_options.keepalive, - bind_address = self.mqtt_options.bind_address, - bind_port = self.mqtt_options.bind_port, - clean_start=True - ) - - def assert_property(self, properties: Properties | None, property_name: str, expected_value: Any): - if get_property(properties, property_name) == expected_value: - pass - else: - self.stop_mqtt() - raise ValueError(f"{property_name} not available") - - def stop_mqtt(self): - self.client.publish(self.service_presence_topic, payload=None, qos=QOS, retain=True) - self.client.disconnect() - self.client.loop_stop() - for stream in self._read_stream_writers.values(): - anyio.from_thread.run(stream.aclose) - self._read_stream_writers = {} - logger.debug("Disconnected from MQTT broker_host") - -def get_property(properties: Properties | None, property_name: str): - if properties and hasattr(properties, property_name): - return getattr(properties, property_name) - else: - return False - async def start_mqtt( server_run: ServerRun, service_name: str, service_description: str, diff --git a/src/mcp/shared/mqtt.py b/src/mcp/shared/mqtt.py new file mode 100644 index 0000000000..cff6ff18f6 --- /dev/null +++ b/src/mcp/shared/mqtt.py @@ -0,0 +1,148 @@ +""" +MQTT Transport Base Module + +""" +from types import TracebackType +import paho.mqtt.client as mqtt +import logging +from paho.mqtt.reasoncodes import ReasonCode +from paho.mqtt.enums import CallbackAPIVersion +from paho.mqtt.properties import Properties +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import BaseModel +from typing import Literal, Optional, Any, TypeAlias, Callable, Awaitable +import mcp.types as types +from typing_extensions import Self + +QOS = 1 +logger = logging.getLogger(__name__) + +RcvStream : TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage] +SndStream : TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage] +RcvStreamEx : TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] +SndStreamEX : TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage | Exception] +ServerRun : TypeAlias = Callable[[RcvStreamEx, SndStream], Awaitable[Any]] + +class MqttOptions(BaseModel): + host: str = "localhost" + port: int = 1883 + transport: Literal['tcp', 'websockets', 'unix'] = 'tcp' + keepalive: int = 60 + bind_address: str = '' + bind_port: int = 0 + username: Optional[str] = None + password: Optional[str] = None + tls_enabled: bool = False + tls_version: Optional[int] = None + tls_insecure: bool = False + ca_certs: Optional[str] = None + certfile: Optional[str] = None + keyfile: Optional[str] = None + ciphers: Optional[str] = None + keyfile_password: Optional[str] = None + alpn_protocols: Optional[list[str]] = None + websocket_path: str = '/mqtt' + websocket_headers: Optional[dict[str, str]] = None + +class MqttTransportBase: + _read_stream_writers: dict[ + str, SndStreamEX + ] + + def __init__(self, mqtt_clientid: str | None = None, + mqtt_options: MqttOptions = MqttOptions()): + self._read_stream_writers = {} + self.mqtt_clientid = mqtt_clientid + self.mqtt_options = mqtt_options + client = mqtt.Client( + callback_api_version=CallbackAPIVersion.VERSION2, + client_id=mqtt_clientid, protocol=mqtt.MQTTv5, + userdata={}, + transport=mqtt_options.transport, reconnect_on_failure=True + ) + client.username_pw_set(mqtt_options.username, mqtt_options.password) + if mqtt_options.tls_enabled: + client.tls_set( # type: ignore + ca_certs=mqtt_options.ca_certs, + certfile=mqtt_options.certfile, + keyfile=mqtt_options.keyfile, + tls_version=mqtt_options.tls_version, + ciphers=mqtt_options.ciphers, + keyfile_password=mqtt_options.keyfile_password, + alpn_protocols=mqtt_options.alpn_protocols + ) + client.tls_insecure_set(mqtt_options.tls_insecure) + if mqtt_options.transport == 'websockets': + client.ws_set_options(path=mqtt_options.websocket_path, headers=mqtt_options.websocket_headers) + client.on_connect = self._on_connect + client.on_message = self._on_message + client.on_subscribe = self._on_subscribe + self.client = client + + async def __aenter__(self) -> Self: + self._task_group = anyio.create_task_group() + await self._task_group.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + self._task_group.cancel_scope.cancel() + return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + + def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code : ReasonCode, properties: Properties | None): + if reason_code == 0: + logger.debug(f"Connected to MQTT broker_host at {self.mqtt_options.host}:{self.mqtt_options.port}") + self.assert_property(properties, "RetainAvailable", 1) + self.assert_property(properties, "WildcardSubscriptionAvailable", 1) + else: + logger.error(f"Failed to connect, return code {reason_code}") + + def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): + pass + + def _on_subscribe(self, client: mqtt.Client, userdata: Any, mid: int, + reason_code_list: list[ReasonCode], properties: Properties | None): + pass + + def publish_json_rpc_message(self, topic: str, message: types.JSONRPCMessage): + json = message.model_dump_json(by_alias=True, exclude_none=True) + self.client.publish(topic, json, qos=QOS) + + def connect(self): + logger.debug("Setting up MQTT connection") + self.client.connect( + host = self.mqtt_options.host, + port = self.mqtt_options.port, + keepalive = self.mqtt_options.keepalive, + bind_address = self.mqtt_options.bind_address, + bind_port = self.mqtt_options.bind_port, + clean_start=True + ) + + def assert_property(self, properties: Properties | None, property_name: str, expected_value: Any): + if get_property(properties, property_name) == expected_value: + pass + else: + self.stop_mqtt() + raise ValueError(f"{property_name} not available") + + def stop_mqtt(self): + self.client.publish(self.service_presence_topic, payload=None, qos=QOS, retain=True) + self.client.disconnect() + self.client.loop_stop() + for stream in self._read_stream_writers.values(): + anyio.from_thread.run(stream.aclose) + self._read_stream_writers = {} + logger.debug("Disconnected from MQTT broker_host") + +def get_property(properties: Properties | None, property_name: str): + if properties and hasattr(properties, property_name): + return getattr(properties, property_name) + else: + return False + From a4142a8dbb0390a8d3f436147bf8f49ff9f9197f Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Thu, 3 Apr 2025 09:55:06 +0800 Subject: [PATCH 05/23] add mqtt transport for mcp clients --- .../clients/mqtt-clients/client_apis_demo.py | 53 +++ examples/fastmcp/mqtt_simple_echo.py | 7 +- pyproject.toml | 2 +- src/mcp/client/mqtt.py | 415 ++++++++++++++++++ src/mcp/server/mqtt.py | 73 +-- src/mcp/shared/mqtt.py | 59 ++- src/mcp/shared/mqtt_topic.py | 4 +- 7 files changed, 567 insertions(+), 46 deletions(-) create mode 100644 examples/clients/mqtt-clients/client_apis_demo.py create mode 100644 src/mcp/client/mqtt.py diff --git a/examples/clients/mqtt-clients/client_apis_demo.py b/examples/clients/mqtt-clients/client_apis_demo.py new file mode 100644 index 0000000000..0d29a08d57 --- /dev/null +++ b/examples/clients/mqtt-clients/client_apis_demo.py @@ -0,0 +1,53 @@ +import logging +import anyio +import mcp.client.mqtt as mcp_mqtt +from mcp.shared.mqtt import configure_logging + +configure_logging(level="DEBUG") +logger = logging.getLogger(__name__) + +async def on_mcp_server_presence(client, service_name, status): + if status == "online": + logger.info(f"Connecting to {service_name}...") + await client.initialize_mcp_server(service_name) + +async def on_mcp_connect(client, service_name, connect_result): + logger.info(f"Connect result to {service_name}: {connect_result}") + capabilities = client.service_sessions[service_name].server_info.capabilities + logger.info(f"Capabilities of {service_name}: {capabilities}") + if capabilities.prompts: + prompts = await client.list_prompts(service_name) + logger.info(f"Prompts of {service_name}: {prompts}") + if capabilities.resources: + resources = await client.list_resources(service_name) + logger.info(f"Resources of {service_name}: {resources}") + resource_templates = await client.list_resource_templates(service_name) + logger.info(f"Resources templates of {service_name}: {resource_templates}") + if capabilities.tools: + tools = await client.list_tools(service_name) + logger.info(f"Tools of {service_name}: {tools}") + +async def on_mcp_disconnect(client, service_name, reason): + logger.info(f"Disconnected from {service_name}, reason: {reason}") + logger.info(f"Services now: {client.service_sessions}") + +async def main(): + async with mcp_mqtt.MqttTransportClient( + "test_client", + auto_connect_to_mcp_server = True, + on_mcp_server_presence = on_mcp_server_presence, + on_mcp_connect = on_mcp_connect, + on_mcp_disconnect = on_mcp_disconnect, + mqtt_options = mcp_mqtt.MqttOptions( + host="broker.emqx.io", + port=1883, + keepalive=60 + ) + ) as client: + client.start() + while True: + logger.info("Other works while the MQTT transport client is running in the background...") + await anyio.sleep(10) + +if __name__ == "__main__": + anyio.run(main) diff --git a/examples/fastmcp/mqtt_simple_echo.py b/examples/fastmcp/mqtt_simple_echo.py index 5cea4f183c..5e187b8b84 100644 --- a/examples/fastmcp/mqtt_simple_echo.py +++ b/examples/fastmcp/mqtt_simple_echo.py @@ -8,7 +8,12 @@ mcp = FastMCP( "demo_server/echo", log_level="DEBUG", - mqtt_service_description="A simple FastMCP server that echoes back the input text." + mqtt_service_description="A simple FastMCP server that echoes back the input text.", + mqtt_options={ + "host": "broker.emqx.io", + "port": 1883, + "keepalive": 60, + }, ) @mcp.tool() diff --git a/pyproject.toml b/pyproject.toml index 52d5577b55..9da6b75ec4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,7 +102,7 @@ target-version = "py310" "tests/server/fastmcp/test_func_metadata.py" = ["E501"] [tool.uv.workspace] -members = ["examples/servers/*"] +members = ["examples/servers/*", "examples/clients/mqtt-clients/smart-home"] [tool.uv.sources] mcp = { workspace = true } diff --git a/src/mcp/client/mqtt.py b/src/mcp/client/mqtt.py new file mode 100644 index 0000000000..a7498c2731 --- /dev/null +++ b/src/mcp/client/mqtt.py @@ -0,0 +1,415 @@ +""" +This module implements the MQTT transport for the MCP server. +""" +from contextlib import AsyncExitStack +from uuid import uuid4 +from datetime import timedelta +import random +from pydantic import AnyUrl, BaseModel +from mcp.client.session import ClientSession, SamplingFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT +from mcp.shared.exceptions import McpError +from mcp.shared.mqtt import MqttTransportBase, MqttOptions, QOS +import asyncio +import anyio.to_thread as anyio_to_thread +import anyio.from_thread as anyio_from_thread +import traceback +import mcp.shared.mqtt_topic as mqtt_topic +import paho.mqtt.client as mqtt +import logging +from paho.mqtt.reasoncodes import ReasonCode +from paho.mqtt.properties import Properties +from paho.mqtt.subscribeoptions import SubscribeOptions +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from typing import Any, Literal, TypeAlias, Callable, Awaitable +import mcp.types as types + +RcvStream : TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage] +SndStream : TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage] +RcvStreamEx : TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] +SndStreamEX : TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage | Exception] +ServerRun : TypeAlias = Callable[[RcvStreamEx, SndStream], Awaitable[Any]] + +ServiceName : TypeAlias = str +ServiceId : TypeAlias = str +InitializeResult : TypeAlias = Literal["ok"] | Literal["already_connected"] | tuple[Literal["error"], str] +ConnectResult : TypeAlias = tuple[Literal["ok"], types.InitializeResult] | tuple[Literal["error"], Any] +DisconnectReason : TypeAlias = Literal["client_initiated_disconnect", "server_initiated_disconnect"] + +logger = logging.getLogger(__name__) + +class ServiceDefinition(BaseModel): + description: str + meta: dict[str, Any] = {} + +class ServiceOnlineNotification(BaseModel): + jsonrpc: Literal["2.0"] + method: str = "notifications/service/online" + params: ServiceDefinition + +class MqttClientSession(ClientSession): + def __init__( + self, + service_id: ServiceId, + service_name: ServiceName, + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], + write_stream: MemoryObjectSendStream[types.JSONRPCMessage], + read_timeout_seconds: timedelta | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + ) -> None: + super().__init__( + read_stream, + write_stream, + read_timeout_seconds, + sampling_callback, + list_roots_callback, + logging_callback, + message_handler, + ) + self.service_id = service_id + self.service_name = service_name + self.server_info: types.InitializeResult | None = None + +class MqttTransportClient(MqttTransportBase): + + def __init__(self, mcp_client_name: str, client_id_prefix: str | None = None, + service_name_filter: str = '#', + auto_connect_to_mcp_server: bool = False, + on_mcp_connect: Callable[["MqttTransportClient", ServiceName, ConnectResult], Awaitable[Any]] | None = None, + on_mcp_disconnect: Callable[["MqttTransportClient", ServiceName, DisconnectReason], Awaitable[Any]] | None = None, + on_mcp_server_presence: Callable[["MqttTransportClient", ServiceName, Literal["online", "offline"]], Awaitable[Any]] | None = None, + mqtt_options: MqttOptions = MqttOptions()): + self.exit_stack: AsyncExitStack = AsyncExitStack() + uuid = uuid4().hex + mqtt_clientid = f"{client_id_prefix}-{uuid}" if client_id_prefix else uuid + self.service_list: dict[ServiceName, dict[ServiceId, ServiceDefinition]] = {} + self.service_sessions: dict[ServiceName, MqttClientSession] = {} + self.mcp_client_id = mqtt_clientid + self.mcp_client_name = mcp_client_name + self.service_name_filter = service_name_filter + self.auto_connect_to_mcp_server = auto_connect_to_mcp_server #TODO: not implemented yet + self.on_mcp_connect = on_mcp_connect + self.on_mcp_disconnect = on_mcp_disconnect #TODO: not implemented yet + self.on_mcp_server_presence = on_mcp_server_presence + self.client_capability_change_topic = mqtt_topic.get_client_capability_change_topic(self.mcp_client_id) + super().__init__("mcp-client", mqtt_clientid = mqtt_clientid, mqtt_options = mqtt_options) + self.presence_topic = mqtt_topic.get_client_presence_topic(self.mcp_client_id) + ## Send disconnected notification when disconnects + self.disconnected_msg = types.JSONRPCNotification( + jsonrpc="2.0", + method = "notifications/disconnected" + ) + self.disconnected_msg_retain = False + self.client.will_set( + topic=self.presence_topic, + payload=self.disconnected_msg.model_dump_json(), + qos=QOS + ) + + def start(self): + def do_start(): + self.connect() + self.client.loop_forever() + try: + asyncio.create_task(anyio_to_thread.run_sync(do_start)) + except asyncio.CancelledError: + logger.debug("MQTT transport (MCP client) got cancelled") + except Exception as exc: + logger.error(f"MQTT transport (MCP client) failed: {exc}") + traceback.print_exc() + + async def initialize_mcp_server( + self, service_name: str, + read_timeout_seconds: timedelta | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None) -> InitializeResult: + if service_name in self.service_sessions: + return "already_connected" + if service_name not in self.service_list: + logger.error(f"MCP server not found, service name: {service_name}") + return ("error", "MCP server not found") + service_id = self.pick_service_id(service_name) + + async def after_subscribed( + subscribe_result: Literal["success", "error"] + ): + if subscribe_result == "error": + if self.on_mcp_connect: + self._task_group.start_soon(self.on_mcp_connect, self, service_name, ("error", "subscribe_mcp_server_topics_failed")) + client_session = self._create_session( + service_id, + service_name, + read_timeout_seconds, + sampling_callback, + list_roots_callback, + logging_callback, + message_handler + ) + self.service_sessions[service_name] = client_session + try: + logger.debug(f"before initialize: {service_name}") + async def after_initialize(): + try: + session = await self.exit_stack.enter_async_context(client_session) + init_result = await session.initialize() + session.server_info = init_result + if self.on_mcp_connect: + self._task_group.start_soon(self.on_mcp_connect, self, service_name, ("ok", init_result)) + except Exception as e: + logging.error(f"Failed to initialize server {service_name}: {e}") + await self.exit_stack.aclose() + raise + self._task_group.start_soon(after_initialize) + logger.debug(f"after initialize: {service_name}") + except McpError as exc: + logger.error(f"Failed to connect to MCP server: {exc}") + if self.on_mcp_connect: + self._task_group.start_soon(self.on_mcp_connect, self, service_name, ("error", McpError)) + + if self._subscribe_mcp_server_topics(service_id, service_name, after_subscribed): + return "ok" + else: + return ("error", "send_subscribe_request_failed") + + async def send_ping(self, service_name: ServiceName) -> bool | types.EmptyResult: + return await self._with_session(service_name, lambda s: s.send_ping()) + + async def send_progress_notification( + self, service_name: ServiceName, progress_token: str | int, + progress: float, total: float | None = None + ) -> bool | None: + return await self._with_session(service_name, lambda s: s.send_progress_notification(progress_token, progress, total)) + + async def set_logging_level(self, service_name: ServiceName, + level: types.LoggingLevel) -> bool | types.EmptyResult: + return await self._with_session(service_name, lambda s: s.set_logging_level(level)) + + async def list_resources(self, service_name: ServiceName) -> bool | types.ListResourcesResult: + return await self._with_session(service_name, lambda s: s.list_resources()) + + async def list_resource_templates(self, service_name: ServiceName) -> bool | types.ListResourceTemplatesResult: + return await self._with_session(service_name, lambda s: s.list_resource_templates()) + + async def read_resource(self, service_name: ServiceName, + uri: AnyUrl) -> bool | types.ReadResourceResult: + return await self._with_session(service_name, lambda s: s.read_resource(uri)) + + async def subscribe_resource(self, service_name: ServiceName, + uri: AnyUrl) -> bool | types.EmptyResult: + return await self._with_session(service_name, lambda s: s.subscribe_resource(uri)) + + async def unsubscribe_resource(self, service_name: ServiceName, + uri: AnyUrl) -> bool | types.EmptyResult: + return await self._with_session(service_name, lambda s: s.unsubscribe_resource(uri)) + + async def call_tool( + self, service_name: ServiceName, name: str, arguments: dict[str, Any] | None = None + ) -> bool | types.CallToolResult: + return await self._with_session(service_name, lambda s: s.call_tool(name, arguments)) + + async def list_prompts(self, service_name: ServiceName) -> bool | types.ListPromptsResult: + return await self._with_session(service_name, lambda s: s.list_prompts()) + + async def get_prompt( + self, service_name: ServiceName, name: str, arguments: dict[str, str] | None = None + ) -> bool | types.GetPromptResult: + return await self._with_session(service_name, lambda s: s.get_prompt(name, arguments)) + + async def complete( + self, + service_name: ServiceName, + ref: types.ResourceReference | types.PromptReference, + argument: dict[str, str], + ) -> bool | types.CompleteResult: + return await self._with_session(service_name, lambda s: s.complete(ref, argument)) + + async def list_tools(self, service_name: ServiceName) -> bool | types.ListToolsResult: + return await self._with_session(service_name, lambda s: s.list_tools()) + + async def send_roots_list_changed(self, service_name: ServiceName) -> bool | None: + return await self._with_session(service_name, lambda s: s.send_roots_list_changed()) + + async def _with_session( + self, service_name: ServiceName, + async_callback: Callable[[MqttClientSession], Awaitable[bool | Any]]) -> bool | Any: + if not (client_session := self.service_sessions.get(service_name)): + logger.error(f"No session for service_name: {service_name}") + return False + return await async_callback(client_session) + + def _create_session( + self, service_id: ServiceId, service_name: ServiceName, + read_timeout_seconds: timedelta | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None + ): + ## Streams are used to communicate between the MqttTransportClient and the MCPSession: + ## 1. (msg) --> MqttBroker --> MqttTransportClient -->[read_stream_writer]-->[read_stream]--> MCPSession + ## 2. MqttBroker <-- MqttTransportClient <--[write_stream_reader]--[write_stream]-- MCPSession <-- (msg) + read_stream: RcvStreamEx + read_stream_writer: SndStreamEX + write_stream: SndStream + write_stream_reader: RcvStream + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + self._read_stream_writers[service_id] = read_stream_writer + self._task_group.start_soon(self._receieved_from_session, service_id, service_name, write_stream_reader) + logger.debug(f"Created new session for service_id: {service_id}") + return MqttClientSession( + service_id, + service_name, + read_stream, + write_stream, + read_timeout_seconds, + sampling_callback, + list_roots_callback, + logging_callback, + message_handler + ) + + def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code : ReasonCode, properties: Properties | None): + if reason_code == 0: + super()._on_connect(client, userdata, connect_flags, reason_code, properties) + ## Subscribe to the MCP server's service presence topic + client.subscribe(mqtt_topic.get_service_presence_topic('+', self.service_name_filter), qos=QOS) + + def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): + logger.debug(f"Received message on topic {msg.topic}: {msg.payload.decode()}") + match msg.topic: + case str() as t if t.startswith(mqtt_topic.SERVICE_PRESENCE_BASE): + self._handle_service_presence_message(msg) + case str() as t if t.startswith(mqtt_topic.RPC_BASE): + self._handle_rpc_message(msg) + case str() as t if t.startswith(mqtt_topic.SERVICE_CAPABILITY_CHANGE_BASE): + self._handle_service_capability_list_changed_message(msg) + case str() as t if t.startswith(mqtt_topic.SERVICE_RESOURCE_UPDATE_BASE): + self._handle_service_capability_resource_updated_message(msg) + case _: + logger.error(f"Received message on unexpected topic: {msg.topic}") + + def _on_subscribe(self, client: mqtt.Client, userdata: Any, mid: int, + reason_code_list: list[ReasonCode], properties: Properties | None): + if mid in userdata.get("pending_subs", {}): + service_name, service_id, after_subscribed = userdata["pending_subs"].pop(mid) + ## only create session if all topic subscribed successfully + if all([rc.value == QOS for rc in reason_code_list]): + logger.debug(f"Subscribed to topics for service_name: {service_name}, service_id: {service_id}") + anyio_from_thread.run(after_subscribed, "success") + else: + anyio_from_thread.run(after_subscribed, "error") + logger.error(f"Failed to subscribe to topics for service_name: {service_name}, service_id: {service_id}, reason_codes: {reason_code_list}") + + def _handle_service_presence_message(self, msg: mqtt.MQTTMessage) -> None: + topic_words = msg.topic.split("/") + service_id = topic_words[2] + service_name = "/".join(topic_words[3:]) + if msg.payload: + newly_added_service = False if service_name in self.service_list else True + service_notif = ServiceOnlineNotification.model_validate_json(msg.payload.decode()) + self.service_list.setdefault(service_name, {})[service_id] = service_notif.params + logger.debug(f"Service {service_name} with id {service_id} is online") + if newly_added_service: + if self.on_mcp_server_presence: + anyio_from_thread.run(self.on_mcp_server_presence, self, service_name, "online") + else: + existing_service = True if service_name in self.service_list else False + if service_id in self.service_list.get(service_name, {}): + logger.debug(f"Service {service_name} with id {service_id} is offline") + self.service_list[service_name].pop(service_id) + if existing_service: + if self.on_mcp_server_presence: + anyio_from_thread.run(self.on_mcp_server_presence, self, service_name, "offline") + + def _handle_rpc_message(self, msg: mqtt.MQTTMessage) -> None: + service_name = "/".join(msg.topic.split("/")[2:]) + anyio_from_thread.run(self._send_message_to_session, service_name, msg) + + def _handle_service_capability_list_changed_message(self, msg: mqtt.MQTTMessage) -> None: + service_name = "/".join(msg.topic.split("/")[4:]) + anyio_from_thread.run(self._send_message_to_session, service_name, msg) + + def _handle_service_capability_resource_updated_message(self, msg: mqtt.MQTTMessage) -> None: + service_name = "/".join(msg.topic.split("/")[4:]) + anyio_from_thread.run(self._send_message_to_session, service_name, msg) + + def _subscribe_mcp_server_topics(self, service_id: ServiceId, service_name: ServiceName, + after_subscribed: Callable[[Any], Awaitable[None]]): + topic_filters = [ + (mqtt_topic.get_service_capability_change_topic(service_id, service_name), SubscribeOptions(qos=QOS)), + (mqtt_topic.get_service_resource_update_topic(service_id, service_name), SubscribeOptions(qos=QOS)), + (mqtt_topic.get_rpc_topic(self.mcp_client_id, service_name), SubscribeOptions(qos=QOS, noLocal=True)) + ] + ret, mid = self.client.subscribe(topic=topic_filters) + if ret != mqtt.MQTT_ERR_SUCCESS: + logger.error(f"Failed to subscribe to topics for service_name: {service_name}") + return False + userdata = self.client.user_data_get() + pending_subs = userdata.get("pending_subs", {}) + pending_subs[mid] = (service_name, service_id, after_subscribed) + userdata["pending_subs"] = pending_subs + return True + + async def _send_message_to_session(self, service_name: ServiceName, msg: mqtt.MQTTMessage): + if service_name not in self.service_sessions: + logger.error(f"_send_message_to_session: No session for service_name: {service_name}") + return + client_session: MqttClientSession = self.service_sessions[service_name] + payload = msg.payload.decode() + service_id = client_session.service_id + if service_id not in self._read_stream_writers: + logger.error(f"No session for service_id: {service_id}") + return + read_stream_writer = self._read_stream_writers[service_id] + try: + message = types.JSONRPCMessage.model_validate_json(payload) + logger.debug(f"Sending msg to session for service_id: {service_id}, msg: {message}") + with anyio.fail_after(3): + await read_stream_writer.send(message) + except Exception as exc: + logger.error(f"Failed to send msg to session for service_id: {service_id}, exception: {exc}") + traceback.print_exc() + ## TODO: the session does not handle exceptions for now + #await read_stream_writer.send(exc) + async def _receieved_from_session(self, service_id: ServiceId, service_name: ServiceName, + write_stream_reader: RcvStream): + async with write_stream_reader: + async for msg in write_stream_reader: + logger.debug(f"Got msg from session for service_id: {service_id}, msg: {msg}") + match msg.model_dump(): + case {"method": method} if method == "notifications/initialized": + logger.debug(f"Session initialized for service_id: {service_id}") + topic = mqtt_topic.get_rpc_topic(self.mcp_client_id, service_name) + case {"method": method} if method.endswith("/list_changed"): + topic = None + logger.warning("Resource updates should not be sent from the session. Ignoring.") + case {"method": method} if method == "initialize": + topic = mqtt_topic.get_service_control_topic(service_name) + case _: + topic = mqtt_topic.get_rpc_topic(self.mcp_client_id, service_name) + if topic: + self.publish_json_rpc_message(topic, message = msg) + # cleanup + if service_id in self._read_stream_writers: + logger.debug(f"Removing session for service_id: {service_id}") + stream = self._read_stream_writers.pop(service_id) + await stream.aclose() + + logger.debug(f"Session stream closed for service_id: {service_id}") + + def pick_service_id(self, service_name: str) -> ServiceId: + return random.choice(list(self.service_list[service_name].keys())) + +def validate_service_name(name: str): + if "/" not in name: + raise ValueError(f"Invalid service name: {name}, must contain a '/'") + elif ("+" in name) or ("#" in name): + raise ValueError(f"Invalid service name: {name}, must not contain '+' or '#'") + elif name[0] == "/": + raise ValueError(f"Invalid service name: {name}, must not start with '/'") diff --git a/src/mcp/server/mqtt.py b/src/mcp/server/mqtt.py index f23b361385..80ecb523f0 100644 --- a/src/mcp/server/mqtt.py +++ b/src/mcp/server/mqtt.py @@ -2,8 +2,10 @@ This module implements the MQTT transport for the MCP server. """ from uuid import uuid4 -from mcp.shared.mqtt import MqttTransportBase, MqttOptions, QOS +from mcp.shared.mqtt import MqttTransportBase, MqttOptions, QOS, PROPERTY_K_MQTT_CLIENT_ID import asyncio +import anyio.to_thread as anyio_to_thread +import anyio.from_thread as anyio_from_thread import json import traceback import mcp.shared.mqtt_topic as mqtt_topic @@ -23,10 +25,9 @@ SndStreamEX : TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage | Exception] ServerRun : TypeAlias = Callable[[RcvStreamEx, SndStream], Awaitable[Any]] -PROPERTY_K_MCP_CLIENT_ID = "mcp-client-id" logger = logging.getLogger(__name__) -class MqttTransport(MqttTransportBase): +class MqttTransportServer(MqttTransportBase): def __init__(self, server_run: ServerRun, service_name: str, service_description: str, @@ -43,8 +44,9 @@ def __init__(self, server_run: ServerRun, service_name: str, self.service_presence_topic = mqtt_topic.get_service_presence_topic(self.service_id, service_name) self.service_capability_change_topic = mqtt_topic.get_service_capability_change_topic(self.service_id, service_name) self.server_run = server_run - super().__init__(mqtt_clientid, mqtt_options) + super().__init__("mcp-server", mqtt_clientid = mqtt_clientid, mqtt_options = mqtt_options) self.presence_topic = mqtt_topic.get_service_presence_topic(self.service_id, service_name) + self.disconnected_msg = None self.client.will_set(topic=self.presence_topic, payload=None, qos=QOS, retain=True) def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code : ReasonCode, properties: Properties | None): @@ -53,16 +55,17 @@ def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.Co ## Subscribe to the service control topic client.subscribe(self.service_control_topic, QOS) ## Reister the service on the presence topic - online_msg = types.JSONRPCNotification( - jsonrpc="2.0", - method = "notifications/service/online", - params = { - "description": self.service_description, - "meta": self.service_meta - } - ) - client.publish(self.presence_topic, payload=online_msg.model_dump_json(), - qos=QOS, retain=True) + online_msg = types.JSONRPCMessage( + types.JSONRPCNotification( + jsonrpc="2.0", + method = "notifications/service/online", + params = { + "description": self.service_description, + "meta": self.service_meta + } + )) + self.publish_json_rpc_message( + self.presence_topic, message=online_msg, retain=True) def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): logger.debug(f"Received message on topic {msg.topic}: {msg.payload.decode()}") @@ -85,7 +88,7 @@ def _on_subscribe(self, client: mqtt.Client, userdata: Any, mid: int, ## only create session if all topic subscribed successfully if all([rc.value == QOS for rc in reason_code_list]): logger.debug(f"Subscribed to topics for mcp_client_id: {mcp_client_id}") - anyio.from_thread.run(self.create_session, mcp_client_id, msg) + anyio_from_thread.run(self.create_session, mcp_client_id, msg) else: logger.error(f"Failed to subscribe to topics for mcp_client_id: {mcp_client_id}, reason_codes: {reason_code_list}") err = types.JSONRPCError( @@ -98,30 +101,30 @@ def _on_subscribe(self, client: mqtt.Client, userdata: Any, mid: int, ) self.publish_json_rpc_message( mqtt_topic.get_rpc_topic(mcp_client_id, self.service_name), - types.JSONRPCMessage(err) + message = types.JSONRPCMessage(err) ) def handle_service_contorl_message(self, msg: mqtt.MQTTMessage): if msg.properties and hasattr(msg.properties, "UserProperty"): user_properties: dict[str, Any] = dict(msg.properties.UserProperty) # type: ignore - if PROPERTY_K_MCP_CLIENT_ID in user_properties: - mcp_client_id = user_properties[PROPERTY_K_MCP_CLIENT_ID] + if PROPERTY_K_MQTT_CLIENT_ID in user_properties: + mcp_client_id = user_properties[PROPERTY_K_MQTT_CLIENT_ID] if mcp_client_id in self._read_stream_writers: - anyio.from_thread.run(self.send_message_to_session, mcp_client_id, msg) + anyio_from_thread.run(self._send_message_to_session, mcp_client_id, msg) else: self.maybe_subscribe_to_client(mcp_client_id, msg) else: - logger.error(f"No {PROPERTY_K_MCP_CLIENT_ID} in UserProperties") + logger.error(f"No {PROPERTY_K_MQTT_CLIENT_ID} in UserProperties") else: logger.error("No UserProperties in control message") def handle_client_capability_change_message(self, msg: mqtt.MQTTMessage) -> None: mcp_client_id = msg.topic.split("/")[-1] - anyio.from_thread.run(self.send_message_to_session, mcp_client_id, msg) + anyio_from_thread.run(self._send_message_to_session, mcp_client_id, msg) def handle_rpc_message(self, msg: mqtt.MQTTMessage) -> None: mcp_client_id = msg.topic.split("/")[1] - anyio.from_thread.run(self.send_message_to_session, mcp_client_id, msg) + anyio_from_thread.run(self._send_message_to_session, mcp_client_id, msg) def handle_client_presence_message(self, msg: mqtt.MQTTMessage) -> None: mcp_client_id = msg.topic.split("/")[-1] @@ -133,7 +136,7 @@ def handle_client_presence_message(self, msg: mqtt.MQTTMessage) -> None: if "method" in json_msg: if json_msg["method"] == "notifications/disconnected": stream = self._read_stream_writers.pop(mcp_client_id) - anyio.from_thread.run(stream.aclose) + anyio_from_thread.run(stream.aclose) logger.debug(f"Removed session for mcp_client_id: {mcp_client_id}") else: logger.error(f"Unknown method in control message for mcp_client_id: {mcp_client_id}") @@ -143,9 +146,9 @@ def handle_client_presence_message(self, msg: mqtt.MQTTMessage) -> None: logger.error(f"Invalid JSON in control message for mcp_client_id: {mcp_client_id}") async def create_session(self, mcp_client_id: str, msg: mqtt.MQTTMessage): - ## Streams are used to communicate between the MqttTransport and the MCPSession: - ## 1. (msg) --> MqttBroker --> MqttTransport -->[read_stream_writer]-->[read_stream]--> MCPSession - ## 2. MqttBroker <-- MqttTransport <--[write_stream_reader]--[write_stream]-- MCPSession <-- (msg) + ## Streams are used to communicate between the MqttTransportServer and the MCPSession: + ## 1. (msg) --> MqttBroker --> MqttTransportServer -->[read_stream_writer]-->[read_stream]--> MCPSession + ## 2. MqttBroker <-- MqttTransportServer <--[write_stream_reader]--[write_stream]-- MCPSession <-- (msg) read_stream: RcvStreamEx read_stream_writer: SndStreamEX write_stream: SndStream @@ -154,9 +157,9 @@ async def create_session(self, mcp_client_id: str, msg: mqtt.MQTTMessage): write_stream, write_stream_reader = anyio.create_memory_object_stream(0) self._read_stream_writers[mcp_client_id] = read_stream_writer self._task_group.start_soon(self.server_run, read_stream, write_stream) - self._task_group.start_soon(self.receieved_from_session, mcp_client_id, write_stream_reader) + self._task_group.start_soon(self._receieved_from_session, mcp_client_id, write_stream_reader) logger.debug(f"Created new session for mcp_client_id: {mcp_client_id}") - await self.send_message_to_session(mcp_client_id, msg) + await self._send_message_to_session(mcp_client_id, msg) def maybe_subscribe_to_client(self, mcp_client_id: str, msg: mqtt.MQTTMessage): try: @@ -185,7 +188,7 @@ def subscribe_to_client(self, mcp_client_id: str, msg: mqtt.MQTTMessage, rcp_msg pending_subs[mid] = (mcp_client_id, msg, rcp_msg_id) userdata["pending_subs"] = pending_subs - async def send_message_to_session(self, mcp_client_id: str, msg: mqtt.MQTTMessage): + async def _send_message_to_session(self, mcp_client_id: str, msg: mqtt.MQTTMessage): payload = msg.payload.decode() if mcp_client_id not in self._read_stream_writers: logger.error(f"No session for mcp_client_id: {mcp_client_id}") @@ -202,7 +205,7 @@ async def send_message_to_session(self, mcp_client_id: str, msg: mqtt.MQTTMessag ## TODO: the session does not handle exceptions for now #await read_stream_writer.send(exc) - async def receieved_from_session(self, mcp_client_id: str, write_stream_reader: RcvStream): + async def _receieved_from_session(self, mcp_client_id: str, write_stream_reader: RcvStream): async with write_stream_reader: async for msg in write_stream_reader: logger.debug(f"Got msg from session for mcp_client_id: {mcp_client_id}, msg: {msg}") @@ -213,7 +216,7 @@ async def receieved_from_session(self, mcp_client_id: str, write_stream_reader: logger.warning("Resource updates should not be sent from the session. Ignoring.") case _: topic = mqtt_topic.get_rpc_topic(mcp_client_id, self.service_name) - self.publish_json_rpc_message(topic, msg) + self.publish_json_rpc_message(topic, message = msg) # cleanup if mcp_client_id in self._read_stream_writers: logger.debug(f"Removing session for mcp_client_id: {mcp_client_id}") @@ -228,7 +231,7 @@ async def start_mqtt( service_meta: dict[str, Any], client_id_prefix: str | None = None, mqtt_options: MqttOptions = MqttOptions()): - async with MqttTransport( + async with MqttTransportServer( server_run, service_name = service_name, service_description=service_description, @@ -240,9 +243,11 @@ def start(): mqtt_trans.connect() mqtt_trans.client.loop_forever() try: - await anyio.to_thread.run_sync(start) + await anyio_to_thread.run_sync(start) except asyncio.CancelledError: - logger.debug("MQTT transport got cancelled") + logger.debug("MQTT transport (MCP server) got cancelled") + except Exception as exc: + logger.error(f"MQTT transport (MCP server) failed with exception: {exc}") def validate_service_name(name: str): if "/" not in name: diff --git a/src/mcp/shared/mqtt.py b/src/mcp/shared/mqtt.py index cff6ff18f6..7d87173739 100644 --- a/src/mcp/shared/mqtt.py +++ b/src/mcp/shared/mqtt.py @@ -8,14 +8,19 @@ from paho.mqtt.reasoncodes import ReasonCode from paho.mqtt.enums import CallbackAPIVersion from paho.mqtt.properties import Properties +from paho.mqtt.packettypes import PacketTypes import anyio +import anyio.from_thread as anyio_from_thread from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import BaseModel from typing import Literal, Optional, Any, TypeAlias, Callable, Awaitable import mcp.types as types from typing_extensions import Self +DEFAULT_LOG_FORMAT = "%(asctime)s - %(message)s" QOS = 1 +PROPERTY_K_MCP_COMPONENT = "MCP-COMPONENT-TYPE" +PROPERTY_K_MQTT_CLIENT_ID = "MQTT-CLIENT-ID" logger = logging.getLogger(__name__) RcvStream : TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage] @@ -50,11 +55,17 @@ class MqttTransportBase: str, SndStreamEX ] - def __init__(self, mqtt_clientid: str | None = None, + def __init__(self, + mcp_component_type: Literal["mcp-client", "mcp-server"], + mqtt_clientid: str | None = None, mqtt_options: MqttOptions = MqttOptions()): self._read_stream_writers = {} self.mqtt_clientid = mqtt_clientid + self.mcp_component_type = mcp_component_type self.mqtt_options = mqtt_options + self.presence_topic = '' + self.disconnected_msg = None + self.disconnected_msg_retain = True client = mqtt.Client( callback_api_version=CallbackAPIVersion.VERSION2, client_id=mqtt_clientid, protocol=mqtt.MQTTv5, @@ -91,6 +102,7 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: + await self.stop_mqtt() self._task_group.cancel_scope.cancel() return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) @@ -109,9 +121,15 @@ def _on_subscribe(self, client: mqtt.Client, userdata: Any, mid: int, reason_code_list: list[ReasonCode], properties: Properties | None): pass - def publish_json_rpc_message(self, topic: str, message: types.JSONRPCMessage): - json = message.model_dump_json(by_alias=True, exclude_none=True) - self.client.publish(topic, json, qos=QOS) + def publish_json_rpc_message(self, topic: str, message: types.JSONRPCMessage | None, + retain: bool = False): + props = Properties(PacketTypes.PUBLISH) + props.UserProperty = [ + (PROPERTY_K_MCP_COMPONENT, self.mcp_component_type), + (PROPERTY_K_MQTT_CLIENT_ID, self.mqtt_clientid) + ] + payload = message.model_dump_json(by_alias=True, exclude_none=True) if message else None + self.client.publish(topic=topic, payload=payload, qos=QOS, retain=retain, properties=props) def connect(self): logger.debug("Setting up MQTT connection") @@ -128,15 +146,19 @@ def assert_property(self, properties: Properties | None, property_name: str, exp if get_property(properties, property_name) == expected_value: pass else: - self.stop_mqtt() + anyio_from_thread.run(self.stop_mqtt) raise ValueError(f"{property_name} not available") - def stop_mqtt(self): - self.client.publish(self.service_presence_topic, payload=None, qos=QOS, retain=True) + async def stop_mqtt(self): + self.publish_json_rpc_message( + self.presence_topic, + message = self.disconnected_msg, + retain = self.disconnected_msg_retain + ) self.client.disconnect() self.client.loop_stop() for stream in self._read_stream_writers.values(): - anyio.from_thread.run(stream.aclose) + await stream.aclose() self._read_stream_writers = {} logger.debug("Disconnected from MQTT broker_host") @@ -146,3 +168,24 @@ def get_property(properties: Properties | None, property_name: str): else: return False +def configure_logging( + level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO", + format: str = DEFAULT_LOG_FORMAT, +) -> None: + handlers: list[logging.Handler] = [] + try: + from rich.console import Console + from rich.logging import RichHandler + + handlers.append(RichHandler(console=Console(stderr=True), rich_tracebacks=True)) + except ImportError: + pass + + if not handlers: + handlers.append(logging.StreamHandler()) + + logging.basicConfig( + level=level, + format=format, + handlers=handlers, + ) diff --git a/src/mcp/shared/mqtt_topic.py b/src/mcp/shared/mqtt_topic.py index 1f96f255c9..9319bf6e8d 100644 --- a/src/mcp/shared/mqtt_topic.py +++ b/src/mcp/shared/mqtt_topic.py @@ -14,8 +14,8 @@ def get_service_control_topic(service_name: str) -> str: def get_service_capability_change_topic(service_id: str, service_name: str) -> str: return f"{SERVICE_CAPABILITY_CHANGE_BASE}/{service_id}/{service_name}" -def get_service_resource_update_topic(service_id: str) -> str: - return f"{SERVICE_RESOURCE_UPDATE_BASE}/{service_id}" +def get_service_resource_update_topic(service_id: str, service_name: str) -> str: + return f"{SERVICE_RESOURCE_UPDATE_BASE}/{service_id}/{service_name}" def get_service_presence_topic(service_id: str, service_name: str) -> str: return f"{SERVICE_PRESENCE_BASE}/{service_id}/{service_name}" From 77366ecac0a05d818dd7f538a3b72da22d8eaf1f Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Sat, 5 Apr 2025 20:34:37 +0800 Subject: [PATCH 06/23] rename service to server --- .../clients/mqtt-clients/client_apis_demo.py | 36 +-- examples/fastmcp/mqtt_simple_echo.py | 4 +- src/mcp/client/mqtt.py | 270 +++++++++--------- src/mcp/server/fastmcp/server.py | 18 +- src/mcp/server/mqtt.py | 74 ++--- src/mcp/shared/mqtt_topic.py | 28 +- 6 files changed, 214 insertions(+), 216 deletions(-) diff --git a/examples/clients/mqtt-clients/client_apis_demo.py b/examples/clients/mqtt-clients/client_apis_demo.py index 0d29a08d57..f303a340bc 100644 --- a/examples/clients/mqtt-clients/client_apis_demo.py +++ b/examples/clients/mqtt-clients/client_apis_demo.py @@ -6,30 +6,30 @@ configure_logging(level="DEBUG") logger = logging.getLogger(__name__) -async def on_mcp_server_presence(client, service_name, status): +async def on_mcp_server_presence(client, server_name, status): if status == "online": - logger.info(f"Connecting to {service_name}...") - await client.initialize_mcp_server(service_name) + logger.info(f"Connecting to {server_name}...") + await client.initialize_mcp_server(server_name) -async def on_mcp_connect(client, service_name, connect_result): - logger.info(f"Connect result to {service_name}: {connect_result}") - capabilities = client.service_sessions[service_name].server_info.capabilities - logger.info(f"Capabilities of {service_name}: {capabilities}") +async def on_mcp_connect(client, server_name, connect_result): + logger.info(f"Connect result to {server_name}: {connect_result}") + capabilities = client.server_sessions[server_name].server_info.capabilities + logger.info(f"Capabilities of {server_name}: {capabilities}") if capabilities.prompts: - prompts = await client.list_prompts(service_name) - logger.info(f"Prompts of {service_name}: {prompts}") + prompts = await client.list_prompts(server_name) + logger.info(f"Prompts of {server_name}: {prompts}") if capabilities.resources: - resources = await client.list_resources(service_name) - logger.info(f"Resources of {service_name}: {resources}") - resource_templates = await client.list_resource_templates(service_name) - logger.info(f"Resources templates of {service_name}: {resource_templates}") + resources = await client.list_resources(server_name) + logger.info(f"Resources of {server_name}: {resources}") + resource_templates = await client.list_resource_templates(server_name) + logger.info(f"Resources templates of {server_name}: {resource_templates}") if capabilities.tools: - tools = await client.list_tools(service_name) - logger.info(f"Tools of {service_name}: {tools}") + tools = await client.list_tools(server_name) + logger.info(f"Tools of {server_name}: {tools}") -async def on_mcp_disconnect(client, service_name, reason): - logger.info(f"Disconnected from {service_name}, reason: {reason}") - logger.info(f"Services now: {client.service_sessions}") +async def on_mcp_disconnect(client, server_name, reason): + logger.info(f"Disconnected from {server_name}, reason: {reason}") + logger.info(f"Server sessions now: {client.server_sessions}") async def main(): async with mcp_mqtt.MqttTransportClient( diff --git a/examples/fastmcp/mqtt_simple_echo.py b/examples/fastmcp/mqtt_simple_echo.py index 5e187b8b84..430128f1ca 100644 --- a/examples/fastmcp/mqtt_simple_echo.py +++ b/examples/fastmcp/mqtt_simple_echo.py @@ -8,11 +8,9 @@ mcp = FastMCP( "demo_server/echo", log_level="DEBUG", - mqtt_service_description="A simple FastMCP server that echoes back the input text.", + mqtt_server_description="A simple FastMCP server that echoes back the input text.", mqtt_options={ "host": "broker.emqx.io", - "port": 1883, - "keepalive": 60, }, ) diff --git a/src/mcp/client/mqtt.py b/src/mcp/client/mqtt.py index a7498c2731..0eff41cda0 100644 --- a/src/mcp/client/mqtt.py +++ b/src/mcp/client/mqtt.py @@ -30,28 +30,28 @@ SndStreamEX : TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage | Exception] ServerRun : TypeAlias = Callable[[RcvStreamEx, SndStream], Awaitable[Any]] -ServiceName : TypeAlias = str -ServiceId : TypeAlias = str +ServerName : TypeAlias = str +ServerId : TypeAlias = str InitializeResult : TypeAlias = Literal["ok"] | Literal["already_connected"] | tuple[Literal["error"], str] ConnectResult : TypeAlias = tuple[Literal["ok"], types.InitializeResult] | tuple[Literal["error"], Any] DisconnectReason : TypeAlias = Literal["client_initiated_disconnect", "server_initiated_disconnect"] logger = logging.getLogger(__name__) -class ServiceDefinition(BaseModel): +class ServerDefinition(BaseModel): description: str meta: dict[str, Any] = {} -class ServiceOnlineNotification(BaseModel): +class ServerOnlineNotification(BaseModel): jsonrpc: Literal["2.0"] - method: str = "notifications/service/online" - params: ServiceDefinition + method: str = "notifications/server/online" + params: ServerDefinition class MqttClientSession(ClientSession): def __init__( self, - service_id: ServiceId, - service_name: ServiceName, + server_id: ServerId, + server_name: ServerName, read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception], write_stream: MemoryObjectSendStream[types.JSONRPCMessage], read_timeout_seconds: timedelta | None = None, @@ -69,27 +69,27 @@ def __init__( logging_callback, message_handler, ) - self.service_id = service_id - self.service_name = service_name + self.server_id = server_id + self.server_name = server_name self.server_info: types.InitializeResult | None = None class MqttTransportClient(MqttTransportBase): def __init__(self, mcp_client_name: str, client_id_prefix: str | None = None, - service_name_filter: str = '#', + server_name_filter: str = '#', auto_connect_to_mcp_server: bool = False, - on_mcp_connect: Callable[["MqttTransportClient", ServiceName, ConnectResult], Awaitable[Any]] | None = None, - on_mcp_disconnect: Callable[["MqttTransportClient", ServiceName, DisconnectReason], Awaitable[Any]] | None = None, - on_mcp_server_presence: Callable[["MqttTransportClient", ServiceName, Literal["online", "offline"]], Awaitable[Any]] | None = None, + on_mcp_connect: Callable[["MqttTransportClient", ServerName, ConnectResult], Awaitable[Any]] | None = None, + on_mcp_disconnect: Callable[["MqttTransportClient", ServerName, DisconnectReason], Awaitable[Any]] | None = None, + on_mcp_server_presence: Callable[["MqttTransportClient", ServerName, Literal["online", "offline"]], Awaitable[Any]] | None = None, mqtt_options: MqttOptions = MqttOptions()): self.exit_stack: AsyncExitStack = AsyncExitStack() uuid = uuid4().hex mqtt_clientid = f"{client_id_prefix}-{uuid}" if client_id_prefix else uuid - self.service_list: dict[ServiceName, dict[ServiceId, ServiceDefinition]] = {} - self.service_sessions: dict[ServiceName, MqttClientSession] = {} + self.server_list: dict[ServerName, dict[ServerId, ServerDefinition]] = {} + self.server_sessions: dict[ServerName, MqttClientSession] = {} self.mcp_client_id = mqtt_clientid self.mcp_client_name = mcp_client_name - self.service_name_filter = service_name_filter + self.server_name_filter = server_name_filter self.auto_connect_to_mcp_server = auto_connect_to_mcp_server #TODO: not implemented yet self.on_mcp_connect = on_mcp_connect self.on_mcp_disconnect = on_mcp_disconnect #TODO: not implemented yet @@ -122,128 +122,128 @@ def do_start(): traceback.print_exc() async def initialize_mcp_server( - self, service_name: str, + self, server_name: str, read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None) -> InitializeResult: - if service_name in self.service_sessions: + if server_name in self.server_sessions: return "already_connected" - if service_name not in self.service_list: - logger.error(f"MCP server not found, service name: {service_name}") + if server_name not in self.server_list: + logger.error(f"MCP server not found, server name: {server_name}") return ("error", "MCP server not found") - service_id = self.pick_service_id(service_name) + server_id = self.pick_server_id(server_name) async def after_subscribed( subscribe_result: Literal["success", "error"] ): if subscribe_result == "error": if self.on_mcp_connect: - self._task_group.start_soon(self.on_mcp_connect, self, service_name, ("error", "subscribe_mcp_server_topics_failed")) + self._task_group.start_soon(self.on_mcp_connect, self, server_name, ("error", "subscribe_mcp_server_topics_failed")) client_session = self._create_session( - service_id, - service_name, + server_id, + server_name, read_timeout_seconds, sampling_callback, list_roots_callback, logging_callback, message_handler ) - self.service_sessions[service_name] = client_session + self.server_sessions[server_name] = client_session try: - logger.debug(f"before initialize: {service_name}") + logger.debug(f"before initialize: {server_name}") async def after_initialize(): try: session = await self.exit_stack.enter_async_context(client_session) init_result = await session.initialize() session.server_info = init_result if self.on_mcp_connect: - self._task_group.start_soon(self.on_mcp_connect, self, service_name, ("ok", init_result)) + self._task_group.start_soon(self.on_mcp_connect, self, server_name, ("ok", init_result)) except Exception as e: - logging.error(f"Failed to initialize server {service_name}: {e}") + logging.error(f"Failed to initialize server {server_name}: {e}") await self.exit_stack.aclose() raise self._task_group.start_soon(after_initialize) - logger.debug(f"after initialize: {service_name}") + logger.debug(f"after initialize: {server_name}") except McpError as exc: logger.error(f"Failed to connect to MCP server: {exc}") if self.on_mcp_connect: - self._task_group.start_soon(self.on_mcp_connect, self, service_name, ("error", McpError)) + self._task_group.start_soon(self.on_mcp_connect, self, server_name, ("error", McpError)) - if self._subscribe_mcp_server_topics(service_id, service_name, after_subscribed): + if self._subscribe_mcp_server_topics(server_id, server_name, after_subscribed): return "ok" else: return ("error", "send_subscribe_request_failed") - async def send_ping(self, service_name: ServiceName) -> bool | types.EmptyResult: - return await self._with_session(service_name, lambda s: s.send_ping()) + async def send_ping(self, server_name: ServerName) -> bool | types.EmptyResult: + return await self._with_session(server_name, lambda s: s.send_ping()) async def send_progress_notification( - self, service_name: ServiceName, progress_token: str | int, + self, server_name: ServerName, progress_token: str | int, progress: float, total: float | None = None ) -> bool | None: - return await self._with_session(service_name, lambda s: s.send_progress_notification(progress_token, progress, total)) + return await self._with_session(server_name, lambda s: s.send_progress_notification(progress_token, progress, total)) - async def set_logging_level(self, service_name: ServiceName, + async def set_logging_level(self, server_name: ServerName, level: types.LoggingLevel) -> bool | types.EmptyResult: - return await self._with_session(service_name, lambda s: s.set_logging_level(level)) + return await self._with_session(server_name, lambda s: s.set_logging_level(level)) - async def list_resources(self, service_name: ServiceName) -> bool | types.ListResourcesResult: - return await self._with_session(service_name, lambda s: s.list_resources()) + async def list_resources(self, server_name: ServerName) -> bool | types.ListResourcesResult: + return await self._with_session(server_name, lambda s: s.list_resources()) - async def list_resource_templates(self, service_name: ServiceName) -> bool | types.ListResourceTemplatesResult: - return await self._with_session(service_name, lambda s: s.list_resource_templates()) + async def list_resource_templates(self, server_name: ServerName) -> bool | types.ListResourceTemplatesResult: + return await self._with_session(server_name, lambda s: s.list_resource_templates()) - async def read_resource(self, service_name: ServiceName, + async def read_resource(self, server_name: ServerName, uri: AnyUrl) -> bool | types.ReadResourceResult: - return await self._with_session(service_name, lambda s: s.read_resource(uri)) + return await self._with_session(server_name, lambda s: s.read_resource(uri)) - async def subscribe_resource(self, service_name: ServiceName, + async def subscribe_resource(self, server_name: ServerName, uri: AnyUrl) -> bool | types.EmptyResult: - return await self._with_session(service_name, lambda s: s.subscribe_resource(uri)) + return await self._with_session(server_name, lambda s: s.subscribe_resource(uri)) - async def unsubscribe_resource(self, service_name: ServiceName, + async def unsubscribe_resource(self, server_name: ServerName, uri: AnyUrl) -> bool | types.EmptyResult: - return await self._with_session(service_name, lambda s: s.unsubscribe_resource(uri)) + return await self._with_session(server_name, lambda s: s.unsubscribe_resource(uri)) async def call_tool( - self, service_name: ServiceName, name: str, arguments: dict[str, Any] | None = None + self, server_name: ServerName, name: str, arguments: dict[str, Any] | None = None ) -> bool | types.CallToolResult: - return await self._with_session(service_name, lambda s: s.call_tool(name, arguments)) + return await self._with_session(server_name, lambda s: s.call_tool(name, arguments)) - async def list_prompts(self, service_name: ServiceName) -> bool | types.ListPromptsResult: - return await self._with_session(service_name, lambda s: s.list_prompts()) + async def list_prompts(self, server_name: ServerName) -> bool | types.ListPromptsResult: + return await self._with_session(server_name, lambda s: s.list_prompts()) async def get_prompt( - self, service_name: ServiceName, name: str, arguments: dict[str, str] | None = None + self, server_name: ServerName, name: str, arguments: dict[str, str] | None = None ) -> bool | types.GetPromptResult: - return await self._with_session(service_name, lambda s: s.get_prompt(name, arguments)) + return await self._with_session(server_name, lambda s: s.get_prompt(name, arguments)) async def complete( self, - service_name: ServiceName, + server_name: ServerName, ref: types.ResourceReference | types.PromptReference, argument: dict[str, str], ) -> bool | types.CompleteResult: - return await self._with_session(service_name, lambda s: s.complete(ref, argument)) + return await self._with_session(server_name, lambda s: s.complete(ref, argument)) - async def list_tools(self, service_name: ServiceName) -> bool | types.ListToolsResult: - return await self._with_session(service_name, lambda s: s.list_tools()) + async def list_tools(self, server_name: ServerName) -> bool | types.ListToolsResult: + return await self._with_session(server_name, lambda s: s.list_tools()) - async def send_roots_list_changed(self, service_name: ServiceName) -> bool | None: - return await self._with_session(service_name, lambda s: s.send_roots_list_changed()) + async def send_roots_list_changed(self, server_name: ServerName) -> bool | None: + return await self._with_session(server_name, lambda s: s.send_roots_list_changed()) async def _with_session( - self, service_name: ServiceName, + self, server_name: ServerName, async_callback: Callable[[MqttClientSession], Awaitable[bool | Any]]) -> bool | Any: - if not (client_session := self.service_sessions.get(service_name)): - logger.error(f"No session for service_name: {service_name}") + if not (client_session := self.server_sessions.get(server_name)): + logger.error(f"No session for server_name: {server_name}") return False return await async_callback(client_session) def _create_session( - self, service_id: ServiceId, service_name: ServiceName, + self, server_id: ServerId, server_name: ServerName, read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, list_roots_callback: ListRootsFnT | None = None, @@ -259,12 +259,12 @@ def _create_session( write_stream_reader: RcvStream read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - self._read_stream_writers[service_id] = read_stream_writer - self._task_group.start_soon(self._receieved_from_session, service_id, service_name, write_stream_reader) - logger.debug(f"Created new session for service_id: {service_id}") + self._read_stream_writers[server_id] = read_stream_writer + self._task_group.start_soon(self._receieved_from_session, server_id, server_name, write_stream_reader) + logger.debug(f"Created new session for server_id: {server_id}") return MqttClientSession( - service_id, - service_name, + server_id, + server_name, read_stream, write_stream, read_timeout_seconds, @@ -277,139 +277,139 @@ def _create_session( def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code : ReasonCode, properties: Properties | None): if reason_code == 0: super()._on_connect(client, userdata, connect_flags, reason_code, properties) - ## Subscribe to the MCP server's service presence topic - client.subscribe(mqtt_topic.get_service_presence_topic('+', self.service_name_filter), qos=QOS) + ## Subscribe to the MCP server's presence topic + client.subscribe(mqtt_topic.get_server_presence_topic('+', self.server_name_filter), qos=QOS) def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): logger.debug(f"Received message on topic {msg.topic}: {msg.payload.decode()}") match msg.topic: - case str() as t if t.startswith(mqtt_topic.SERVICE_PRESENCE_BASE): - self._handle_service_presence_message(msg) + case str() as t if t.startswith(mqtt_topic.SERVER_PRESENCE_BASE): + self._handle_server_presence_message(msg) case str() as t if t.startswith(mqtt_topic.RPC_BASE): self._handle_rpc_message(msg) - case str() as t if t.startswith(mqtt_topic.SERVICE_CAPABILITY_CHANGE_BASE): - self._handle_service_capability_list_changed_message(msg) - case str() as t if t.startswith(mqtt_topic.SERVICE_RESOURCE_UPDATE_BASE): - self._handle_service_capability_resource_updated_message(msg) + case str() as t if t.startswith(mqtt_topic.SERVER_CAPABILITY_CHANGE_BASE): + self._handle_server_capability_list_changed_message(msg) + case str() as t if t.startswith(mqtt_topic.SERVER_RESOURCE_UPDATE_BASE): + self._handle_server_capability_resource_updated_message(msg) case _: logger.error(f"Received message on unexpected topic: {msg.topic}") def _on_subscribe(self, client: mqtt.Client, userdata: Any, mid: int, reason_code_list: list[ReasonCode], properties: Properties | None): if mid in userdata.get("pending_subs", {}): - service_name, service_id, after_subscribed = userdata["pending_subs"].pop(mid) + server_name, server_id, after_subscribed = userdata["pending_subs"].pop(mid) ## only create session if all topic subscribed successfully if all([rc.value == QOS for rc in reason_code_list]): - logger.debug(f"Subscribed to topics for service_name: {service_name}, service_id: {service_id}") + logger.debug(f"Subscribed to topics for server_name: {server_name}, server_id: {server_id}") anyio_from_thread.run(after_subscribed, "success") else: anyio_from_thread.run(after_subscribed, "error") - logger.error(f"Failed to subscribe to topics for service_name: {service_name}, service_id: {service_id}, reason_codes: {reason_code_list}") + logger.error(f"Failed to subscribe to topics for server_name: {server_name}, server_id: {server_id}, reason_codes: {reason_code_list}") - def _handle_service_presence_message(self, msg: mqtt.MQTTMessage) -> None: + def _handle_server_presence_message(self, msg: mqtt.MQTTMessage) -> None: topic_words = msg.topic.split("/") - service_id = topic_words[2] - service_name = "/".join(topic_words[3:]) + server_id = topic_words[2] + server_name = "/".join(topic_words[3:]) if msg.payload: - newly_added_service = False if service_name in self.service_list else True - service_notif = ServiceOnlineNotification.model_validate_json(msg.payload.decode()) - self.service_list.setdefault(service_name, {})[service_id] = service_notif.params - logger.debug(f"Service {service_name} with id {service_id} is online") - if newly_added_service: + newly_added_server = False if server_name in self.server_list else True + server_notif = ServerOnlineNotification.model_validate_json(msg.payload.decode()) + self.server_list.setdefault(server_name, {})[server_id] = server_notif.params + logger.debug(f"Server {server_name} with id {server_id} is online") + if newly_added_server: if self.on_mcp_server_presence: - anyio_from_thread.run(self.on_mcp_server_presence, self, service_name, "online") + anyio_from_thread.run(self.on_mcp_server_presence, self, server_name, "online") else: - existing_service = True if service_name in self.service_list else False - if service_id in self.service_list.get(service_name, {}): - logger.debug(f"Service {service_name} with id {service_id} is offline") - self.service_list[service_name].pop(service_id) - if existing_service: + existing_server = True if server_name in self.server_list else False + if server_id in self.server_list.get(server_name, {}): + logger.debug(f"Server {server_name} with id {server_id} is offline") + self.server_list[server_name].pop(server_id) + if existing_server: if self.on_mcp_server_presence: - anyio_from_thread.run(self.on_mcp_server_presence, self, service_name, "offline") + anyio_from_thread.run(self.on_mcp_server_presence, self, server_name, "offline") def _handle_rpc_message(self, msg: mqtt.MQTTMessage) -> None: - service_name = "/".join(msg.topic.split("/")[2:]) - anyio_from_thread.run(self._send_message_to_session, service_name, msg) + server_name = "/".join(msg.topic.split("/")[2:]) + anyio_from_thread.run(self._send_message_to_session, server_name, msg) - def _handle_service_capability_list_changed_message(self, msg: mqtt.MQTTMessage) -> None: - service_name = "/".join(msg.topic.split("/")[4:]) - anyio_from_thread.run(self._send_message_to_session, service_name, msg) + def _handle_server_capability_list_changed_message(self, msg: mqtt.MQTTMessage) -> None: + server_name = "/".join(msg.topic.split("/")[4:]) + anyio_from_thread.run(self._send_message_to_session, server_name, msg) - def _handle_service_capability_resource_updated_message(self, msg: mqtt.MQTTMessage) -> None: - service_name = "/".join(msg.topic.split("/")[4:]) - anyio_from_thread.run(self._send_message_to_session, service_name, msg) + def _handle_server_capability_resource_updated_message(self, msg: mqtt.MQTTMessage) -> None: + server_name = "/".join(msg.topic.split("/")[4:]) + anyio_from_thread.run(self._send_message_to_session, server_name, msg) - def _subscribe_mcp_server_topics(self, service_id: ServiceId, service_name: ServiceName, + def _subscribe_mcp_server_topics(self, server_id: ServerId, server_name: ServerName, after_subscribed: Callable[[Any], Awaitable[None]]): topic_filters = [ - (mqtt_topic.get_service_capability_change_topic(service_id, service_name), SubscribeOptions(qos=QOS)), - (mqtt_topic.get_service_resource_update_topic(service_id, service_name), SubscribeOptions(qos=QOS)), - (mqtt_topic.get_rpc_topic(self.mcp_client_id, service_name), SubscribeOptions(qos=QOS, noLocal=True)) + (mqtt_topic.get_server_capability_change_topic(server_id, server_name), SubscribeOptions(qos=QOS)), + (mqtt_topic.get_server_resource_update_topic(server_id, server_name), SubscribeOptions(qos=QOS)), + (mqtt_topic.get_rpc_topic(self.mcp_client_id, server_name), SubscribeOptions(qos=QOS, noLocal=True)) ] ret, mid = self.client.subscribe(topic=topic_filters) if ret != mqtt.MQTT_ERR_SUCCESS: - logger.error(f"Failed to subscribe to topics for service_name: {service_name}") + logger.error(f"Failed to subscribe to topics for server_name: {server_name}") return False userdata = self.client.user_data_get() pending_subs = userdata.get("pending_subs", {}) - pending_subs[mid] = (service_name, service_id, after_subscribed) + pending_subs[mid] = (server_name, server_id, after_subscribed) userdata["pending_subs"] = pending_subs return True - async def _send_message_to_session(self, service_name: ServiceName, msg: mqtt.MQTTMessage): - if service_name not in self.service_sessions: - logger.error(f"_send_message_to_session: No session for service_name: {service_name}") + async def _send_message_to_session(self, server_name: ServerName, msg: mqtt.MQTTMessage): + if server_name not in self.server_sessions: + logger.error(f"_send_message_to_session: No session for server_name: {server_name}") return - client_session: MqttClientSession = self.service_sessions[service_name] + client_session: MqttClientSession = self.server_sessions[server_name] payload = msg.payload.decode() - service_id = client_session.service_id - if service_id not in self._read_stream_writers: - logger.error(f"No session for service_id: {service_id}") + server_id = client_session.server_id + if server_id not in self._read_stream_writers: + logger.error(f"No session for server_id: {server_id}") return - read_stream_writer = self._read_stream_writers[service_id] + read_stream_writer = self._read_stream_writers[server_id] try: message = types.JSONRPCMessage.model_validate_json(payload) - logger.debug(f"Sending msg to session for service_id: {service_id}, msg: {message}") + logger.debug(f"Sending msg to session for server_id: {server_id}, msg: {message}") with anyio.fail_after(3): await read_stream_writer.send(message) except Exception as exc: - logger.error(f"Failed to send msg to session for service_id: {service_id}, exception: {exc}") + logger.error(f"Failed to send msg to session for server_id: {server_id}, exception: {exc}") traceback.print_exc() ## TODO: the session does not handle exceptions for now #await read_stream_writer.send(exc) - async def _receieved_from_session(self, service_id: ServiceId, service_name: ServiceName, + async def _receieved_from_session(self, server_id: ServerId, server_name: ServerName, write_stream_reader: RcvStream): async with write_stream_reader: async for msg in write_stream_reader: - logger.debug(f"Got msg from session for service_id: {service_id}, msg: {msg}") + logger.debug(f"Got msg from session for server_id: {server_id}, msg: {msg}") match msg.model_dump(): case {"method": method} if method == "notifications/initialized": - logger.debug(f"Session initialized for service_id: {service_id}") - topic = mqtt_topic.get_rpc_topic(self.mcp_client_id, service_name) + logger.debug(f"Session initialized for server_id: {server_id}") + topic = mqtt_topic.get_rpc_topic(self.mcp_client_id, server_name) case {"method": method} if method.endswith("/list_changed"): topic = None logger.warning("Resource updates should not be sent from the session. Ignoring.") case {"method": method} if method == "initialize": - topic = mqtt_topic.get_service_control_topic(service_name) + topic = mqtt_topic.get_server_control_topic(server_name) case _: - topic = mqtt_topic.get_rpc_topic(self.mcp_client_id, service_name) + topic = mqtt_topic.get_rpc_topic(self.mcp_client_id, server_name) if topic: self.publish_json_rpc_message(topic, message = msg) # cleanup - if service_id in self._read_stream_writers: - logger.debug(f"Removing session for service_id: {service_id}") - stream = self._read_stream_writers.pop(service_id) + if server_id in self._read_stream_writers: + logger.debug(f"Removing session for server_id: {server_id}") + stream = self._read_stream_writers.pop(server_id) await stream.aclose() - logger.debug(f"Session stream closed for service_id: {service_id}") + logger.debug(f"Session stream closed for server_id: {server_id}") - def pick_service_id(self, service_name: str) -> ServiceId: - return random.choice(list(self.service_list[service_name].keys())) + def pick_server_id(self, server_name: str) -> ServerId: + return random.choice(list(self.server_list[server_name].keys())) -def validate_service_name(name: str): +def validate_server_name(name: str): if "/" not in name: - raise ValueError(f"Invalid service name: {name}, must contain a '/'") + raise ValueError(f"Invalid server name: {name}, must contain a '/'") elif ("+" in name) or ("#" in name): - raise ValueError(f"Invalid service name: {name}, must not contain '+' or '#'") + raise ValueError(f"Invalid server name: {name}, must not contain '+' or '#'") elif name[0] == "/": - raise ValueError(f"Invalid service name: {name}, must not start with '/'") + raise ValueError(f"Invalid server name: {name}, must not start with '/'") diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 6f175e29d1..e51ee74f55 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -35,7 +35,7 @@ from mcp.server.lowlevel.server import lifespan as default_lifespan from mcp.server.session import ServerSession, ServerSessionT from mcp.server.sse import SseServerTransport -from mcp.server.mqtt import validate_service_name, start_mqtt, MqttOptions +from mcp.server.mqtt import validate_server_name, start_mqtt, MqttOptions from mcp.server.stdio import stdio_server from mcp.shared.context import LifespanContextT, RequestContext from mcp.types import ( @@ -77,8 +77,8 @@ class Settings(BaseSettings, Generic[LifespanResultT]): message_path: str = "/messages/" # MQTT settings - mqtt_service_description: str = '' - mqtt_service_meta: dict[str, Any] = {} + mqtt_server_description: str = '' + mqtt_server_meta: dict[str, Any] = {} mqtt_client_id_prefix: str | None = None mqtt_options: MqttOptions = MqttOptions() @@ -164,7 +164,7 @@ def run(self, transport: Literal["stdio", "sse", "mqtt"] = "stdio") -> None: if transport == "stdio": anyio.run(self.run_stdio_async) elif transport == "mqtt": - validate_service_name(self._mcp_server.name) + validate_server_name(self._mcp_server.name) anyio.run(self.run_mqtt_async) else: # transport == "sse" anyio.run(self.run_sse_async) @@ -488,17 +488,17 @@ async def run_sse_async(self) -> None: async def run_mqtt_async(self) -> None: """Run the server using MQTT transport.""" - def server_run(read_stream: Any, write_stream: Any): + def server_session_run(read_stream: Any, write_stream: Any): return self._mcp_server.run( read_stream, write_stream, self._mcp_server.create_initialization_options(), ) await start_mqtt( - server_run, - service_name = self._mcp_server.name, - service_description=self.settings.mqtt_service_description, - service_meta = self.settings.mqtt_service_meta, + server_session_run, + server_name = self._mcp_server.name, + server_description=self.settings.mqtt_server_description, + server_meta = self.settings.mqtt_server_meta, client_id_prefix = self.settings.mqtt_client_id_prefix, mqtt_options = self.settings.mqtt_options ) diff --git a/src/mcp/server/mqtt.py b/src/mcp/server/mqtt.py index 80ecb523f0..d42d7a8592 100644 --- a/src/mcp/server/mqtt.py +++ b/src/mcp/server/mqtt.py @@ -23,45 +23,45 @@ SndStream : TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage] RcvStreamEx : TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] SndStreamEX : TypeAlias = MemoryObjectSendStream[types.JSONRPCMessage | Exception] -ServerRun : TypeAlias = Callable[[RcvStreamEx, SndStream], Awaitable[Any]] +ServerSessionRun : TypeAlias = Callable[[RcvStreamEx, SndStream], Awaitable[Any]] logger = logging.getLogger(__name__) class MqttTransportServer(MqttTransportBase): - def __init__(self, server_run: ServerRun, service_name: str, - service_description: str, - service_meta: dict[str, Any], + def __init__(self, server_session_run: ServerSessionRun, server_name: str, + server_description: str, + server_meta: dict[str, Any], client_id_prefix: str | None = None, mqtt_options: MqttOptions = MqttOptions()): uuid = uuid4().hex mqtt_clientid = f"{client_id_prefix}-{uuid}" if client_id_prefix else uuid - self.service_id = mqtt_clientid - self.service_name = service_name - self.service_description = service_description - self.service_meta = service_meta - self.service_control_topic = mqtt_topic.get_service_control_topic(service_name) - self.service_presence_topic = mqtt_topic.get_service_presence_topic(self.service_id, service_name) - self.service_capability_change_topic = mqtt_topic.get_service_capability_change_topic(self.service_id, service_name) - self.server_run = server_run + self.server_id = mqtt_clientid + self.server_name = server_name + self.server_description = server_description + self.server_meta = server_meta + self.server_control_topic = mqtt_topic.get_server_control_topic(server_name) + self.server_presence_topic = mqtt_topic.get_server_presence_topic(self.server_id, server_name) + self.server_capability_change_topic = mqtt_topic.get_server_capability_change_topic(self.server_id, server_name) + self.server_session_run = server_session_run super().__init__("mcp-server", mqtt_clientid = mqtt_clientid, mqtt_options = mqtt_options) - self.presence_topic = mqtt_topic.get_service_presence_topic(self.service_id, service_name) + self.presence_topic = mqtt_topic.get_server_presence_topic(self.server_id, server_name) self.disconnected_msg = None self.client.will_set(topic=self.presence_topic, payload=None, qos=QOS, retain=True) def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code : ReasonCode, properties: Properties | None): if reason_code == 0: super()._on_connect(client, userdata, connect_flags, reason_code, properties) - ## Subscribe to the service control topic - client.subscribe(self.service_control_topic, QOS) - ## Reister the service on the presence topic + ## Subscribe to the server control topic + client.subscribe(self.server_control_topic, QOS) + ## Reister the server on the presence topic online_msg = types.JSONRPCMessage( types.JSONRPCNotification( jsonrpc="2.0", - method = "notifications/service/online", + method = "notifications/server/online", params = { - "description": self.service_description, - "meta": self.service_meta + "description": self.server_description, + "meta": self.server_meta } )) self.publish_json_rpc_message( @@ -70,8 +70,8 @@ def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.Co def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): logger.debug(f"Received message on topic {msg.topic}: {msg.payload.decode()}") match msg.topic: - case str() as t if t == self.service_control_topic: - self.handle_service_contorl_message(msg) + case str() as t if t == self.server_control_topic: + self.handle_server_contorl_message(msg) case str() as t if t.startswith(mqtt_topic.CLIENT_CAPABILITY_CHANGE_BASE): self.handle_client_capability_change_message(msg) case str() as t if t.startswith(mqtt_topic.RPC_BASE): @@ -100,11 +100,11 @@ def _on_subscribe(self, client: mqtt.Client, userdata: Any, mid: int, ) ) self.publish_json_rpc_message( - mqtt_topic.get_rpc_topic(mcp_client_id, self.service_name), + mqtt_topic.get_rpc_topic(mcp_client_id, self.server_name), message = types.JSONRPCMessage(err) ) - def handle_service_contorl_message(self, msg: mqtt.MQTTMessage): + def handle_server_contorl_message(self, msg: mqtt.MQTTMessage): if msg.properties and hasattr(msg.properties, "UserProperty"): user_properties: dict[str, Any] = dict(msg.properties.UserProperty) # type: ignore if PROPERTY_K_MQTT_CLIENT_ID in user_properties: @@ -156,7 +156,7 @@ async def create_session(self, mcp_client_id: str, msg: mqtt.MQTTMessage): read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) self._read_stream_writers[mcp_client_id] = read_stream_writer - self._task_group.start_soon(self.server_run, read_stream, write_stream) + self._task_group.start_soon(self.server_session_run, read_stream, write_stream) self._task_group.start_soon(self._receieved_from_session, mcp_client_id, write_stream_reader) logger.debug(f"Created new session for mcp_client_id: {mcp_client_id}") await self._send_message_to_session(mcp_client_id, msg) @@ -177,7 +177,7 @@ def subscribe_to_client(self, mcp_client_id: str, msg: mqtt.MQTTMessage, rcp_msg topic_filters = [ (mqtt_topic.get_client_presence_topic(mcp_client_id), SubscribeOptions(qos=QOS)), (mqtt_topic.get_client_capability_change_topic(mcp_client_id), SubscribeOptions(qos=QOS)), - (mqtt_topic.get_rpc_topic(mcp_client_id, self.service_name), SubscribeOptions(qos=QOS, noLocal=True)) + (mqtt_topic.get_rpc_topic(mcp_client_id, self.server_name), SubscribeOptions(qos=QOS, noLocal=True)) ] ret, mid = self.client.subscribe(topic=topic_filters) if ret != mqtt.MQTT_ERR_SUCCESS: @@ -215,7 +215,7 @@ async def _receieved_from_session(self, mcp_client_id: str, write_stream_reader: case {"method": method} if method.endswith("/list_changed"): logger.warning("Resource updates should not be sent from the session. Ignoring.") case _: - topic = mqtt_topic.get_rpc_topic(mcp_client_id, self.service_name) + topic = mqtt_topic.get_rpc_topic(mcp_client_id, self.server_name) self.publish_json_rpc_message(topic, message = msg) # cleanup if mcp_client_id in self._read_stream_writers: @@ -226,16 +226,16 @@ async def _receieved_from_session(self, mcp_client_id: str, write_stream_reader: logger.debug(f"Session stream closed for mcp_client_id: {mcp_client_id}") async def start_mqtt( - server_run: ServerRun, service_name: str, - service_description: str, - service_meta: dict[str, Any], + server_session_run: ServerSessionRun, server_name: str, + server_description: str, + server_meta: dict[str, Any], client_id_prefix: str | None = None, mqtt_options: MqttOptions = MqttOptions()): async with MqttTransportServer( - server_run, - service_name = service_name, - service_description=service_description, - service_meta = service_meta, + server_session_run, + server_name = server_name, + server_description=server_description, + server_meta = server_meta, client_id_prefix = client_id_prefix, mqtt_options = mqtt_options ) as mqtt_trans: @@ -249,10 +249,10 @@ def start(): except Exception as exc: logger.error(f"MQTT transport (MCP server) failed with exception: {exc}") -def validate_service_name(name: str): +def validate_server_name(name: str): if "/" not in name: - raise ValueError(f"Invalid service name: {name}, must contain a '/'") + raise ValueError(f"Invalid server name: {name}, must contain a '/'") elif ("+" in name) or ("#" in name): - raise ValueError(f"Invalid service name: {name}, must not contain '+' or '#'") + raise ValueError(f"Invalid server name: {name}, must not contain '+' or '#'") elif name[0] == "/": - raise ValueError(f"Invalid service name: {name}, must not start with '/'") + raise ValueError(f"Invalid server name: {name}, must not start with '/'") diff --git a/src/mcp/shared/mqtt_topic.py b/src/mcp/shared/mqtt_topic.py index 9319bf6e8d..566dfbdebc 100644 --- a/src/mcp/shared/mqtt_topic.py +++ b/src/mcp/shared/mqtt_topic.py @@ -1,24 +1,24 @@ -SERVICE_CONTROL_BASE: str = '$mcp-service' -SERVICE_CAPABILITY_CHANGE_BASE: str = '$mcp-service/capability/list-changed' -SERVICE_RESOURCE_UPDATE_BASE: str = '$mcp-service/capability/resource-updated' -SERVICE_PRESENCE_BASE: str = '$mcp-service/presence' +SERVER_CONTROL_BASE: str = '$mcp-server' +SERVER_CAPABILITY_CHANGE_BASE: str = '$mcp-server/capability/list-changed' +SERVER_RESOURCE_UPDATE_BASE: str = '$mcp-server/capability/resource-updated' +SERVER_PRESENCE_BASE: str = '$mcp-server/presence' CLIENT_PRESENCE_BASE: str = '$mcp-client/presence' CLIENT_CAPABILITY_CHANGE_BASE: str = '$mcp-client/capability/list-changed' RPC_BASE: str = '$mcp-rpc-endpoint' -def get_service_control_topic(service_name: str) -> str: - return f"{SERVICE_CONTROL_BASE}/{service_name}" +def get_server_control_topic(server_name: str) -> str: + return f"{SERVER_CONTROL_BASE}/{server_name}" -def get_service_capability_change_topic(service_id: str, service_name: str) -> str: - return f"{SERVICE_CAPABILITY_CHANGE_BASE}/{service_id}/{service_name}" +def get_server_capability_change_topic(server_id: str, server_name: str) -> str: + return f"{SERVER_CAPABILITY_CHANGE_BASE}/{server_id}/{server_name}" -def get_service_resource_update_topic(service_id: str, service_name: str) -> str: - return f"{SERVICE_RESOURCE_UPDATE_BASE}/{service_id}/{service_name}" +def get_server_resource_update_topic(server_id: str, server_name: str) -> str: + return f"{SERVER_RESOURCE_UPDATE_BASE}/{server_id}/{server_name}" -def get_service_presence_topic(service_id: str, service_name: str) -> str: - return f"{SERVICE_PRESENCE_BASE}/{service_id}/{service_name}" +def get_server_presence_topic(server_id: str, server_name: str) -> str: + return f"{SERVER_PRESENCE_BASE}/{server_id}/{server_name}" def get_client_presence_topic(mcp_clientid: str) -> str: return f"{CLIENT_PRESENCE_BASE}/{mcp_clientid}" @@ -26,5 +26,5 @@ def get_client_presence_topic(mcp_clientid: str) -> str: def get_client_capability_change_topic(mcp_clientid: str) -> str: return f"{CLIENT_CAPABILITY_CHANGE_BASE}/{mcp_clientid}" -def get_rpc_topic(mcp_clientid: str, service_name: str) -> str: - return f"{RPC_BASE}/{mcp_clientid}/{service_name}" +def get_rpc_topic(mcp_clientid: str, server_name: str) -> str: + return f"{RPC_BASE}/{mcp_clientid}/{server_name}" From e63c0189125a5a955dbc3cbba9ed70bcbf3285b5 Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Sat, 5 Apr 2025 22:10:01 +0800 Subject: [PATCH 07/23] refactor some callbacks of mcp client --- .../clients/mqtt-clients/client_apis_demo.py | 16 +++---- src/mcp/client/mqtt.py | 44 +++++++++++-------- 2 files changed, 32 insertions(+), 28 deletions(-) diff --git a/examples/clients/mqtt-clients/client_apis_demo.py b/examples/clients/mqtt-clients/client_apis_demo.py index f303a340bc..ca040220cf 100644 --- a/examples/clients/mqtt-clients/client_apis_demo.py +++ b/examples/clients/mqtt-clients/client_apis_demo.py @@ -6,14 +6,13 @@ configure_logging(level="DEBUG") logger = logging.getLogger(__name__) -async def on_mcp_server_presence(client, server_name, status): - if status == "online": - logger.info(f"Connecting to {server_name}...") - await client.initialize_mcp_server(server_name) +async def on_mcp_server_discovered(client, server_name): + logger.info(f"Discovered {server_name}, connecting ...") + await client.initialize_mcp_server(server_name) async def on_mcp_connect(client, server_name, connect_result): logger.info(f"Connect result to {server_name}: {connect_result}") - capabilities = client.server_sessions[server_name].server_info.capabilities + capabilities = client.get_session(server_name).server_info.capabilities logger.info(f"Capabilities of {server_name}: {capabilities}") if capabilities.prompts: prompts = await client.list_prompts(server_name) @@ -27,15 +26,14 @@ async def on_mcp_connect(client, server_name, connect_result): tools = await client.list_tools(server_name) logger.info(f"Tools of {server_name}: {tools}") -async def on_mcp_disconnect(client, server_name, reason): - logger.info(f"Disconnected from {server_name}, reason: {reason}") - logger.info(f"Server sessions now: {client.server_sessions}") +async def on_mcp_disconnect(client, server_name): + logger.info(f"Disconnected from {server_name}") async def main(): async with mcp_mqtt.MqttTransportClient( "test_client", auto_connect_to_mcp_server = True, - on_mcp_server_presence = on_mcp_server_presence, + on_mcp_server_discovered = on_mcp_server_discovered, on_mcp_connect = on_mcp_connect, on_mcp_disconnect = on_mcp_disconnect, mqtt_options = mcp_mqtt.MqttOptions( diff --git a/src/mcp/client/mqtt.py b/src/mcp/client/mqtt.py index 0eff41cda0..c6048c9805 100644 --- a/src/mcp/client/mqtt.py +++ b/src/mcp/client/mqtt.py @@ -34,7 +34,6 @@ ServerId : TypeAlias = str InitializeResult : TypeAlias = Literal["ok"] | Literal["already_connected"] | tuple[Literal["error"], str] ConnectResult : TypeAlias = tuple[Literal["ok"], types.InitializeResult] | tuple[Literal["error"], Any] -DisconnectReason : TypeAlias = Literal["client_initiated_disconnect", "server_initiated_disconnect"] logger = logging.getLogger(__name__) @@ -79,21 +78,20 @@ def __init__(self, mcp_client_name: str, client_id_prefix: str | None = None, server_name_filter: str = '#', auto_connect_to_mcp_server: bool = False, on_mcp_connect: Callable[["MqttTransportClient", ServerName, ConnectResult], Awaitable[Any]] | None = None, - on_mcp_disconnect: Callable[["MqttTransportClient", ServerName, DisconnectReason], Awaitable[Any]] | None = None, - on_mcp_server_presence: Callable[["MqttTransportClient", ServerName, Literal["online", "offline"]], Awaitable[Any]] | None = None, + on_mcp_disconnect: Callable[["MqttTransportClient", ServerName], Awaitable[Any]] | None = None, + on_mcp_server_discovered: Callable[["MqttTransportClient", ServerName], Awaitable[Any]] | None = None, mqtt_options: MqttOptions = MqttOptions()): - self.exit_stack: AsyncExitStack = AsyncExitStack() uuid = uuid4().hex mqtt_clientid = f"{client_id_prefix}-{uuid}" if client_id_prefix else uuid self.server_list: dict[ServerName, dict[ServerId, ServerDefinition]] = {} - self.server_sessions: dict[ServerName, MqttClientSession] = {} + self.client_sessions: dict[ServerName, MqttClientSession] = {} self.mcp_client_id = mqtt_clientid self.mcp_client_name = mcp_client_name self.server_name_filter = server_name_filter self.auto_connect_to_mcp_server = auto_connect_to_mcp_server #TODO: not implemented yet self.on_mcp_connect = on_mcp_connect - self.on_mcp_disconnect = on_mcp_disconnect #TODO: not implemented yet - self.on_mcp_server_presence = on_mcp_server_presence + self.on_mcp_disconnect = on_mcp_disconnect + self.on_mcp_server_discovered = on_mcp_server_discovered self.client_capability_change_topic = mqtt_topic.get_client_capability_change_topic(self.mcp_client_id) super().__init__("mcp-client", mqtt_clientid = mqtt_clientid, mqtt_options = mqtt_options) self.presence_topic = mqtt_topic.get_client_presence_topic(self.mcp_client_id) @@ -121,6 +119,9 @@ def do_start(): logger.error(f"MQTT transport (MCP client) failed: {exc}") traceback.print_exc() + def get_session(self, server_name: ServerName) -> MqttClientSession | None: + return self.client_sessions.get(server_name, None) + async def initialize_mcp_server( self, server_name: str, read_timeout_seconds: timedelta | None = None, @@ -128,7 +129,7 @@ async def initialize_mcp_server( list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None) -> InitializeResult: - if server_name in self.server_sessions: + if server_name in self.client_sessions: return "already_connected" if server_name not in self.server_list: logger.error(f"MCP server not found, server name: {server_name}") @@ -150,20 +151,20 @@ async def after_subscribed( logging_callback, message_handler ) - self.server_sessions[server_name] = client_session + self.client_sessions[server_name] = client_session try: logger.debug(f"before initialize: {server_name}") async def after_initialize(): + exit_stack = AsyncExitStack() try: - session = await self.exit_stack.enter_async_context(client_session) + session = await exit_stack.enter_async_context(client_session) init_result = await session.initialize() session.server_info = init_result if self.on_mcp_connect: self._task_group.start_soon(self.on_mcp_connect, self, server_name, ("ok", init_result)) except Exception as e: logging.error(f"Failed to initialize server {server_name}: {e}") - await self.exit_stack.aclose() - raise + await exit_stack.aclose() self._task_group.start_soon(after_initialize) logger.debug(f"after initialize: {server_name}") except McpError as exc: @@ -237,7 +238,7 @@ async def send_roots_list_changed(self, server_name: ServerName) -> bool | None: async def _with_session( self, server_name: ServerName, async_callback: Callable[[MqttClientSession], Awaitable[bool | Any]]) -> bool | Any: - if not (client_session := self.server_sessions.get(server_name)): + if not (client_session := self.client_sessions.get(server_name, None)): logger.error(f"No session for server_name: {server_name}") return False return await async_callback(client_session) @@ -316,16 +317,22 @@ def _handle_server_presence_message(self, msg: mqtt.MQTTMessage) -> None: self.server_list.setdefault(server_name, {})[server_id] = server_notif.params logger.debug(f"Server {server_name} with id {server_id} is online") if newly_added_server: - if self.on_mcp_server_presence: - anyio_from_thread.run(self.on_mcp_server_presence, self, server_name, "online") + if self.on_mcp_server_discovered: + anyio_from_thread.run(self.on_mcp_server_discovered, self, server_name) else: existing_server = True if server_name in self.server_list else False if server_id in self.server_list.get(server_name, {}): logger.debug(f"Server {server_name} with id {server_id} is offline") self.server_list[server_name].pop(server_id) + if not self.server_list[server_name]: + self.server_list.pop(server_name) + if server_name in self.client_sessions: + _ = self.client_sessions.pop(server_name) + stream = self._read_stream_writers.pop(server_id) + stream.close() if existing_server: - if self.on_mcp_server_presence: - anyio_from_thread.run(self.on_mcp_server_presence, self, server_name, "offline") + if self.on_mcp_disconnect: + anyio_from_thread.run(self.on_mcp_disconnect, self, server_name) def _handle_rpc_message(self, msg: mqtt.MQTTMessage) -> None: server_name = "/".join(msg.topic.split("/")[2:]) @@ -357,10 +364,9 @@ def _subscribe_mcp_server_topics(self, server_id: ServerId, server_name: ServerN return True async def _send_message_to_session(self, server_name: ServerName, msg: mqtt.MQTTMessage): - if server_name not in self.server_sessions: + if not (client_session := self.client_sessions.get(server_name, None)): logger.error(f"_send_message_to_session: No session for server_name: {server_name}") return - client_session: MqttClientSession = self.server_sessions[server_name] payload = msg.payload.decode() server_id = client_session.server_id if server_id not in self._read_stream_writers: From 5583d774bce45e6f7917c04075117d084a259f47 Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Sat, 5 Apr 2025 22:10:51 +0800 Subject: [PATCH 08/23] fix: session crashes on response_stream_reader EOF --- src/mcp/shared/session.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 9f04bf841a..062830f3d7 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -252,7 +252,12 @@ async def send_request( if self._read_timeout_seconds is None else self._read_timeout_seconds.total_seconds() ): - response_or_error = await response_stream_reader.receive() + async with response_stream_reader: + async for response_or_error in response_stream_reader: + if isinstance(response_or_error, JSONRPCError): + raise McpError(response_or_error.error) + else: + return result_type.model_validate(response_or_error.result) except TimeoutError: raise McpError( ErrorData( @@ -265,11 +270,6 @@ async def send_request( ) ) - if isinstance(response_or_error, JSONRPCError): - raise McpError(response_or_error.error) - else: - return result_type.model_validate(response_or_error.result) - async def send_notification(self, notification: SendNotificationT) -> None: """ Emits a notification, which is a one-way message that does not expect From 955b30df552f671258fa915d08783fb89b6253c3 Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Sat, 5 Apr 2025 22:16:08 +0800 Subject: [PATCH 09/23] improve client_apis_demo.py --- examples/clients/mqtt-clients/client_apis_demo.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/clients/mqtt-clients/client_apis_demo.py b/examples/clients/mqtt-clients/client_apis_demo.py index ca040220cf..31da377f63 100644 --- a/examples/clients/mqtt-clients/client_apis_demo.py +++ b/examples/clients/mqtt-clients/client_apis_demo.py @@ -11,7 +11,6 @@ async def on_mcp_server_discovered(client, server_name): await client.initialize_mcp_server(server_name) async def on_mcp_connect(client, server_name, connect_result): - logger.info(f"Connect result to {server_name}: {connect_result}") capabilities = client.get_session(server_name).server_info.capabilities logger.info(f"Capabilities of {server_name}: {capabilities}") if capabilities.prompts: @@ -23,7 +22,8 @@ async def on_mcp_connect(client, server_name, connect_result): resource_templates = await client.list_resource_templates(server_name) logger.info(f"Resources templates of {server_name}: {resource_templates}") if capabilities.tools: - tools = await client.list_tools(server_name) + toolsResult = await client.list_tools(server_name) + tools = toolsResult.tools logger.info(f"Tools of {server_name}: {tools}") async def on_mcp_disconnect(client, server_name): @@ -38,8 +38,6 @@ async def main(): on_mcp_disconnect = on_mcp_disconnect, mqtt_options = mcp_mqtt.MqttOptions( host="broker.emqx.io", - port=1883, - keepalive=60 ) ) as client: client.start() From 6c8e1fcfd23771c49d614dd9a0feb5e13e08c8cb Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Thu, 17 Apr 2025 09:29:18 +0800 Subject: [PATCH 10/23] add server-id into contorl topic --- src/mcp/client/mqtt.py | 2 +- src/mcp/server/mqtt.py | 2 +- src/mcp/shared/mqtt_topic.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcp/client/mqtt.py b/src/mcp/client/mqtt.py index c6048c9805..03bff3ddc4 100644 --- a/src/mcp/client/mqtt.py +++ b/src/mcp/client/mqtt.py @@ -396,7 +396,7 @@ async def _receieved_from_session(self, server_id: ServerId, server_name: Server topic = None logger.warning("Resource updates should not be sent from the session. Ignoring.") case {"method": method} if method == "initialize": - topic = mqtt_topic.get_server_control_topic(server_name) + topic = mqtt_topic.get_server_control_topic(server_id, server_name) case _: topic = mqtt_topic.get_rpc_topic(self.mcp_client_id, server_name) if topic: diff --git a/src/mcp/server/mqtt.py b/src/mcp/server/mqtt.py index d42d7a8592..c0e89c2b34 100644 --- a/src/mcp/server/mqtt.py +++ b/src/mcp/server/mqtt.py @@ -40,7 +40,7 @@ def __init__(self, server_session_run: ServerSessionRun, server_name: str, self.server_name = server_name self.server_description = server_description self.server_meta = server_meta - self.server_control_topic = mqtt_topic.get_server_control_topic(server_name) + self.server_control_topic = mqtt_topic.get_server_control_topic(self.server_id, server_name) self.server_presence_topic = mqtt_topic.get_server_presence_topic(self.server_id, server_name) self.server_capability_change_topic = mqtt_topic.get_server_capability_change_topic(self.server_id, server_name) self.server_session_run = server_session_run diff --git a/src/mcp/shared/mqtt_topic.py b/src/mcp/shared/mqtt_topic.py index 566dfbdebc..ee5a543aa5 100644 --- a/src/mcp/shared/mqtt_topic.py +++ b/src/mcp/shared/mqtt_topic.py @@ -8,8 +8,8 @@ CLIENT_CAPABILITY_CHANGE_BASE: str = '$mcp-client/capability/list-changed' RPC_BASE: str = '$mcp-rpc-endpoint' -def get_server_control_topic(server_name: str) -> str: - return f"{SERVER_CONTROL_BASE}/{server_name}" +def get_server_control_topic(server_id: str, server_name: str) -> str: + return f"{SERVER_CONTROL_BASE}/{server_id}/{server_name}" def get_server_capability_change_topic(server_id: str, server_name: str) -> str: return f"{SERVER_CAPABILITY_CHANGE_BASE}/{server_id}/{server_name}" From b71fa798791123038372d3b979a2af8987e132c5 Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Thu, 17 Apr 2025 23:26:41 +0800 Subject: [PATCH 11/23] support broker suggested MCP server name --- src/mcp/client/mqtt.py | 17 +++++++-------- src/mcp/server/mqtt.py | 27 ++++++++++++++++-------- src/mcp/shared/mqtt.py | 47 ++++++++++++++++++++++++++++++++++-------- 3 files changed, 64 insertions(+), 27 deletions(-) diff --git a/src/mcp/client/mqtt.py b/src/mcp/client/mqtt.py index 03bff3ddc4..e9c6915e0c 100644 --- a/src/mcp/client/mqtt.py +++ b/src/mcp/client/mqtt.py @@ -93,19 +93,18 @@ def __init__(self, mcp_client_name: str, client_id_prefix: str | None = None, self.on_mcp_disconnect = on_mcp_disconnect self.on_mcp_server_discovered = on_mcp_server_discovered self.client_capability_change_topic = mqtt_topic.get_client_capability_change_topic(self.mcp_client_id) - super().__init__("mcp-client", mqtt_clientid = mqtt_clientid, mqtt_options = mqtt_options) - self.presence_topic = mqtt_topic.get_client_presence_topic(self.mcp_client_id) ## Send disconnected notification when disconnects - self.disconnected_msg = types.JSONRPCNotification( + disconnected_msg = types.JSONRPCNotification( jsonrpc="2.0", method = "notifications/disconnected" ) - self.disconnected_msg_retain = False - self.client.will_set( - topic=self.presence_topic, - payload=self.disconnected_msg.model_dump_json(), - qos=QOS - ) + super().__init__("mcp-client", mqtt_clientid = mqtt_clientid, + mqtt_options = mqtt_options, + disconnected_msg = types.JSONRPCMessage(disconnected_msg), + disconnected_msg_retain = False) + + def get_presence_topic(self) -> str: + return mqtt_topic.get_client_presence_topic(self.mcp_client_id) def start(self): def do_start(): diff --git a/src/mcp/server/mqtt.py b/src/mcp/server/mqtt.py index c0e89c2b34..9882fe8c8a 100644 --- a/src/mcp/server/mqtt.py +++ b/src/mcp/server/mqtt.py @@ -2,7 +2,7 @@ This module implements the MQTT transport for the MCP server. """ from uuid import uuid4 -from mcp.shared.mqtt import MqttTransportBase, MqttOptions, QOS, PROPERTY_K_MQTT_CLIENT_ID +from mcp.shared.mqtt import MqttTransportBase, MqttOptions, QOS, PROPERTY_K_MQTT_CLIENT_ID, MCP_SERVER_NAME import asyncio import anyio.to_thread as anyio_to_thread import anyio.from_thread as anyio_from_thread @@ -40,18 +40,27 @@ def __init__(self, server_session_run: ServerSessionRun, server_name: str, self.server_name = server_name self.server_description = server_description self.server_meta = server_meta - self.server_control_topic = mqtt_topic.get_server_control_topic(self.server_id, server_name) - self.server_presence_topic = mqtt_topic.get_server_presence_topic(self.server_id, server_name) - self.server_capability_change_topic = mqtt_topic.get_server_capability_change_topic(self.server_id, server_name) self.server_session_run = server_session_run - super().__init__("mcp-server", mqtt_clientid = mqtt_clientid, mqtt_options = mqtt_options) - self.presence_topic = mqtt_topic.get_server_presence_topic(self.server_id, server_name) - self.disconnected_msg = None - self.client.will_set(topic=self.presence_topic, payload=None, qos=QOS, retain=True) + super().__init__("mcp-server", mqtt_clientid = mqtt_clientid, + mqtt_options = mqtt_options, + disconnected_msg = None, + disconnected_msg_retain = True) + + def get_presence_topic(self) -> str: + return mqtt_topic.get_server_presence_topic(self.server_id, self.server_name) def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code : ReasonCode, properties: Properties | None): if reason_code == 0: super()._on_connect(client, userdata, connect_flags, reason_code, properties) + if properties and hasattr(properties, "UserProperty"): + user_properties: dict[str, Any] = dict(properties.UserProperty) # type: ignore + if MCP_SERVER_NAME in user_properties: + broker_suggested_server_name = user_properties[MCP_SERVER_NAME] + self.server_name = broker_suggested_server_name + logger.debug(f"Used broker suggested server name: {broker_suggested_server_name}") + else: + logger.error(f"No {PROPERTY_K_MQTT_CLIENT_ID} in UserProperties") + self.server_control_topic = mqtt_topic.get_server_control_topic(self.server_id, self.server_name) ## Subscribe to the server control topic client.subscribe(self.server_control_topic, QOS) ## Reister the server on the presence topic @@ -65,7 +74,7 @@ def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.Co } )) self.publish_json_rpc_message( - self.presence_topic, message=online_msg, retain=True) + self.get_presence_topic(), message=online_msg, retain=True) def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): logger.debug(f"Received message on topic {msg.topic}: {msg.payload.decode()}") diff --git a/src/mcp/shared/mqtt.py b/src/mcp/shared/mqtt.py index 7d87173739..8bfa4631f7 100644 --- a/src/mcp/shared/mqtt.py +++ b/src/mcp/shared/mqtt.py @@ -16,9 +16,12 @@ from typing import Literal, Optional, Any, TypeAlias, Callable, Awaitable import mcp.types as types from typing_extensions import Self +from abc import ABC, abstractmethod DEFAULT_LOG_FORMAT = "%(asctime)s - %(message)s" QOS = 1 +MCP_SERVER_NAME = "MCP-SERVER-NAME" +MCP_AUTH_ROLE = "MCP-AUTH-ROLE" PROPERTY_K_MCP_COMPONENT = "MCP-COMPONENT-TYPE" PROPERTY_K_MQTT_CLIENT_ID = "MQTT-CLIENT-ID" logger = logging.getLogger(__name__) @@ -50,7 +53,7 @@ class MqttOptions(BaseModel): websocket_path: str = '/mqtt' websocket_headers: Optional[dict[str, str]] = None -class MqttTransportBase: +class MqttTransportBase(ABC): _read_stream_writers: dict[ str, SndStreamEX ] @@ -58,14 +61,15 @@ class MqttTransportBase: def __init__(self, mcp_component_type: Literal["mcp-client", "mcp-server"], mqtt_clientid: str | None = None, - mqtt_options: MqttOptions = MqttOptions()): + mqtt_options: MqttOptions = MqttOptions(), + disconnected_msg: types.JSONRPCMessage | None = None, + disconnected_msg_retain: bool = True): self._read_stream_writers = {} self.mqtt_clientid = mqtt_clientid self.mcp_component_type = mcp_component_type self.mqtt_options = mqtt_options - self.presence_topic = '' - self.disconnected_msg = None - self.disconnected_msg_retain = True + self.disconnected_msg = disconnected_msg + self.disconnected_msg_retain = disconnected_msg_retain client = mqtt.Client( callback_api_version=CallbackAPIVersion.VERSION2, client_id=mqtt_clientid, protocol=mqtt.MQTTv5, @@ -89,6 +93,18 @@ def __init__(self, client.on_connect = self._on_connect client.on_message = self._on_message client.on_subscribe = self._on_subscribe + ## We need to set an empty will message to clean the retained presence + ## message when the MCP server goes offline. + ## Note that if the broker suggested a new server name, it's the broker's + ## responsibility to clean the retained presence message and send the + ## last will message on the changed presence topic. + client.will_set( + topic = self.get_presence_topic(), + payload = disconnected_msg.model_dump_json() if disconnected_msg else None, + qos = QOS, + retain = disconnected_msg_retain, + properties = self.get_publish_properties(), + ) self.client = client async def __aenter__(self) -> Self: @@ -123,23 +139,32 @@ def _on_subscribe(self, client: mqtt.Client, userdata: Any, mid: int, def publish_json_rpc_message(self, topic: str, message: types.JSONRPCMessage | None, retain: bool = False): + props = self.get_publish_properties() + payload = message.model_dump_json(by_alias=True, exclude_none=True) if message else None + self.client.publish(topic=topic, payload=payload, qos=QOS, retain=retain, properties=props) + + def get_publish_properties(self): props = Properties(PacketTypes.PUBLISH) props.UserProperty = [ (PROPERTY_K_MCP_COMPONENT, self.mcp_component_type), (PROPERTY_K_MQTT_CLIENT_ID, self.mqtt_clientid) ] - payload = message.model_dump_json(by_alias=True, exclude_none=True) if message else None - self.client.publish(topic=topic, payload=payload, qos=QOS, retain=retain, properties=props) + return props def connect(self): logger.debug("Setting up MQTT connection") + props = Properties(PacketTypes.CONNECT) + props.UserProperty = [ + (PROPERTY_K_MCP_COMPONENT, self.mcp_component_type) + ] self.client.connect( host = self.mqtt_options.host, port = self.mqtt_options.port, keepalive = self.mqtt_options.keepalive, bind_address = self.mqtt_options.bind_address, bind_port = self.mqtt_options.bind_port, - clean_start=True + clean_start=True, + properties=props, ) def assert_property(self, properties: Properties | None, property_name: str, expected_value: Any): @@ -149,9 +174,13 @@ def assert_property(self, properties: Properties | None, property_name: str, exp anyio_from_thread.run(self.stop_mqtt) raise ValueError(f"{property_name} not available") + @abstractmethod + def get_presence_topic(self) -> str: + pass + async def stop_mqtt(self): self.publish_json_rpc_message( - self.presence_topic, + self.get_presence_topic(), message = self.disconnected_msg, retain = self.disconnected_msg_retain ) From 7f9edc869daa6055d4c627253cbacb45035f7196 Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Wed, 14 May 2025 11:58:58 +0800 Subject: [PATCH 12/23] log mqtt settings --- src/mcp/shared/mqtt.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/mcp/shared/mqtt.py b/src/mcp/shared/mqtt.py index 8bfa4631f7..94186592cc 100644 --- a/src/mcp/shared/mqtt.py +++ b/src/mcp/shared/mqtt.py @@ -12,7 +12,7 @@ import anyio import anyio.from_thread as anyio_from_thread from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import BaseModel +from pydantic import BaseModel, SecretStr from typing import Literal, Optional, Any, TypeAlias, Callable, Awaitable import mcp.types as types from typing_extensions import Self @@ -40,7 +40,7 @@ class MqttOptions(BaseModel): bind_address: str = '' bind_port: int = 0 username: Optional[str] = None - password: Optional[str] = None + password: Optional[SecretStr] = None tls_enabled: bool = False tls_version: Optional[int] = None tls_insecure: bool = False @@ -76,7 +76,7 @@ def __init__(self, userdata={}, transport=mqtt_options.transport, reconnect_on_failure=True ) - client.username_pw_set(mqtt_options.username, mqtt_options.password) + client.username_pw_set(mqtt_options.username, mqtt_options.password.get_secret_value() if mqtt_options.password else None) if mqtt_options.tls_enabled: client.tls_set( # type: ignore ca_certs=mqtt_options.ca_certs, @@ -105,6 +105,7 @@ def __init__(self, retain = disconnected_msg_retain, properties = self.get_publish_properties(), ) + logger.info(f"MCP component type: {mcp_component_type}, MQTT clientid: {mqtt_clientid}, MQTT settings: {mqtt_options}") self.client = client async def __aenter__(self) -> Self: From 5d45514d725315e7b05fd1fff6ca9422038c9686 Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Fri, 23 May 2025 17:58:22 +0800 Subject: [PATCH 13/23] add server-id into the rpc topic; change client_id_prefix to client_id --- src/mcp/client/mqtt.py | 12 +++++++----- src/mcp/server/fastmcp/server.py | 4 ++-- src/mcp/server/mqtt.py | 14 +++++++------- src/mcp/shared/mqtt_topic.py | 4 ++-- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/mcp/client/mqtt.py b/src/mcp/client/mqtt.py index e9c6915e0c..345c0742e3 100644 --- a/src/mcp/client/mqtt.py +++ b/src/mcp/client/mqtt.py @@ -74,7 +74,9 @@ def __init__( class MqttTransportClient(MqttTransportBase): - def __init__(self, mcp_client_name: str, client_id_prefix: str | None = None, + def __init__(self, + mcp_client_name: str, + client_id: str | None = None, server_name_filter: str = '#', auto_connect_to_mcp_server: bool = False, on_mcp_connect: Callable[["MqttTransportClient", ServerName, ConnectResult], Awaitable[Any]] | None = None, @@ -82,7 +84,7 @@ def __init__(self, mcp_client_name: str, client_id_prefix: str | None = None, on_mcp_server_discovered: Callable[["MqttTransportClient", ServerName], Awaitable[Any]] | None = None, mqtt_options: MqttOptions = MqttOptions()): uuid = uuid4().hex - mqtt_clientid = f"{client_id_prefix}-{uuid}" if client_id_prefix else uuid + mqtt_clientid = client_id if client_id else uuid self.server_list: dict[ServerName, dict[ServerId, ServerDefinition]] = {} self.client_sessions: dict[ServerName, MqttClientSession] = {} self.mcp_client_id = mqtt_clientid @@ -350,7 +352,7 @@ def _subscribe_mcp_server_topics(self, server_id: ServerId, server_name: ServerN topic_filters = [ (mqtt_topic.get_server_capability_change_topic(server_id, server_name), SubscribeOptions(qos=QOS)), (mqtt_topic.get_server_resource_update_topic(server_id, server_name), SubscribeOptions(qos=QOS)), - (mqtt_topic.get_rpc_topic(self.mcp_client_id, server_name), SubscribeOptions(qos=QOS, noLocal=True)) + (mqtt_topic.get_rpc_topic(self.mcp_client_id, server_id, server_name), SubscribeOptions(qos=QOS, noLocal=True)) ] ret, mid = self.client.subscribe(topic=topic_filters) if ret != mqtt.MQTT_ERR_SUCCESS: @@ -390,14 +392,14 @@ async def _receieved_from_session(self, server_id: ServerId, server_name: Server match msg.model_dump(): case {"method": method} if method == "notifications/initialized": logger.debug(f"Session initialized for server_id: {server_id}") - topic = mqtt_topic.get_rpc_topic(self.mcp_client_id, server_name) + topic = mqtt_topic.get_rpc_topic(self.mcp_client_id, server_id, server_name) case {"method": method} if method.endswith("/list_changed"): topic = None logger.warning("Resource updates should not be sent from the session. Ignoring.") case {"method": method} if method == "initialize": topic = mqtt_topic.get_server_control_topic(server_id, server_name) case _: - topic = mqtt_topic.get_rpc_topic(self.mcp_client_id, server_name) + topic = mqtt_topic.get_rpc_topic(self.mcp_client_id, server_id, server_name) if topic: self.publish_json_rpc_message(topic, message = msg) # cleanup diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index e51ee74f55..2108b1cf61 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -79,7 +79,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]): # MQTT settings mqtt_server_description: str = '' mqtt_server_meta: dict[str, Any] = {} - mqtt_client_id_prefix: str | None = None + mqtt_client_id: str | None = None mqtt_options: MqttOptions = MqttOptions() # resource settings @@ -499,7 +499,7 @@ def server_session_run(read_stream: Any, write_stream: Any): server_name = self._mcp_server.name, server_description=self.settings.mqtt_server_description, server_meta = self.settings.mqtt_server_meta, - client_id_prefix = self.settings.mqtt_client_id_prefix, + client_id = self.settings.mqtt_client_id, mqtt_options = self.settings.mqtt_options ) diff --git a/src/mcp/server/mqtt.py b/src/mcp/server/mqtt.py index 9882fe8c8a..7c38593609 100644 --- a/src/mcp/server/mqtt.py +++ b/src/mcp/server/mqtt.py @@ -32,10 +32,10 @@ class MqttTransportServer(MqttTransportBase): def __init__(self, server_session_run: ServerSessionRun, server_name: str, server_description: str, server_meta: dict[str, Any], - client_id_prefix: str | None = None, + client_id: str | None = None, mqtt_options: MqttOptions = MqttOptions()): uuid = uuid4().hex - mqtt_clientid = f"{client_id_prefix}-{uuid}" if client_id_prefix else uuid + mqtt_clientid = client_id if client_id else uuid self.server_id = mqtt_clientid self.server_name = server_name self.server_description = server_description @@ -109,7 +109,7 @@ def _on_subscribe(self, client: mqtt.Client, userdata: Any, mid: int, ) ) self.publish_json_rpc_message( - mqtt_topic.get_rpc_topic(mcp_client_id, self.server_name), + mqtt_topic.get_rpc_topic(mcp_client_id, self.server_id, self.server_name), message = types.JSONRPCMessage(err) ) @@ -186,7 +186,7 @@ def subscribe_to_client(self, mcp_client_id: str, msg: mqtt.MQTTMessage, rcp_msg topic_filters = [ (mqtt_topic.get_client_presence_topic(mcp_client_id), SubscribeOptions(qos=QOS)), (mqtt_topic.get_client_capability_change_topic(mcp_client_id), SubscribeOptions(qos=QOS)), - (mqtt_topic.get_rpc_topic(mcp_client_id, self.server_name), SubscribeOptions(qos=QOS, noLocal=True)) + (mqtt_topic.get_rpc_topic(mcp_client_id, self.server_id, self.server_name), SubscribeOptions(qos=QOS, noLocal=True)) ] ret, mid = self.client.subscribe(topic=topic_filters) if ret != mqtt.MQTT_ERR_SUCCESS: @@ -224,7 +224,7 @@ async def _receieved_from_session(self, mcp_client_id: str, write_stream_reader: case {"method": method} if method.endswith("/list_changed"): logger.warning("Resource updates should not be sent from the session. Ignoring.") case _: - topic = mqtt_topic.get_rpc_topic(mcp_client_id, self.server_name) + topic = mqtt_topic.get_rpc_topic(mcp_client_id, self.server_id, self.server_name) self.publish_json_rpc_message(topic, message = msg) # cleanup if mcp_client_id in self._read_stream_writers: @@ -238,14 +238,14 @@ async def start_mqtt( server_session_run: ServerSessionRun, server_name: str, server_description: str, server_meta: dict[str, Any], - client_id_prefix: str | None = None, + client_id: str | None = None, mqtt_options: MqttOptions = MqttOptions()): async with MqttTransportServer( server_session_run, server_name = server_name, server_description=server_description, server_meta = server_meta, - client_id_prefix = client_id_prefix, + client_id = client_id, mqtt_options = mqtt_options ) as mqtt_trans: def start(): diff --git a/src/mcp/shared/mqtt_topic.py b/src/mcp/shared/mqtt_topic.py index ee5a543aa5..8befc32a60 100644 --- a/src/mcp/shared/mqtt_topic.py +++ b/src/mcp/shared/mqtt_topic.py @@ -26,5 +26,5 @@ def get_client_presence_topic(mcp_clientid: str) -> str: def get_client_capability_change_topic(mcp_clientid: str) -> str: return f"{CLIENT_CAPABILITY_CHANGE_BASE}/{mcp_clientid}" -def get_rpc_topic(mcp_clientid: str, server_name: str) -> str: - return f"{RPC_BASE}/{mcp_clientid}/{server_name}" +def get_rpc_topic(mcp_clientid: str, server_id: str, server_name: str) -> str: + return f"{RPC_BASE}/{mcp_clientid}/{server_id}/{server_name}" From 8f1e139c34687561b4a9b7a28093c3c273f735e9 Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Sat, 24 May 2025 21:18:12 +0800 Subject: [PATCH 14/23] change MQTT-CLIENT-ID to MCP-MQTT-CLIENT-ID --- src/mcp/shared/mqtt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/shared/mqtt.py b/src/mcp/shared/mqtt.py index 94186592cc..f56cf51c6e 100644 --- a/src/mcp/shared/mqtt.py +++ b/src/mcp/shared/mqtt.py @@ -23,7 +23,7 @@ MCP_SERVER_NAME = "MCP-SERVER-NAME" MCP_AUTH_ROLE = "MCP-AUTH-ROLE" PROPERTY_K_MCP_COMPONENT = "MCP-COMPONENT-TYPE" -PROPERTY_K_MQTT_CLIENT_ID = "MQTT-CLIENT-ID" +PROPERTY_K_MQTT_CLIENT_ID = "MCP-MQTT-CLIENT-ID" logger = logging.getLogger(__name__) RcvStream : TypeAlias = MemoryObjectReceiveStream[types.JSONRPCMessage] From cc3a0ea64ab4dd91fea1f5b0a167320c668dcc4d Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Sat, 24 May 2025 23:11:19 +0800 Subject: [PATCH 15/23] fix: got incorrect server_name in clients --- src/mcp/client/mqtt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/client/mqtt.py b/src/mcp/client/mqtt.py index 345c0742e3..1c790d8acf 100644 --- a/src/mcp/client/mqtt.py +++ b/src/mcp/client/mqtt.py @@ -336,7 +336,7 @@ def _handle_server_presence_message(self, msg: mqtt.MQTTMessage) -> None: anyio_from_thread.run(self.on_mcp_disconnect, self, server_name) def _handle_rpc_message(self, msg: mqtt.MQTTMessage) -> None: - server_name = "/".join(msg.topic.split("/")[2:]) + server_name = "/".join(msg.topic.split("/")[3:]) anyio_from_thread.run(self._send_message_to_session, server_name, msg) def _handle_server_capability_list_changed_message(self, msg: mqtt.MQTTMessage) -> None: From 17cc30e7333f6609d5224dbf1babcd71f25ebca3 Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Fri, 6 Jun 2025 13:18:14 +0800 Subject: [PATCH 16/23] wait for connection success when start a mqtt client --- src/mcp/client/mqtt.py | 20 +++++++++++++++++--- src/mcp/server/mqtt.py | 2 +- src/mcp/shared/mqtt.py | 16 ++++++++++++++-- 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/mcp/client/mqtt.py b/src/mcp/client/mqtt.py index 1c790d8acf..93627cbf12 100644 --- a/src/mcp/client/mqtt.py +++ b/src/mcp/client/mqtt.py @@ -108,17 +108,31 @@ def __init__(self, def get_presence_topic(self) -> str: return mqtt_topic.get_client_presence_topic(self.mcp_client_id) - def start(self): + async def start(self, timeout: timedelta | None = None) -> bool | str: + connect_result = self.connect() def do_start(): - self.connect() self.client.loop_forever() try: asyncio.create_task(anyio_to_thread.run_sync(do_start)) + if connect_result and connect_result != mqtt.MQTT_ERR_SUCCESS: + logger.error(f"Failed to connect to MQTT broker, error code: {connect_result}") + return mqtt.error_string(connect_result) + # test if the client is connected and wait until it is connected + if timeout: + while not self.is_connected(): + await asyncio.sleep(0.1) + if timeout.total_seconds() <= 0: + logger.error(f"Timeout while waiting for MQTT client to connect, reason: {self.get_last_connect_fail_reason()}") + return self.get_last_connect_fail_reason() or "timeout" + timeout -= timedelta(seconds=0.1) + return True except asyncio.CancelledError: logger.debug("MQTT transport (MCP client) got cancelled") + return "cancelled" except Exception as exc: logger.error(f"MQTT transport (MCP client) failed: {exc}") traceback.print_exc() + return "error" def get_session(self, server_name: ServerName) -> MqttClientSession | None: return self.client_sessions.get(server_name, None) @@ -277,8 +291,8 @@ def _create_session( ) def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code : ReasonCode, properties: Properties | None): + super()._on_connect(client, userdata, connect_flags, reason_code, properties) if reason_code == 0: - super()._on_connect(client, userdata, connect_flags, reason_code, properties) ## Subscribe to the MCP server's presence topic client.subscribe(mqtt_topic.get_server_presence_topic('+', self.server_name_filter), qos=QOS) diff --git a/src/mcp/server/mqtt.py b/src/mcp/server/mqtt.py index 7c38593609..3971dd7c66 100644 --- a/src/mcp/server/mqtt.py +++ b/src/mcp/server/mqtt.py @@ -50,8 +50,8 @@ def get_presence_topic(self) -> str: return mqtt_topic.get_server_presence_topic(self.server_id, self.server_name) def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code : ReasonCode, properties: Properties | None): + super()._on_connect(client, userdata, connect_flags, reason_code, properties) if reason_code == 0: - super()._on_connect(client, userdata, connect_flags, reason_code, properties) if properties and hasattr(properties, "UserProperty"): user_properties: dict[str, Any] = dict(properties.UserProperty) # type: ignore if MCP_SERVER_NAME in user_properties: diff --git a/src/mcp/shared/mqtt.py b/src/mcp/shared/mqtt.py index f56cf51c6e..86b2adb168 100644 --- a/src/mcp/shared/mqtt.py +++ b/src/mcp/shared/mqtt.py @@ -65,6 +65,7 @@ def __init__(self, disconnected_msg: types.JSONRPCMessage | None = None, disconnected_msg_retain: bool = True): self._read_stream_writers = {} + self._last_connect_fail_reason = None self.mqtt_clientid = mqtt_clientid self.mcp_component_type = mcp_component_type self.mqtt_options = mqtt_options @@ -74,7 +75,11 @@ def __init__(self, callback_api_version=CallbackAPIVersion.VERSION2, client_id=mqtt_clientid, protocol=mqtt.MQTTv5, userdata={}, - transport=mqtt_options.transport, reconnect_on_failure=True + transport=mqtt_options.transport, + reconnect_on_failure=True + ) + client.reconnect_delay_set( + min_delay=1, max_delay=120 ) client.username_pw_set(mqtt_options.username, mqtt_options.password.get_secret_value() if mqtt_options.password else None) if mqtt_options.tls_enabled: @@ -123,12 +128,13 @@ async def __aexit__( self._task_group.cancel_scope.cancel() return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) - def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code : ReasonCode, properties: Properties | None): + def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code: ReasonCode, properties: Properties | None): if reason_code == 0: logger.debug(f"Connected to MQTT broker_host at {self.mqtt_options.host}:{self.mqtt_options.port}") self.assert_property(properties, "RetainAvailable", 1) self.assert_property(properties, "WildcardSubscriptionAvailable", 1) else: + self._last_connect_fail_reason = reason_code logger.error(f"Failed to connect, return code {reason_code}") def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): @@ -138,6 +144,12 @@ def _on_subscribe(self, client: mqtt.Client, userdata: Any, mid: int, reason_code_list: list[ReasonCode], properties: Properties | None): pass + def is_connected(self) -> bool: + return self.client.is_connected() + + def get_last_connect_fail_reason(self) -> ReasonCode | None: + return self._last_connect_fail_reason + def publish_json_rpc_message(self, topic: str, message: types.JSONRPCMessage | None, retain: bool = False): props = self.get_publish_properties() From 37e14bd6137b7a8c7a7e7da348c69dd79bf2b6f9 Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Sat, 7 Jun 2025 14:50:10 +0800 Subject: [PATCH 17/23] fix: handle connection refused exceptions --- src/mcp/client/mqtt.py | 27 ++++++++++++++++----------- src/mcp/shared/mqtt.py | 2 +- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/mcp/client/mqtt.py b/src/mcp/client/mqtt.py index 93627cbf12..b702c9640a 100644 --- a/src/mcp/client/mqtt.py +++ b/src/mcp/client/mqtt.py @@ -109,12 +109,10 @@ def get_presence_topic(self) -> str: return mqtt_topic.get_client_presence_topic(self.mcp_client_id) async def start(self, timeout: timedelta | None = None) -> bool | str: - connect_result = self.connect() - def do_start(): - self.client.loop_forever() try: - asyncio.create_task(anyio_to_thread.run_sync(do_start)) - if connect_result and connect_result != mqtt.MQTT_ERR_SUCCESS: + connect_result = self.connect() + asyncio.create_task(anyio_to_thread.run_sync(self.client.loop_forever)) + if connect_result != mqtt.MQTT_ERR_SUCCESS: logger.error(f"Failed to connect to MQTT broker, error code: {connect_result}") return mqtt.error_string(connect_result) # test if the client is connected and wait until it is connected @@ -122,17 +120,24 @@ def do_start(): while not self.is_connected(): await asyncio.sleep(0.1) if timeout.total_seconds() <= 0: - logger.error(f"Timeout while waiting for MQTT client to connect, reason: {self.get_last_connect_fail_reason()}") - return self.get_last_connect_fail_reason() or "timeout" + last_fail_reason = self.get_last_connect_fail_reason() + if last_fail_reason: + return last_fail_reason.getName() + return "timeout" timeout -= timedelta(seconds=0.1) return True except asyncio.CancelledError: logger.debug("MQTT transport (MCP client) got cancelled") return "cancelled" + except ConnectionRefusedError as exc: + logger.error(f"MQTT transport (MCP client) failed to connect: {exc}") + return "connection_refused" + except TimeoutError as exc: + logger.error(f"MQTT transport (MCP client) timed out: {exc}") + return "timeout" except Exception as exc: logger.error(f"MQTT transport (MCP client) failed: {exc}") - traceback.print_exc() - return "error" + return f"connect mqtt error: {str(exc)}" def get_session(self, server_name: ServerName) -> MqttClientSession | None: return self.client_sessions.get(server_name, None) @@ -273,8 +278,8 @@ def _create_session( read_stream_writer: SndStreamEX write_stream: SndStream write_stream_reader: RcvStream - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) # type: ignore + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) # type: ignore self._read_stream_writers[server_id] = read_stream_writer self._task_group.start_soon(self._receieved_from_session, server_id, server_name, write_stream_reader) logger.debug(f"Created new session for server_id: {server_id}") diff --git a/src/mcp/shared/mqtt.py b/src/mcp/shared/mqtt.py index 86b2adb168..f884e66a2c 100644 --- a/src/mcp/shared/mqtt.py +++ b/src/mcp/shared/mqtt.py @@ -170,7 +170,7 @@ def connect(self): props.UserProperty = [ (PROPERTY_K_MCP_COMPONENT, self.mcp_component_type) ] - self.client.connect( + return self.client.connect( host = self.mqtt_options.host, port = self.mqtt_options.port, keepalive = self.mqtt_options.keepalive, From f098c2a419efa2e18c4efa15df1df02dbf6c2071 Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Thu, 19 Jun 2025 19:09:29 +0800 Subject: [PATCH 18/23] use the same topic to send list_changed and resource updated notifications --- src/mcp/client/mqtt.py | 1 - src/mcp/shared/mqtt_topic.py | 10 +++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/mcp/client/mqtt.py b/src/mcp/client/mqtt.py index b702c9640a..44e33d34da 100644 --- a/src/mcp/client/mqtt.py +++ b/src/mcp/client/mqtt.py @@ -370,7 +370,6 @@ def _subscribe_mcp_server_topics(self, server_id: ServerId, server_name: ServerN after_subscribed: Callable[[Any], Awaitable[None]]): topic_filters = [ (mqtt_topic.get_server_capability_change_topic(server_id, server_name), SubscribeOptions(qos=QOS)), - (mqtt_topic.get_server_resource_update_topic(server_id, server_name), SubscribeOptions(qos=QOS)), (mqtt_topic.get_rpc_topic(self.mcp_client_id, server_id, server_name), SubscribeOptions(qos=QOS, noLocal=True)) ] ret, mid = self.client.subscribe(topic=topic_filters) diff --git a/src/mcp/shared/mqtt_topic.py b/src/mcp/shared/mqtt_topic.py index 8befc32a60..1c5e3da536 100644 --- a/src/mcp/shared/mqtt_topic.py +++ b/src/mcp/shared/mqtt_topic.py @@ -1,12 +1,11 @@ SERVER_CONTROL_BASE: str = '$mcp-server' -SERVER_CAPABILITY_CHANGE_BASE: str = '$mcp-server/capability/list-changed' -SERVER_RESOURCE_UPDATE_BASE: str = '$mcp-server/capability/resource-updated' +SERVER_CAPABILITY_CHANGE_BASE: str = '$mcp-server/capability' SERVER_PRESENCE_BASE: str = '$mcp-server/presence' CLIENT_PRESENCE_BASE: str = '$mcp-client/presence' -CLIENT_CAPABILITY_CHANGE_BASE: str = '$mcp-client/capability/list-changed' -RPC_BASE: str = '$mcp-rpc-endpoint' +CLIENT_CAPABILITY_CHANGE_BASE: str = '$mcp-client/capability' +RPC_BASE: str = '$mcp-rpc' def get_server_control_topic(server_id: str, server_name: str) -> str: return f"{SERVER_CONTROL_BASE}/{server_id}/{server_name}" @@ -14,9 +13,6 @@ def get_server_control_topic(server_id: str, server_name: str) -> str: def get_server_capability_change_topic(server_id: str, server_name: str) -> str: return f"{SERVER_CAPABILITY_CHANGE_BASE}/{server_id}/{server_name}" -def get_server_resource_update_topic(server_id: str, server_name: str) -> str: - return f"{SERVER_RESOURCE_UPDATE_BASE}/{server_id}/{server_name}" - def get_server_presence_topic(server_id: str, server_name: str) -> str: return f"{SERVER_PRESENCE_BASE}/{server_id}/{server_name}" From 577664dd34c6eaf4d1792c6e0a638438f8586d68 Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Fri, 20 Jun 2025 17:11:27 +0800 Subject: [PATCH 19/23] mcp client support disconnect a specific mcp server --- src/mcp/client/mqtt.py | 71 ++++++++++++++++++++++++++---------------- src/mcp/server/mqtt.py | 38 +++++++++++++++++----- src/mcp/shared/mqtt.py | 2 +- 3 files changed, 76 insertions(+), 35 deletions(-) diff --git a/src/mcp/client/mqtt.py b/src/mcp/client/mqtt.py index 44e33d34da..081b6b82e5 100644 --- a/src/mcp/client/mqtt.py +++ b/src/mcp/client/mqtt.py @@ -85,6 +85,7 @@ def __init__(self, mqtt_options: MqttOptions = MqttOptions()): uuid = uuid4().hex mqtt_clientid = client_id if client_id else uuid + self._current_server_id: dict[ServerName, ServerId] = {} self.server_list: dict[ServerName, dict[ServerId, ServerDefinition]] = {} self.client_sessions: dict[ServerName, MqttClientSession] = {} self.mcp_client_id = mqtt_clientid @@ -96,13 +97,15 @@ def __init__(self, self.on_mcp_server_discovered = on_mcp_server_discovered self.client_capability_change_topic = mqtt_topic.get_client_capability_change_topic(self.mcp_client_id) ## Send disconnected notification when disconnects - disconnected_msg = types.JSONRPCNotification( - jsonrpc="2.0", - method = "notifications/disconnected" + self._disconnected_msg = types.JSONRPCMessage( + types.JSONRPCNotification( + jsonrpc="2.0", + method = "notifications/disconnected" + ) ) super().__init__("mcp-client", mqtt_clientid = mqtt_clientid, mqtt_options = mqtt_options, - disconnected_msg = types.JSONRPCMessage(disconnected_msg), + disconnected_msg = self._disconnected_msg, disconnected_msg_retain = False) def get_presence_topic(self) -> str: @@ -197,6 +200,12 @@ async def after_initialize(): else: return ("error", "send_subscribe_request_failed") + async def deinitialize_mcp_server(self, server_name: ServerName) -> None: + server_id = self._current_server_id[server_name] + topic = mqtt_topic.get_rpc_topic(self.mcp_client_id, server_id, server_name) + self.publish_json_rpc_message(topic, message = self._disconnected_msg, retain=False) + self._remove_server(server_id, server_name) + async def send_ping(self, server_name: ServerName) -> bool | types.EmptyResult: return await self._with_session(server_name, lambda s: s.send_ping()) @@ -309,9 +318,7 @@ def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage) case str() as t if t.startswith(mqtt_topic.RPC_BASE): self._handle_rpc_message(msg) case str() as t if t.startswith(mqtt_topic.SERVER_CAPABILITY_CHANGE_BASE): - self._handle_server_capability_list_changed_message(msg) - case str() as t if t.startswith(mqtt_topic.SERVER_RESOURCE_UPDATE_BASE): - self._handle_server_capability_resource_updated_message(msg) + self._handle_server_capability_message(msg) case _: logger.error(f"Received message on unexpected topic: {msg.topic}") @@ -340,29 +347,19 @@ def _handle_server_presence_message(self, msg: mqtt.MQTTMessage) -> None: if self.on_mcp_server_discovered: anyio_from_thread.run(self.on_mcp_server_discovered, self, server_name) else: - existing_server = True if server_name in self.server_list else False - if server_id in self.server_list.get(server_name, {}): - logger.debug(f"Server {server_name} with id {server_id} is offline") - self.server_list[server_name].pop(server_id) - if not self.server_list[server_name]: - self.server_list.pop(server_name) - if server_name in self.client_sessions: - _ = self.client_sessions.pop(server_name) - stream = self._read_stream_writers.pop(server_id) - stream.close() - if existing_server: - if self.on_mcp_disconnect: - anyio_from_thread.run(self.on_mcp_disconnect, self, server_name) + # server is offline if the payload is empty + logger.debug(f"Server {server_name} with id {server_id} is offline") + self._remove_server(server_id, server_name) + + def _remove_server(self, server_id: ServerId, server_name: ServerName) -> None: + if server_id in self.server_list.get(server_name, {}): + self._read_stream_writers[server_id].close() def _handle_rpc_message(self, msg: mqtt.MQTTMessage) -> None: server_name = "/".join(msg.topic.split("/")[3:]) anyio_from_thread.run(self._send_message_to_session, server_name, msg) - def _handle_server_capability_list_changed_message(self, msg: mqtt.MQTTMessage) -> None: - server_name = "/".join(msg.topic.split("/")[4:]) - anyio_from_thread.run(self._send_message_to_session, server_name, msg) - - def _handle_server_capability_resource_updated_message(self, msg: mqtt.MQTTMessage) -> None: + def _handle_server_capability_message(self, msg: mqtt.MQTTMessage) -> None: server_name = "/".join(msg.topic.split("/")[4:]) anyio_from_thread.run(self._send_message_to_session, server_name, msg) @@ -426,10 +423,32 @@ async def _receieved_from_session(self, server_id: ServerId, server_name: Server stream = self._read_stream_writers.pop(server_id) await stream.aclose() + # unsubscribe from the topics + logger.debug(f"Unsubscribing from topics for server_id: {server_id}, server_name: {server_name}") + topic_filters = [ + mqtt_topic.get_server_capability_change_topic(server_id, server_name), + mqtt_topic.get_rpc_topic(self.mcp_client_id, server_id, server_name) + ] + self.client.unsubscribe(topic=topic_filters) + + if server_id in self.server_list.get(server_name, {}): + _ = self.server_list[server_name].pop(server_id) + if not self.server_list[server_name]: + _ = self.server_list.pop(server_name) + if self.on_mcp_disconnect: + self._task_group.start_soon(self.on_mcp_disconnect, self, server_name) + + if server_name in self.client_sessions: + _ = self.client_sessions.pop(server_name) + + if server_name in self._current_server_id: + _ = self._current_server_id.pop(server_name) logger.debug(f"Session stream closed for server_id: {server_id}") def pick_server_id(self, server_name: str) -> ServerId: - return random.choice(list(self.server_list[server_name].keys())) + server_id = random.choice(list(self.server_list[server_name].keys())) + self._current_server_id[server_name] = server_id + return server_id def validate_server_name(name: str): if "/" not in name: diff --git a/src/mcp/server/mqtt.py b/src/mcp/server/mqtt.py index 3971dd7c66..7d4339e48a 100644 --- a/src/mcp/server/mqtt.py +++ b/src/mcp/server/mqtt.py @@ -133,7 +133,20 @@ def handle_client_capability_change_message(self, msg: mqtt.MQTTMessage) -> None def handle_rpc_message(self, msg: mqtt.MQTTMessage) -> None: mcp_client_id = msg.topic.split("/")[1] - anyio_from_thread.run(self._send_message_to_session, mcp_client_id, msg) + try: + json_msg = json.loads(msg.payload.decode()) + if "method" in json_msg: + if json_msg["method"] == "notifications/disconnected": + stream = self._read_stream_writers[mcp_client_id] + anyio_from_thread.run(stream.aclose) + logger.debug(f"Closed read_stream for mcp_client_id: {mcp_client_id}") + return + else: + anyio_from_thread.run(self._send_message_to_session, mcp_client_id, msg) + else: + anyio_from_thread.run(self._send_message_to_session, mcp_client_id, msg) + except json.JSONDecodeError: + logger.error(f"Invalid JSON in RPC message for mcp_client_id: {mcp_client_id}") def handle_client_presence_message(self, msg: mqtt.MQTTMessage) -> None: mcp_client_id = msg.topic.split("/")[-1] @@ -144,15 +157,15 @@ def handle_client_presence_message(self, msg: mqtt.MQTTMessage) -> None: json_msg = json.loads(msg.payload.decode()) if "method" in json_msg: if json_msg["method"] == "notifications/disconnected": - stream = self._read_stream_writers.pop(mcp_client_id) + stream = self._read_stream_writers[mcp_client_id] anyio_from_thread.run(stream.aclose) - logger.debug(f"Removed session for mcp_client_id: {mcp_client_id}") + logger.debug(f"Closed read_stream for mcp_client_id: {mcp_client_id}") else: - logger.error(f"Unknown method in control message for mcp_client_id: {mcp_client_id}") + logger.error(f"Unknown method in presence message for mcp_client_id: {mcp_client_id}") else: - logger.error(f"No method in control message for mcp_client_id: {mcp_client_id}") + logger.error(f"No method in presence message for mcp_client_id: {mcp_client_id}") except json.JSONDecodeError: - logger.error(f"Invalid JSON in control message for mcp_client_id: {mcp_client_id}") + logger.error(f"Invalid JSON in presence message for mcp_client_id: {mcp_client_id}") async def create_session(self, mcp_client_id: str, msg: mqtt.MQTTMessage): ## Streams are used to communicate between the MqttTransportServer and the MCPSession: @@ -162,8 +175,8 @@ async def create_session(self, mcp_client_id: str, msg: mqtt.MQTTMessage): read_stream_writer: SndStreamEX write_stream: SndStream write_stream_reader: RcvStream - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) # type: ignore + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) # type: ignore self._read_stream_writers[mcp_client_id] = read_stream_writer self._task_group.start_soon(self.server_session_run, read_stream, write_stream) self._task_group.start_soon(self._receieved_from_session, mcp_client_id, write_stream_reader) @@ -232,6 +245,15 @@ async def _receieved_from_session(self, mcp_client_id: str, write_stream_reader: stream = self._read_stream_writers.pop(mcp_client_id) await stream.aclose() + # unsubscribe from the client topics + logger.debug(f"Unsubscribing from topics for mcp_client_id: {mcp_client_id}") + topic_filters = [ + mqtt_topic.get_client_presence_topic(mcp_client_id), + mqtt_topic.get_client_capability_change_topic(mcp_client_id), + mqtt_topic.get_rpc_topic(mcp_client_id, self.server_id, self.server_name) + ] + self.client.unsubscribe(topic=topic_filters) + logger.debug(f"Session stream closed for mcp_client_id: {mcp_client_id}") async def start_mqtt( diff --git a/src/mcp/shared/mqtt.py b/src/mcp/shared/mqtt.py index f884e66a2c..a0bc3a37cb 100644 --- a/src/mcp/shared/mqtt.py +++ b/src/mcp/shared/mqtt.py @@ -19,7 +19,7 @@ from abc import ABC, abstractmethod DEFAULT_LOG_FORMAT = "%(asctime)s - %(message)s" -QOS = 1 +QOS = 0 MCP_SERVER_NAME = "MCP-SERVER-NAME" MCP_AUTH_ROLE = "MCP-AUTH-ROLE" PROPERTY_K_MCP_COMPONENT = "MCP-COMPONENT-TYPE" From 2138d876c45f4f19bf80c3a299bb41a623a8ca37 Mon Sep 17 00:00:00 2001 From: Jianbo He Date: Wed, 25 Jun 2025 11:40:17 +0800 Subject: [PATCH 20/23] fix: Pop client session if it failed to initialized --- src/mcp/client/mqtt.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/mcp/client/mqtt.py b/src/mcp/client/mqtt.py index 081b6b82e5..b49c3a467b 100644 --- a/src/mcp/client/mqtt.py +++ b/src/mcp/client/mqtt.py @@ -186,11 +186,13 @@ async def after_initialize(): if self.on_mcp_connect: self._task_group.start_soon(self.on_mcp_connect, self, server_name, ("ok", init_result)) except Exception as e: + self.client_sessions.pop(server_name) logging.error(f"Failed to initialize server {server_name}: {e}") await exit_stack.aclose() self._task_group.start_soon(after_initialize) logger.debug(f"after initialize: {server_name}") except McpError as exc: + self.client_sessions.pop(server_name) logger.error(f"Failed to connect to MCP server: {exc}") if self.on_mcp_connect: self._task_group.start_soon(self.on_mcp_connect, self, server_name, ("error", McpError)) From f13ade2c9342a2e6b3808472fa57dfbc83413937 Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Mon, 21 Jul 2025 18:47:07 +0800 Subject: [PATCH 21/23] support broker suggested server-name-filters --- .../clients/mqtt-clients/client_apis_demo.py | 2 +- src/mcp/client/mqtt.py | 26 ++++++++++++++----- src/mcp/shared/mqtt.py | 1 + 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/examples/clients/mqtt-clients/client_apis_demo.py b/examples/clients/mqtt-clients/client_apis_demo.py index 31da377f63..6ba4d965e4 100644 --- a/examples/clients/mqtt-clients/client_apis_demo.py +++ b/examples/clients/mqtt-clients/client_apis_demo.py @@ -32,7 +32,7 @@ async def on_mcp_disconnect(client, server_name): async def main(): async with mcp_mqtt.MqttTransportClient( "test_client", - auto_connect_to_mcp_server = True, + auto_connect_to_mcp_servers = True, on_mcp_server_discovered = on_mcp_server_discovered, on_mcp_connect = on_mcp_connect, on_mcp_disconnect = on_mcp_disconnect, diff --git a/src/mcp/client/mqtt.py b/src/mcp/client/mqtt.py index b49c3a467b..59d0b7ec5a 100644 --- a/src/mcp/client/mqtt.py +++ b/src/mcp/client/mqtt.py @@ -2,13 +2,14 @@ This module implements the MQTT transport for the MCP server. """ from contextlib import AsyncExitStack +import json from uuid import uuid4 from datetime import timedelta import random from pydantic import AnyUrl, BaseModel from mcp.client.session import ClientSession, SamplingFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT from mcp.shared.exceptions import McpError -from mcp.shared.mqtt import MqttTransportBase, MqttOptions, QOS +from mcp.shared.mqtt import MqttTransportBase, MqttOptions, QOS, MCP_SERVER_NAME_FILTERS import asyncio import anyio.to_thread as anyio_to_thread import anyio.from_thread as anyio_from_thread @@ -77,8 +78,8 @@ class MqttTransportClient(MqttTransportBase): def __init__(self, mcp_client_name: str, client_id: str | None = None, - server_name_filter: str = '#', - auto_connect_to_mcp_server: bool = False, + server_name_filters: str | list[str] = '#', + auto_connect_to_mcp_servers: bool = False, on_mcp_connect: Callable[["MqttTransportClient", ServerName, ConnectResult], Awaitable[Any]] | None = None, on_mcp_disconnect: Callable[["MqttTransportClient", ServerName], Awaitable[Any]] | None = None, on_mcp_server_discovered: Callable[["MqttTransportClient", ServerName], Awaitable[Any]] | None = None, @@ -90,8 +91,11 @@ def __init__(self, self.client_sessions: dict[ServerName, MqttClientSession] = {} self.mcp_client_id = mqtt_clientid self.mcp_client_name = mcp_client_name - self.server_name_filter = server_name_filter - self.auto_connect_to_mcp_server = auto_connect_to_mcp_server #TODO: not implemented yet + if isinstance(server_name_filters, str): + self.server_name_filters = [server_name_filters] + else: + self.server_name_filters = server_name_filters + self.auto_connect_to_mcp_servers = auto_connect_to_mcp_servers self.on_mcp_connect = on_mcp_connect self.on_mcp_disconnect = on_mcp_disconnect self.on_mcp_server_discovered = on_mcp_server_discovered @@ -308,9 +312,16 @@ def _create_session( def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.ConnectFlags, reason_code : ReasonCode, properties: Properties | None): super()._on_connect(client, userdata, connect_flags, reason_code, properties) + if properties and hasattr(properties, "UserProperty"): + user_properties: dict[str, Any] = dict(properties.UserProperty) # type: ignore + if MCP_SERVER_NAME_FILTERS in user_properties: + self.server_name_filters = json.loads(user_properties[MCP_SERVER_NAME_FILTERS]) + logger.debug(f"Use broker suggested server name filters: {self.server_name_filters}") if reason_code == 0: ## Subscribe to the MCP server's presence topic - client.subscribe(mqtt_topic.get_server_presence_topic('+', self.server_name_filter), qos=QOS) + for server_name_filter in self.server_name_filters: + logger.debug(f"Subscribing to server presence topic for server_name_filter: {server_name_filter}") + client.subscribe(mqtt_topic.get_server_presence_topic('+', server_name_filter), qos=QOS) def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): logger.debug(f"Received message on topic {msg.topic}: {msg.payload.decode()}") @@ -346,6 +357,9 @@ def _handle_server_presence_message(self, msg: mqtt.MQTTMessage) -> None: self.server_list.setdefault(server_name, {})[server_id] = server_notif.params logger.debug(f"Server {server_name} with id {server_id} is online") if newly_added_server: + if self.auto_connect_to_mcp_servers: + logger.debug(f"Auto connecting to MCP server {server_name}") + anyio_from_thread.run(self.initialize_mcp_server, server_name) if self.on_mcp_server_discovered: anyio_from_thread.run(self.on_mcp_server_discovered, self, server_name) else: diff --git a/src/mcp/shared/mqtt.py b/src/mcp/shared/mqtt.py index a0bc3a37cb..a8705d970a 100644 --- a/src/mcp/shared/mqtt.py +++ b/src/mcp/shared/mqtt.py @@ -21,6 +21,7 @@ DEFAULT_LOG_FORMAT = "%(asctime)s - %(message)s" QOS = 0 MCP_SERVER_NAME = "MCP-SERVER-NAME" +MCP_SERVER_NAME_FILTERS = "MCP-SERVER-NAME-FILTERS" MCP_AUTH_ROLE = "MCP-AUTH-ROLE" PROPERTY_K_MCP_COMPONENT = "MCP-COMPONENT-TYPE" PROPERTY_K_MQTT_CLIENT_ID = "MCP-MQTT-CLIENT-ID" From 9404d578e47613b4237e19e638fbf00b9c63f853 Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Tue, 22 Jul 2025 09:34:52 +0800 Subject: [PATCH 22/23] fix: make the params backward-compatible --- .../clients/mqtt-clients/client_apis_demo.py | 2 +- src/mcp/client/mqtt.py | 24 +++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/clients/mqtt-clients/client_apis_demo.py b/examples/clients/mqtt-clients/client_apis_demo.py index 6ba4d965e4..31da377f63 100644 --- a/examples/clients/mqtt-clients/client_apis_demo.py +++ b/examples/clients/mqtt-clients/client_apis_demo.py @@ -32,7 +32,7 @@ async def on_mcp_disconnect(client, server_name): async def main(): async with mcp_mqtt.MqttTransportClient( "test_client", - auto_connect_to_mcp_servers = True, + auto_connect_to_mcp_server = True, on_mcp_server_discovered = on_mcp_server_discovered, on_mcp_connect = on_mcp_connect, on_mcp_disconnect = on_mcp_disconnect, diff --git a/src/mcp/client/mqtt.py b/src/mcp/client/mqtt.py index 59d0b7ec5a..3bab5ff8c7 100644 --- a/src/mcp/client/mqtt.py +++ b/src/mcp/client/mqtt.py @@ -78,8 +78,8 @@ class MqttTransportClient(MqttTransportBase): def __init__(self, mcp_client_name: str, client_id: str | None = None, - server_name_filters: str | list[str] = '#', - auto_connect_to_mcp_servers: bool = False, + server_name_filter: str | list[str] = '#', + auto_connect_to_mcp_server: bool = False, on_mcp_connect: Callable[["MqttTransportClient", ServerName, ConnectResult], Awaitable[Any]] | None = None, on_mcp_disconnect: Callable[["MqttTransportClient", ServerName], Awaitable[Any]] | None = None, on_mcp_server_discovered: Callable[["MqttTransportClient", ServerName], Awaitable[Any]] | None = None, @@ -91,11 +91,11 @@ def __init__(self, self.client_sessions: dict[ServerName, MqttClientSession] = {} self.mcp_client_id = mqtt_clientid self.mcp_client_name = mcp_client_name - if isinstance(server_name_filters, str): - self.server_name_filters = [server_name_filters] + if isinstance(server_name_filter, str): + self.server_name_filter = [server_name_filter] else: - self.server_name_filters = server_name_filters - self.auto_connect_to_mcp_servers = auto_connect_to_mcp_servers + self.server_name_filter = server_name_filter + self.auto_connect_to_mcp_server = auto_connect_to_mcp_server self.on_mcp_connect = on_mcp_connect self.on_mcp_disconnect = on_mcp_disconnect self.on_mcp_server_discovered = on_mcp_server_discovered @@ -315,13 +315,13 @@ def _on_connect(self, client: mqtt.Client, userdata: Any, connect_flags: mqtt.Co if properties and hasattr(properties, "UserProperty"): user_properties: dict[str, Any] = dict(properties.UserProperty) # type: ignore if MCP_SERVER_NAME_FILTERS in user_properties: - self.server_name_filters = json.loads(user_properties[MCP_SERVER_NAME_FILTERS]) - logger.debug(f"Use broker suggested server name filters: {self.server_name_filters}") + self.server_name_filter = json.loads(user_properties[MCP_SERVER_NAME_FILTERS]) + logger.debug(f"Use broker suggested server name filters: {self.server_name_filter}") if reason_code == 0: ## Subscribe to the MCP server's presence topic - for server_name_filter in self.server_name_filters: - logger.debug(f"Subscribing to server presence topic for server_name_filter: {server_name_filter}") - client.subscribe(mqtt_topic.get_server_presence_topic('+', server_name_filter), qos=QOS) + for snf in self.server_name_filter: + logger.debug(f"Subscribing to server presence topic for server_name_filter: {snf}") + client.subscribe(mqtt_topic.get_server_presence_topic('+', snf), qos=QOS) def _on_message(self, client: mqtt.Client, userdata: Any, msg: mqtt.MQTTMessage): logger.debug(f"Received message on topic {msg.topic}: {msg.payload.decode()}") @@ -357,7 +357,7 @@ def _handle_server_presence_message(self, msg: mqtt.MQTTMessage) -> None: self.server_list.setdefault(server_name, {})[server_id] = server_notif.params logger.debug(f"Server {server_name} with id {server_id} is online") if newly_added_server: - if self.auto_connect_to_mcp_servers: + if self.auto_connect_to_mcp_server: logger.debug(f"Auto connecting to MCP server {server_name}") anyio_from_thread.run(self.initialize_mcp_server, server_name) if self.on_mcp_server_discovered: From 4e6647c7e13976ada1e849f9bfe7878c69504dad Mon Sep 17 00:00:00 2001 From: Shawn <506895667@qq.com> Date: Wed, 3 Sep 2025 13:03:40 +0800 Subject: [PATCH 23/23] fix: mqtt mcp client crash on clean up server_ids --- src/mcp/client/mqtt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mcp/client/mqtt.py b/src/mcp/client/mqtt.py index 3bab5ff8c7..b9fe0bcbff 100644 --- a/src/mcp/client/mqtt.py +++ b/src/mcp/client/mqtt.py @@ -369,7 +369,9 @@ def _handle_server_presence_message(self, msg: mqtt.MQTTMessage) -> None: def _remove_server(self, server_id: ServerId, server_name: ServerName) -> None: if server_id in self.server_list.get(server_name, {}): - self._read_stream_writers[server_id].close() + if server_id in self._read_stream_writers: + logger.debug(f"Closing stream writer for server_id: {server_id}") + self._read_stream_writers[server_id].close() def _handle_rpc_message(self, msg: mqtt.MQTTMessage) -> None: server_name = "/".join(msg.topic.split("/")[3:])